Add `CANDLE_NVCC_CCBIN` support for `candle-kernels`, and eliminate warning. (#836)
This commit is contained in:
parent
3e94324012
commit
1c09164021
|
@ -12,6 +12,7 @@ use half::{bf16, f16};
|
|||
use std::sync::{Arc, Mutex};
|
||||
|
||||
const USE_IM2COL_CONV1D: bool = true;
|
||||
#[cfg(not(feature = "cudnn"))]
|
||||
const USE_IM2COL_CONV2D: bool = true;
|
||||
|
||||
/// cudarc related errors
|
||||
|
|
|
@ -164,6 +164,8 @@ mod cuda {
|
|||
|
||||
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
|
||||
|
||||
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
|
||||
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
|
||||
let children = kernel_paths
|
||||
.par_iter()
|
||||
.flat_map(|p| {
|
||||
|
@ -188,8 +190,13 @@ mod cuda {
|
|||
.args(["--output-directory", &out_dir])
|
||||
// Flash attention only
|
||||
// .arg("--expt-relaxed-constexpr")
|
||||
.args(&include_options)
|
||||
.arg(p);
|
||||
.args(&include_options);
|
||||
if let Ok(ccbin_path) = &ccbin_env {
|
||||
command
|
||||
.arg("-allow-unsupported-compiler")
|
||||
.args(["-ccbin", ccbin_path]);
|
||||
}
|
||||
command.arg(p);
|
||||
Some((p, command.spawn()
|
||||
.expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output()))
|
||||
}})
|
||||
|
|
Loading…
Reference in New Issue