Add `CANDLE_NVCC_CCBIN` support for `candle-kernels`, and eliminate warning. (#836)

This commit is contained in:
Charles Lew 2023-09-13 18:39:22 +08:00 committed by GitHub
parent 3e94324012
commit 1c09164021
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 2 deletions

View File

@ -12,6 +12,7 @@ use half::{bf16, f16};
use std::sync::{Arc, Mutex};
const USE_IM2COL_CONV1D: bool = true;
#[cfg(not(feature = "cudnn"))]
const USE_IM2COL_CONV2D: bool = true;
/// cudarc related errors

View File

@ -164,6 +164,8 @@ mod cuda {
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
let children = kernel_paths
.par_iter()
.flat_map(|p| {
@ -188,8 +190,13 @@ mod cuda {
.args(["--output-directory", &out_dir])
// Flash attention only
// .arg("--expt-relaxed-constexpr")
.args(&include_options)
.arg(p);
.args(&include_options);
if let Ok(ccbin_path) = &ccbin_env {
command
.arg("-allow-unsupported-compiler")
.args(["-ccbin", ccbin_path]);
}
command.arg(p);
Some((p, command.spawn()
.expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output()))
}})