Update to cudarc 0.14 (breaking change). (#2858)
* Start updating to cudarc 0.14. * Adapt a couple more things. * And a couple more fixes. * More tweaks. * And a couple more fixes. * Bump the major version number. * Proper module system for the cuda kernels. * Proper ptx loading. * Launch the sort kernel. * Custom op. * Start using the builder pattern. * More builder. * More builder. * Get candle-core to compile. * Get the tests to pass. * Get candle-nn to work too. * Support for custom cuda functions. * cudnn fixes. * Get flash attn to run. * Switch the crate versions to be alpha. * Bump the ug dependency.
This commit is contained in:
parent
d6db305829
commit
d9904a3baf
26
Cargo.toml
26
Cargo.toml
|
@ -20,7 +20,7 @@ exclude = [
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.8.4"
|
version = "0.9.0-alpha.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Minimalist ML framework."
|
description = "Minimalist ML framework."
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
|
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
|
||||||
accelerate-src = { version = "0.3.2" }
|
accelerate-src = { version = "0.3.2" }
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
candle = { path = "./candle-core", package = "candle-core", version = "0.8.4" }
|
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.1" }
|
||||||
candle-datasets = { path = "./candle-datasets", version = "0.8.4" }
|
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.1" }
|
||||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.4" }
|
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.1" }
|
||||||
candle-kernels = { path = "./candle-kernels", version = "0.8.4" }
|
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.1" }
|
||||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.4" }
|
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.1" }
|
||||||
candle-nn = { path = "./candle-nn", version = "0.8.4" }
|
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.1" }
|
||||||
candle-onnx = { path = "./candle-onnx", version = "0.8.4" }
|
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.1" }
|
||||||
candle-transformers = { path = "./candle-transformers", version = "0.8.4" }
|
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.1" }
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
criterion = { version = "0.5.1", default-features=false }
|
criterion = { version = "0.5.1", default-features=false }
|
||||||
cudarc = { version = "0.13.5", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
cudarc = { version = "0.14.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||||
fancy-regex = "0.13.0"
|
fancy-regex = "0.13.0"
|
||||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||||
hf-hub = "0.4.1"
|
hf-hub = "0.4.1"
|
||||||
|
@ -70,9 +70,9 @@ tokenizers = { version = "0.21.0", default-features = false }
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-chrome = "0.7.1"
|
tracing-chrome = "0.7.1"
|
||||||
tracing-subscriber = "0.3.7"
|
tracing-subscriber = "0.3.7"
|
||||||
ug = "0.1.0"
|
ug = "0.2.0"
|
||||||
ug-cuda = "0.1.0"
|
ug-cuda = "0.2.0"
|
||||||
ug-metal = "0.1.0"
|
ug-metal = "0.2.0"
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "1.1.1", default-features = false }
|
zip = { version = "1.1.1", default-features = false }
|
||||||
metal = { version = "0.27.0", features = ["mps"]}
|
metal = { version = "0.27.0", features = ["mps"]}
|
||||||
|
|
|
@ -43,7 +43,7 @@ pub(crate) fn launch_conv2d<
|
||||||
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
||||||
return Ok(cudnn.clone());
|
return Ok(cudnn.clone());
|
||||||
}
|
}
|
||||||
let c = Cudnn::new(dev.cuda_device());
|
let c = Cudnn::new(dev.cuda_stream());
|
||||||
if let Ok(c) = &c {
|
if let Ok(c) = &c {
|
||||||
cudnn.borrow_mut().insert(device_id, c.clone());
|
cudnn.borrow_mut().insert(device_id, c.clone());
|
||||||
}
|
}
|
||||||
|
@ -109,7 +109,7 @@ pub(crate) fn launch_conv2d<
|
||||||
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
||||||
};
|
};
|
||||||
let workspace_size = conv2d.get_workspace_size(alg)?;
|
let workspace_size = conv2d.get_workspace_size(alg)?;
|
||||||
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
|
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
|
||||||
unsafe {
|
unsafe {
|
||||||
conv2d.launch::<CudaSlice<u8>, _, _, _>(
|
conv2d.launch::<CudaSlice<u8>, _, _, _>(
|
||||||
alg,
|
alg,
|
||||||
|
|
|
@ -2,8 +2,9 @@ use crate::backend::BackendDevice;
|
||||||
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||||
pub use candle_kernels as kernels;
|
pub use candle_kernels as kernels;
|
||||||
pub use cudarc;
|
pub use cudarc;
|
||||||
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
|
use cudarc::driver::{CudaFunction, LaunchConfig, PushKernelArg};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
|
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
|
||||||
|
@ -24,10 +25,17 @@ impl DeviceId {
|
||||||
struct CudaRng(cudarc::curand::CudaRng);
|
struct CudaRng(cudarc::curand::CudaRng);
|
||||||
unsafe impl Send for CudaRng {}
|
unsafe impl Send for CudaRng {}
|
||||||
|
|
||||||
|
pub struct ModuleStore {
|
||||||
|
mdls: [Option<Arc<cudarc::driver::CudaModule>>; kernels::ALL_IDS.len()],
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct CudaDevice {
|
pub struct CudaDevice {
|
||||||
id: DeviceId,
|
id: DeviceId,
|
||||||
device: Arc<cudarc::driver::CudaDevice>,
|
context: Arc<cudarc::driver::CudaContext>,
|
||||||
|
modules: Arc<std::sync::RwLock<ModuleStore>>,
|
||||||
|
custom_modules: Arc<std::sync::RwLock<HashMap<String, Arc<cudarc::driver::CudaModule>>>>,
|
||||||
|
stream: Arc<cudarc::driver::CudaStream>,
|
||||||
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
|
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
|
||||||
curand: Arc<Mutex<CudaRng>>,
|
curand: Arc<Mutex<CudaRng>>,
|
||||||
}
|
}
|
||||||
|
@ -39,16 +47,51 @@ impl std::fmt::Debug for CudaDevice {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::ops::Deref for CudaDevice {
|
impl std::ops::Deref for CudaDevice {
|
||||||
type Target = Arc<cudarc::driver::CudaDevice>;
|
type Target = Arc<cudarc::driver::CudaStream>;
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
fn deref(&self) -> &Self::Target {
|
||||||
&self.device
|
&self.stream
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CudaFunc {
|
||||||
|
func: CudaFunction,
|
||||||
|
stream: Arc<cudarc::driver::CudaStream>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::ops::Deref for CudaFunc {
|
||||||
|
type Target = CudaFunction;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.func
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CudaFunc {
|
||||||
|
pub fn into_cuda_function(self) -> CudaFunction {
|
||||||
|
self.func
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! builder_arg {
|
||||||
|
($b:ident, $($arg:expr),*) => {
|
||||||
|
$(
|
||||||
|
let __arg = $arg;
|
||||||
|
$b.arg(&__arg);
|
||||||
|
)*
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CudaFunc {
|
||||||
|
pub fn builder(&self) -> cudarc::driver::LaunchArgs<'_> {
|
||||||
|
self.stream.launch_builder(&self.func)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CudaDevice {
|
impl CudaDevice {
|
||||||
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
|
pub fn cuda_stream(&self) -> Arc<cudarc::driver::CudaStream> {
|
||||||
self.device.clone()
|
self.stream.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(target_arch = "wasm32"))]
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
|
@ -56,7 +99,7 @@ impl CudaDevice {
|
||||||
&self,
|
&self,
|
||||||
func_name: &'static str,
|
func_name: &'static str,
|
||||||
kernel: ug::lang::ssa::Kernel,
|
kernel: ug::lang::ssa::Kernel,
|
||||||
) -> Result<CudaFunction> {
|
) -> Result<CudaFunc> {
|
||||||
let mut buf = vec![];
|
let mut buf = vec![];
|
||||||
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
|
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
|
||||||
let cuda_code = String::from_utf8(buf)?;
|
let cuda_code = String::from_utf8(buf)?;
|
||||||
|
@ -65,12 +108,12 @@ impl CudaDevice {
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
|
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
|
||||||
self.device.load_ptx(ptx, "ug", &[func_name]).w()?;
|
let module = self.context.load_module(ptx).w()?;
|
||||||
let func = match self.device.get_func("ug", func_name) {
|
let func = module.load_function(func_name).w()?;
|
||||||
Some(func) => func,
|
Ok(CudaFunc {
|
||||||
None => crate::bail!("unknown function ug::{func_name}"),
|
func,
|
||||||
};
|
stream: self.stream.clone(),
|
||||||
Ok(func)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn id(&self) -> DeviceId {
|
pub fn id(&self) -> DeviceId {
|
||||||
|
@ -84,57 +127,84 @@ impl CudaDevice {
|
||||||
DType::U8 => {
|
DType::U8 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
|
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
|
||||||
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
|
||||||
let params = (&data, v as u8, elem_count);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let v = v as u8;
|
||||||
|
builder.arg(&data);
|
||||||
|
builder.arg(&v);
|
||||||
|
builder.arg(&elem_count);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
CudaStorageSlice::U8(data)
|
CudaStorageSlice::U8(data)
|
||||||
}
|
}
|
||||||
DType::U32 => {
|
DType::U32 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
|
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
|
||||||
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
|
||||||
let params = (&data, v as u32, elem_count);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let v = v as u32;
|
||||||
|
builder.arg(&data);
|
||||||
|
builder.arg(&v);
|
||||||
|
builder.arg(&elem_count);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
CudaStorageSlice::U32(data)
|
CudaStorageSlice::U32(data)
|
||||||
}
|
}
|
||||||
DType::I64 => {
|
DType::I64 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
||||||
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
|
||||||
let params = (&data, v as i64, elem_count);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let v = v as i64;
|
||||||
|
builder.arg(&data);
|
||||||
|
builder.arg(&v);
|
||||||
|
builder.arg(&elem_count);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
CudaStorageSlice::I64(data)
|
CudaStorageSlice::I64(data)
|
||||||
}
|
}
|
||||||
DType::BF16 => {
|
DType::BF16 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
||||||
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
|
||||||
let params = (&data, bf16::from_f64(v), elem_count);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let v = bf16::from_f64(v);
|
||||||
|
builder.arg(&data);
|
||||||
|
builder.arg(&v);
|
||||||
|
builder.arg(&elem_count);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
CudaStorageSlice::BF16(data)
|
CudaStorageSlice::BF16(data)
|
||||||
}
|
}
|
||||||
DType::F16 => {
|
DType::F16 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
|
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
|
||||||
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
|
||||||
let params = (&data, f16::from_f64(v), elem_count);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let v = f16::from_f64(v);
|
||||||
|
builder.arg(&data);
|
||||||
|
builder.arg(&v);
|
||||||
|
builder.arg(&elem_count);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
CudaStorageSlice::F16(data)
|
CudaStorageSlice::F16(data)
|
||||||
}
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||||
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
|
||||||
let params = (&data, v as f32, elem_count);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let v = v as f32;
|
||||||
|
builder.arg(&data);
|
||||||
|
builder.arg(&v);
|
||||||
|
builder.arg(&elem_count);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
DType::F64 => {
|
DType::F64 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||||
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
|
||||||
let params = (&data, v, elem_count);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
builder.arg(&data);
|
||||||
|
builder.arg(&v);
|
||||||
|
builder.arg(&elem_count);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -144,38 +214,69 @@ impl CudaDevice {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
pub fn get_or_load_custom_func(
|
||||||
if !self.has_func(module_name, module_name) {
|
&self,
|
||||||
// Leaking the string here is a bit sad but we need a &'static str and this is only
|
fn_name: &str,
|
||||||
// done once per kernel name.
|
module_name: &str,
|
||||||
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
|
ptx: &str,
|
||||||
self.load_ptx(ptx.into(), module_name, &[static_module_name])
|
) -> Result<CudaFunc> {
|
||||||
.map_err(|cuda| CudaError::Load {
|
let ms = self.custom_modules.read().unwrap();
|
||||||
cuda,
|
if let Some(mdl) = ms.get(module_name).as_ref() {
|
||||||
module_name: module_name.to_string(),
|
let func = mdl.load_function(fn_name).w()?;
|
||||||
})
|
return Ok(CudaFunc {
|
||||||
.w()?;
|
func,
|
||||||
|
stream: self.stream.clone(),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
self.get_func(module_name, module_name)
|
drop(ms);
|
||||||
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
let mut ms = self.custom_modules.write().unwrap();
|
||||||
// able to only build the error value if needed.
|
let cuda_module = self.context.load_module(ptx.into()).w()?;
|
||||||
.ok_or(CudaError::MissingKernel {
|
ms.insert(module_name.to_string(), cuda_module.clone());
|
||||||
module_name: module_name.to_string(),
|
let func = cuda_module.load_function(fn_name).w()?;
|
||||||
|
Ok(CudaFunc {
|
||||||
|
func,
|
||||||
|
stream: self.stream.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_or_load_func(&self, fn_name: &str, mdl: &kernels::Module) -> Result<CudaFunc> {
|
||||||
|
let ms = self.modules.read().unwrap();
|
||||||
|
if let Some(mdl) = ms.mdls[mdl.index()].as_ref() {
|
||||||
|
let func = mdl.load_function(fn_name).w()?;
|
||||||
|
return Ok(CudaFunc {
|
||||||
|
func,
|
||||||
|
stream: self.stream.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
drop(ms);
|
||||||
|
let mut ms = self.modules.write().unwrap();
|
||||||
|
let cuda_module = self.context.load_module(mdl.ptx().into()).w()?;
|
||||||
|
ms.mdls[mdl.index()] = Some(cuda_module.clone());
|
||||||
|
let func = cuda_module.load_function(fn_name).w()?;
|
||||||
|
Ok(CudaFunc {
|
||||||
|
func,
|
||||||
|
stream: self.stream.clone(),
|
||||||
})
|
})
|
||||||
.w()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CudaDevice {
|
impl CudaDevice {
|
||||||
pub fn new_with_stream(ordinal: usize) -> Result<Self> {
|
pub fn new_with_stream(ordinal: usize) -> Result<Self> {
|
||||||
let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?;
|
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
|
||||||
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
let stream = context.new_stream().w()?;
|
||||||
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
|
||||||
|
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
|
||||||
|
let module_store = ModuleStore {
|
||||||
|
mdls: [const { None }; kernels::ALL_IDS.len()],
|
||||||
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
id: DeviceId::new(),
|
id: DeviceId::new(),
|
||||||
device,
|
context,
|
||||||
|
stream,
|
||||||
blas: Arc::new(blas),
|
blas: Arc::new(blas),
|
||||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||||
|
modules: Arc::new(std::sync::RwLock::new(module_store)),
|
||||||
|
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -184,14 +285,21 @@ impl BackendDevice for CudaDevice {
|
||||||
type Storage = CudaStorage;
|
type Storage = CudaStorage;
|
||||||
|
|
||||||
fn new(ordinal: usize) -> Result<Self> {
|
fn new(ordinal: usize) -> Result<Self> {
|
||||||
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
|
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
|
||||||
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
let stream = context.default_stream();
|
||||||
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
|
||||||
|
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
|
||||||
|
let module_store = ModuleStore {
|
||||||
|
mdls: [const { None }; kernels::ALL_IDS.len()],
|
||||||
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
id: DeviceId::new(),
|
id: DeviceId::new(),
|
||||||
device,
|
context,
|
||||||
|
stream,
|
||||||
blas: Arc::new(blas),
|
blas: Arc::new(blas),
|
||||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||||
|
modules: Arc::new(std::sync::RwLock::new(module_store)),
|
||||||
|
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -199,13 +307,13 @@ impl BackendDevice for CudaDevice {
|
||||||
// We do not call set_seed but instead create a new curand object. This ensures that the
|
// We do not call set_seed but instead create a new curand object. This ensures that the
|
||||||
// state will be identical and the same random numbers will be generated.
|
// state will be identical and the same random numbers will be generated.
|
||||||
let mut curand = self.curand.lock().unwrap();
|
let mut curand = self.curand.lock().unwrap();
|
||||||
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
curand.0 = cudarc::curand::CudaRng::new(seed, self.stream.clone()).w()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn location(&self) -> crate::DeviceLocation {
|
fn location(&self) -> crate::DeviceLocation {
|
||||||
crate::DeviceLocation::Cuda {
|
crate::DeviceLocation::Cuda {
|
||||||
gpu_id: self.device.ordinal(),
|
gpu_id: self.context.ordinal(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -373,31 +481,31 @@ impl BackendDevice for CudaDevice {
|
||||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||||
let slice = match T::cpu_storage_ref(s) {
|
let slice = match T::cpu_storage_ref(s) {
|
||||||
CpuStorageRef::U8(storage) => {
|
CpuStorageRef::U8(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::U8(data)
|
CudaStorageSlice::U8(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::U32(storage) => {
|
CpuStorageRef::U32(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::U32(data)
|
CudaStorageSlice::U32(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::I64(storage) => {
|
CpuStorageRef::I64(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::I64(data)
|
CudaStorageSlice::I64(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::BF16(storage) => {
|
CpuStorageRef::BF16(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::BF16(data)
|
CudaStorageSlice::BF16(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::F16(storage) => {
|
CpuStorageRef::F16(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::F16(data)
|
CudaStorageSlice::F16(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::F32(storage) => {
|
CpuStorageRef::F32(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::F64(storage) => {
|
CpuStorageRef::F64(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -410,31 +518,31 @@ impl BackendDevice for CudaDevice {
|
||||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||||
let slice = match storage {
|
let slice = match storage {
|
||||||
CpuStorage::U8(storage) => {
|
CpuStorage::U8(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::U8(data)
|
CudaStorageSlice::U8(data)
|
||||||
}
|
}
|
||||||
CpuStorage::U32(storage) => {
|
CpuStorage::U32(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::U32(data)
|
CudaStorageSlice::U32(data)
|
||||||
}
|
}
|
||||||
CpuStorage::I64(storage) => {
|
CpuStorage::I64(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::I64(data)
|
CudaStorageSlice::I64(data)
|
||||||
}
|
}
|
||||||
CpuStorage::BF16(storage) => {
|
CpuStorage::BF16(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::BF16(data)
|
CudaStorageSlice::BF16(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F16(storage) => {
|
CpuStorage::F16(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::F16(data)
|
CudaStorageSlice::F16(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F32(storage) => {
|
CpuStorage::F32(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F64(storage) => {
|
CpuStorage::F64(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -447,31 +555,31 @@ impl BackendDevice for CudaDevice {
|
||||||
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
||||||
let slice = match storage {
|
let slice = match storage {
|
||||||
CpuStorage::U8(storage) => {
|
CpuStorage::U8(storage) => {
|
||||||
let data = self.htod_copy(storage).w()?;
|
let data = self.memcpy_stod(&storage).w()?;
|
||||||
CudaStorageSlice::U8(data)
|
CudaStorageSlice::U8(data)
|
||||||
}
|
}
|
||||||
CpuStorage::U32(storage) => {
|
CpuStorage::U32(storage) => {
|
||||||
let data = self.htod_copy(storage).w()?;
|
let data = self.memcpy_stod(&storage).w()?;
|
||||||
CudaStorageSlice::U32(data)
|
CudaStorageSlice::U32(data)
|
||||||
}
|
}
|
||||||
CpuStorage::I64(storage) => {
|
CpuStorage::I64(storage) => {
|
||||||
let data = self.htod_copy(storage).w()?;
|
let data = self.memcpy_stod(&storage).w()?;
|
||||||
CudaStorageSlice::I64(data)
|
CudaStorageSlice::I64(data)
|
||||||
}
|
}
|
||||||
CpuStorage::BF16(storage) => {
|
CpuStorage::BF16(storage) => {
|
||||||
let data = self.htod_copy(storage).w()?;
|
let data = self.memcpy_stod(&storage).w()?;
|
||||||
CudaStorageSlice::BF16(data)
|
CudaStorageSlice::BF16(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F16(storage) => {
|
CpuStorage::F16(storage) => {
|
||||||
let data = self.htod_copy(storage).w()?;
|
let data = self.memcpy_stod(&storage).w()?;
|
||||||
CudaStorageSlice::F16(data)
|
CudaStorageSlice::F16(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F32(storage) => {
|
CpuStorage::F32(storage) => {
|
||||||
let data = self.htod_copy(storage).w()?;
|
let data = self.memcpy_stod(&storage).w()?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F64(storage) => {
|
CpuStorage::F64(storage) => {
|
||||||
let data = self.htod_copy(storage).w()?;
|
let data = self.memcpy_stod(&storage).w()?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -482,7 +590,7 @@ impl BackendDevice for CudaDevice {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn synchronize(&self) -> Result<()> {
|
fn synchronize(&self) -> Result<()> {
|
||||||
self.device.synchronize().map_err(crate::Error::wrap)?;
|
self.stream.synchronize().map_err(crate::Error::wrap)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -396,7 +396,10 @@ impl UgIOp1 {
|
||||||
{
|
{
|
||||||
let device = device.as_cuda_device()?;
|
let device = device.as_cuda_device()?;
|
||||||
let func = device.compile(name, kernel)?;
|
let func = device.compile(name, kernel)?;
|
||||||
Ok(Self { name, func })
|
Ok(Self {
|
||||||
|
name,
|
||||||
|
func: func.into_cuda_function(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
#[cfg(feature = "metal")]
|
#[cfg(feature = "metal")]
|
||||||
{
|
{
|
||||||
|
@ -459,16 +462,16 @@ impl InplaceOp1 for UgIOp1 {
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
|
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
|
||||||
use crate::cuda_backend::WrapErr;
|
use crate::cuda_backend::WrapErr;
|
||||||
use cudarc::driver::LaunchAsync;
|
use cudarc::driver::PushKernelArg;
|
||||||
|
|
||||||
let elem_count = layout.shape().elem_count();
|
let elem_count = layout.shape().elem_count();
|
||||||
|
let stream = sto.device.cuda_stream();
|
||||||
// TODO: support more dtypes.
|
// TODO: support more dtypes.
|
||||||
let sto = sto.as_cuda_slice::<f32>()?;
|
let sto = sto.as_cuda_slice::<f32>()?;
|
||||||
let sto = match layout.contiguous_offsets() {
|
let sto = match layout.contiguous_offsets() {
|
||||||
None => crate::bail!("input has to be contiguous"),
|
None => crate::bail!("input has to be contiguous"),
|
||||||
Some((o1, o2)) => sto.slice(o1..o2),
|
Some((o1, o2)) => sto.slice(o1..o2),
|
||||||
};
|
};
|
||||||
let params = (&sto,);
|
|
||||||
let (g, b) = if elem_count % 32 == 0 {
|
let (g, b) = if elem_count % 32 == 0 {
|
||||||
(elem_count / 32, 32)
|
(elem_count / 32, 32)
|
||||||
} else {
|
} else {
|
||||||
|
@ -479,7 +482,9 @@ impl InplaceOp1 for UgIOp1 {
|
||||||
block_dim: (b as u32, 1, 1),
|
block_dim: (b as u32, 1, 1),
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
unsafe { self.func.clone().launch(cfg, params) }.w()?;
|
let mut builder = stream.launch_builder(&self.func);
|
||||||
|
builder.arg(&sto);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
use super::{GgmlDType, QStorage};
|
use super::{GgmlDType, QStorage};
|
||||||
use crate::quantized::k_quants::GgmlType;
|
use crate::quantized::k_quants::GgmlType;
|
||||||
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
||||||
use crate::{CudaDevice, CudaStorage, Result};
|
use crate::{builder_arg as barg, CudaDevice, CudaStorage, Result};
|
||||||
use half::f16;
|
use half::f16;
|
||||||
|
|
||||||
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
|
use cudarc::driver::{CudaSlice, CudaView, PushKernelArg};
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
struct PaddedCudaSlice {
|
struct PaddedCudaSlice {
|
||||||
|
@ -50,19 +50,20 @@ fn quantize_q8_1(
|
||||||
ky: usize,
|
ky: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let kx = elem_count;
|
let kx = elem_count;
|
||||||
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
|
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
|
||||||
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
|
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
|
||||||
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func("quantize_q8_1", &candle_kernels::QUANTIZED)?;
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
grid_dim: (num_blocks as u32, ky as u32, 1),
|
grid_dim: (num_blocks as u32, ky as u32, 1),
|
||||||
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
|
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
let params = (src, dst, kx as i32, kx_padded as i32);
|
let mut builder = func.builder();
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
builder.arg(src);
|
||||||
|
builder.arg(dst);
|
||||||
|
barg!(builder, kx as i32, kx_padded as i32);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,8 +73,6 @@ fn dequantize_f32(
|
||||||
elem_count: usize,
|
elem_count: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let nb = (elem_count + 255) / 256;
|
let nb = (elem_count + 255) / 256;
|
||||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
|
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
|
||||||
|
@ -99,7 +98,7 @@ fn dequantize_f32(
|
||||||
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
|
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
|
||||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||||
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
|
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
|
||||||
// See e.g.
|
// See e.g.
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||||
|
@ -110,15 +109,20 @@ fn dequantize_f32(
|
||||||
};
|
};
|
||||||
|
|
||||||
if is_k {
|
if is_k {
|
||||||
let params = (&data.inner, &dst);
|
let mut builder = func.builder();
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
builder.arg(&data.inner);
|
||||||
|
builder.arg(&dst);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
} else {
|
} else {
|
||||||
let nb32 = match dtype {
|
let nb32 = match dtype {
|
||||||
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||||
_ => elem_count / 32,
|
_ => elem_count / 32,
|
||||||
};
|
};
|
||||||
let params = (&data.inner, &dst, nb32 as i32);
|
let mut builder = func.builder();
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
builder.arg(&data.inner);
|
||||||
|
builder.arg(&dst);
|
||||||
|
barg!(builder, nb32 as i32);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
}
|
}
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
}
|
}
|
||||||
|
@ -129,8 +133,6 @@ fn dequantize_f16(
|
||||||
elem_count: usize,
|
elem_count: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let nb = (elem_count + 255) / 256;
|
let nb = (elem_count + 255) / 256;
|
||||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
|
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
|
||||||
|
@ -156,7 +158,7 @@ fn dequantize_f16(
|
||||||
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
|
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
|
||||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||||
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
|
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
|
||||||
// See e.g.
|
// See e.g.
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||||
|
@ -167,15 +169,20 @@ fn dequantize_f16(
|
||||||
};
|
};
|
||||||
|
|
||||||
if is_k {
|
if is_k {
|
||||||
let params = (&data.inner, &dst);
|
let mut builder = func.builder();
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
builder.arg(&data.inner);
|
||||||
|
builder.arg(&dst);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
} else {
|
} else {
|
||||||
let nb32 = match dtype {
|
let nb32 = match dtype {
|
||||||
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||||
_ => elem_count / 32,
|
_ => elem_count / 32,
|
||||||
};
|
};
|
||||||
let params = (&data.inner, &dst, nb32 as i32);
|
let mut builder = func.builder();
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
builder.arg(&data.inner);
|
||||||
|
builder.arg(&dst);
|
||||||
|
barg!(builder, nb32 as i32);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
}
|
}
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
}
|
}
|
||||||
|
@ -188,8 +195,6 @@ fn dequantize_mul_mat_vec(
|
||||||
nrows: usize,
|
nrows: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
||||||
if data_elems < ncols * nrows {
|
if data_elems < ncols * nrows {
|
||||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||||
|
@ -210,7 +215,7 @@ fn dequantize_mul_mat_vec(
|
||||||
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
|
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
|
||||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||||
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
|
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
|
||||||
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
|
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
|
@ -219,8 +224,12 @@ fn dequantize_mul_mat_vec(
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let params = (&data.inner, y, &dst, ncols as i32, nrows as i32);
|
let mut builder = func.builder();
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
builder.arg(&data.inner);
|
||||||
|
builder.arg(y);
|
||||||
|
builder.arg(&dst);
|
||||||
|
barg!(builder, ncols as i32, nrows as i32);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -233,8 +242,6 @@ fn mul_mat_vec_via_q8_1(
|
||||||
b_size: usize,
|
b_size: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
||||||
if data_elems < ncols * nrows {
|
if data_elems < ncols * nrows {
|
||||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||||
|
@ -266,7 +273,7 @@ fn mul_mat_vec_via_q8_1(
|
||||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||||
};
|
};
|
||||||
let kernel_name = format!("{kernel_name}{b_size}");
|
let kernel_name = format!("{kernel_name}{b_size}");
|
||||||
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(&kernel_name, &candle_kernels::QUANTIZED)?;
|
||||||
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
|
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
||||||
let (nblocks, nwarps) = match b_size {
|
let (nblocks, nwarps) = match b_size {
|
||||||
|
@ -281,16 +288,18 @@ fn mul_mat_vec_via_q8_1(
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let params = (
|
let mut builder = func.builder();
|
||||||
&data.inner,
|
builder.arg(&data.inner);
|
||||||
&y_q8_1,
|
builder.arg(&y_q8_1);
|
||||||
&dst,
|
builder.arg(&dst);
|
||||||
|
barg!(
|
||||||
|
builder,
|
||||||
/* ncols_x */ ncols as i32,
|
/* ncols_x */ ncols as i32,
|
||||||
/* nrows_x */ nrows as i32,
|
/* nrows_x */ nrows as i32,
|
||||||
/* nrows_y */ ncols_padded as i32,
|
/* nrows_y */ ncols_padded as i32,
|
||||||
/* nrows_dst */ nrows as i32,
|
/* nrows_dst */ nrows as i32
|
||||||
);
|
);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -305,8 +314,6 @@ fn mul_mat_via_q8_1(
|
||||||
y_cols: usize,
|
y_cols: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
||||||
if data_elems < x_rows * x_cols {
|
if data_elems < x_rows * x_cols {
|
||||||
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
|
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
|
||||||
|
@ -338,7 +345,7 @@ fn mul_mat_via_q8_1(
|
||||||
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
|
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
|
||||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||||
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
|
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
grid_dim: (
|
grid_dim: (
|
||||||
|
@ -350,17 +357,19 @@ fn mul_mat_via_q8_1(
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let params = (
|
let mut builder = func.builder();
|
||||||
/* vx */ &data.inner,
|
builder.arg(/* vx */ &data.inner);
|
||||||
/* vy */ &y_q8_1,
|
builder.arg(/* vy */ &y_q8_1);
|
||||||
/* dst */ &dst,
|
builder.arg(/* dst */ &dst);
|
||||||
|
barg!(
|
||||||
|
builder,
|
||||||
/* ncols_x */ x_cols as i32,
|
/* ncols_x */ x_cols as i32,
|
||||||
/* nrows_x */ x_rows as i32,
|
/* nrows_x */ x_rows as i32,
|
||||||
/* ncols_y */ y_cols as i32,
|
/* ncols_y */ y_cols as i32,
|
||||||
/* nrows_y */ k_padded as i32,
|
/* nrows_y */ k_padded as i32,
|
||||||
/* nrows_dst */ x_rows as i32,
|
/* nrows_dst */ x_rows as i32
|
||||||
);
|
);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -416,7 +425,7 @@ impl QCudaStorage {
|
||||||
|
|
||||||
let buffer = self
|
let buffer = self
|
||||||
.device
|
.device
|
||||||
.dtoh_sync_copy(&self.data.inner.slice(..self.data.len))
|
.memcpy_dtov(&self.data.inner.slice(..self.data.len))
|
||||||
.w()?;
|
.w()?;
|
||||||
let mut out = vec![0.0; elem_count];
|
let mut out = vec![0.0; elem_count];
|
||||||
let block_len = elem_count / self.dtype.block_size();
|
let block_len = elem_count / self.dtype.block_size();
|
||||||
|
@ -449,7 +458,7 @@ impl QCudaStorage {
|
||||||
// Run the quantization on cpu.
|
// Run the quantization on cpu.
|
||||||
let src = match &src.slice {
|
let src = match &src.slice {
|
||||||
crate::cuda_backend::CudaStorageSlice::F32(data) => {
|
crate::cuda_backend::CudaStorageSlice::F32(data) => {
|
||||||
self.device.dtoh_sync_copy(data).w()?
|
self.device.memcpy_dtov(data).w()?
|
||||||
}
|
}
|
||||||
_ => crate::bail!("only f32 can be quantized"),
|
_ => crate::bail!("only f32 can be quantized"),
|
||||||
};
|
};
|
||||||
|
@ -462,7 +471,7 @@ impl QCudaStorage {
|
||||||
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
|
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
|
||||||
let mut inner = unsafe { self.device.alloc::<u8>(padded_len).w()? };
|
let mut inner = unsafe { self.device.alloc::<u8>(padded_len).w()? };
|
||||||
self.device
|
self.device
|
||||||
.htod_sync_copy_into(data.as_ref(), &mut inner.slice_mut(..data.len()))
|
.memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))
|
||||||
.w()?;
|
.w()?;
|
||||||
self.data = PaddedCudaSlice {
|
self.data = PaddedCudaSlice {
|
||||||
inner,
|
inner,
|
||||||
|
@ -599,7 +608,7 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||||
let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size();
|
let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size();
|
||||||
let mut inner = unsafe { device.alloc::<u8>(padded_len).w()? };
|
let mut inner = unsafe { device.alloc::<u8>(padded_len).w()? };
|
||||||
device
|
device
|
||||||
.htod_sync_copy_into(data, &mut inner.slice_mut(..data.len()))
|
.memcpy_htod(data, &mut inner.slice_mut(..data.len()))
|
||||||
.w()?;
|
.w()?;
|
||||||
Ok(QStorage::Cuda(QCudaStorage {
|
Ok(QStorage::Cuda(QCudaStorage {
|
||||||
data: PaddedCudaSlice {
|
data: PaddedCudaSlice {
|
||||||
|
@ -624,7 +633,7 @@ mod test {
|
||||||
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||||
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
|
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
|
||||||
let y = dev.htod_sync_copy(&vs).w()?;
|
let y = dev.memcpy_stod(&vs).w()?;
|
||||||
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
|
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -634,7 +643,7 @@ mod test {
|
||||||
let dev = CudaDevice::new(0)?;
|
let dev = CudaDevice::new(0)?;
|
||||||
let ncols = 256;
|
let ncols = 256;
|
||||||
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
|
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
|
||||||
let y = dev.htod_sync_copy(&vs).w()?;
|
let y = dev.memcpy_stod(&vs).w()?;
|
||||||
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
|
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
|
||||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||||
let cuda_storage = mul_mat_vec_via_q8_1(
|
let cuda_storage = mul_mat_vec_via_q8_1(
|
||||||
|
@ -647,7 +656,7 @@ mod test {
|
||||||
&dev,
|
&dev,
|
||||||
)?;
|
)?;
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||||
assert_eq!(vs.len(), 1);
|
assert_eq!(vs.len(), 1);
|
||||||
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
|
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
|
||||||
// Q8 means 1/256 precision.
|
// Q8 means 1/256 precision.
|
||||||
|
@ -662,7 +671,7 @@ mod test {
|
||||||
&dev,
|
&dev,
|
||||||
)?;
|
)?;
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||||
assert_eq!(vs.len(), 1);
|
assert_eq!(vs.len(), 1);
|
||||||
assert_eq!(vs[0], 5561851.0);
|
assert_eq!(vs[0], 5561851.0);
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -673,7 +682,7 @@ mod test {
|
||||||
let dev = CudaDevice::new(0)?;
|
let dev = CudaDevice::new(0)?;
|
||||||
let ncols = 256;
|
let ncols = 256;
|
||||||
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
|
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
|
||||||
let y = dev.htod_sync_copy(&vs).w()?;
|
let y = dev.memcpy_stod(&vs).w()?;
|
||||||
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
|
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
|
||||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||||
let cuda_storage = mul_mat_via_q8_1(
|
let cuda_storage = mul_mat_via_q8_1(
|
||||||
|
@ -687,7 +696,7 @@ mod test {
|
||||||
&dev,
|
&dev,
|
||||||
)?;
|
)?;
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||||
|
|
||||||
/*
|
/*
|
||||||
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
|
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
|
||||||
|
@ -714,7 +723,7 @@ mod test {
|
||||||
let dev = CudaDevice::new(0)?;
|
let dev = CudaDevice::new(0)?;
|
||||||
let (x_rows, ncols, y_cols) = (4, 16, 2048);
|
let (x_rows, ncols, y_cols) = (4, 16, 2048);
|
||||||
let vs: Vec<f32> = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect();
|
let vs: Vec<f32> = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect();
|
||||||
let y = dev.htod_sync_copy(&vs).w()?;
|
let y = dev.memcpy_stod(&vs).w()?;
|
||||||
let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?;
|
let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?;
|
||||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||||
let cuda_storage = mul_mat_via_q8_1(
|
let cuda_storage = mul_mat_via_q8_1(
|
||||||
|
@ -728,7 +737,7 @@ mod test {
|
||||||
&dev,
|
&dev,
|
||||||
)?;
|
)?;
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
let _vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
let _vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,7 +56,7 @@ impl ArgSort {
|
||||||
mod cuda {
|
mod cuda {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::cuda_backend::cudarc::driver::{
|
use crate::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
CudaSlice, DeviceRepr, LaunchConfig, ValidAsZeroBits,
|
||||||
};
|
};
|
||||||
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
|
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
|
||||||
use crate::{CudaDevice, WithDType};
|
use crate::{CudaDevice, WithDType};
|
||||||
|
@ -69,6 +69,8 @@ mod cuda {
|
||||||
layout: &crate::Layout,
|
layout: &crate::Layout,
|
||||||
_wrap: W,
|
_wrap: W,
|
||||||
) -> Result<S> {
|
) -> Result<S> {
|
||||||
|
use cudarc::driver::PushKernelArg;
|
||||||
|
|
||||||
let slice = match layout.contiguous_offsets() {
|
let slice = match layout.contiguous_offsets() {
|
||||||
None => crate::bail!("input has to be contiguous"),
|
None => crate::bail!("input has to be contiguous"),
|
||||||
Some((o1, o2)) => src.slice(o1..o2),
|
Some((o1, o2)) => src.slice(o1..o2),
|
||||||
|
@ -76,20 +78,24 @@ mod cuda {
|
||||||
let elem_count = layout.shape().elem_count();
|
let elem_count = layout.shape().elem_count();
|
||||||
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
|
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
|
||||||
let func = if self.asc {
|
let func = if self.asc {
|
||||||
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
|
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), &kernels::SORT)?
|
||||||
} else {
|
} else {
|
||||||
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
|
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), &kernels::SORT)?
|
||||||
};
|
};
|
||||||
let ncols = self.last_dim;
|
let ncols = self.last_dim;
|
||||||
let nrows = elem_count / ncols;
|
let nrows = elem_count / ncols;
|
||||||
let ncols_pad = next_power_of_2(ncols);
|
let ncols_pad = next_power_of_2(ncols);
|
||||||
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
|
|
||||||
let cfg = LaunchConfig {
|
let cfg = LaunchConfig {
|
||||||
grid_dim: (1, nrows as u32, 1),
|
grid_dim: (1, nrows as u32, 1),
|
||||||
block_dim: (ncols_pad as u32, 1, 1),
|
block_dim: (ncols_pad as u32, 1, 1),
|
||||||
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
|
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
|
||||||
};
|
};
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let stream = dev.cuda_stream();
|
||||||
|
let mut builder = stream.launch_builder(&func);
|
||||||
|
let ncols = ncols as i32;
|
||||||
|
let ncols_pad = ncols_pad as i32;
|
||||||
|
builder.arg(&slice).arg(&dst).arg(&ncols).arg(&ncols_pad);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(S::U32(dst))
|
Ok(S::U32(dst))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,7 +56,7 @@ impl CustomOp1 for LayerNorm {
|
||||||
layout: &Layout,
|
layout: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::backend::BackendStorage;
|
use candle::backend::BackendStorage;
|
||||||
use candle::cuda_backend::cudarc::driver::{LaunchAsync, LaunchConfig};
|
use candle::cuda_backend::cudarc::driver::{LaunchConfig, PushKernelArg};
|
||||||
use candle::cuda_backend::WrapErr;
|
use candle::cuda_backend::WrapErr;
|
||||||
let (d1, d2) = layout.shape().dims2()?;
|
let (d1, d2) = layout.shape().dims2()?;
|
||||||
let d1 = d1 as u32;
|
let d1 = d1 as u32;
|
||||||
|
@ -69,14 +69,18 @@ impl CustomOp1 for LayerNorm {
|
||||||
};
|
};
|
||||||
let elem_count = layout.shape().elem_count();
|
let elem_count = layout.shape().elem_count();
|
||||||
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
||||||
let func = dev.get_or_load_func("rms_f32", cuda_kernels::LAYERNORM_KERNELS)?;
|
let func =
|
||||||
let params = (&dst, &slice, self.eps, d1, d2);
|
dev.get_or_load_custom_func("rms_f32", "mymodule", cuda_kernels::LAYERNORM_KERNELS)?;
|
||||||
let cfg = LaunchConfig {
|
let cfg = LaunchConfig {
|
||||||
grid_dim: (d1, 1, 1),
|
grid_dim: (d1, 1, 1),
|
||||||
block_dim: (d2, 1, 1),
|
block_dim: (d2, 1, 1),
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let mut builder = func.builder();
|
||||||
|
builder.arg(&dst);
|
||||||
|
builder.arg(&slice);
|
||||||
|
candle::builder_arg!(builder, self.eps, d1, d2);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
|
|
||||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
||||||
Ok((dst, layout.shape().clone()))
|
Ok((dst, layout.shape().clone()))
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "candle-flash-attn"
|
name = "candle-flash-attn"
|
||||||
version = "0.8.4"
|
version = "0.9.0-alpha.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Flash attention layer for the candle ML framework."
|
description = "Flash attention layer for the candle ML framework."
|
||||||
|
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.4" }
|
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.1" }
|
||||||
half = { version = "2.3.1", features = ["num-traits"] }
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
|
|
@ -88,6 +88,7 @@ impl FlashAttn {
|
||||||
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
|
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let stream = dev.cuda_stream();
|
||||||
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
||||||
if alibi_slopes.dtype() != DType::F32 {
|
if alibi_slopes.dtype() != DType::F32 {
|
||||||
candle::bail!(
|
candle::bail!(
|
||||||
|
@ -114,7 +115,9 @@ impl FlashAttn {
|
||||||
|
|
||||||
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
||||||
|
|
||||||
*alibi_slopes.device_ptr() as *const core::ffi::c_void
|
// Dropping the guard here doesn't seem very safe.
|
||||||
|
let (ptr, _guard) = alibi_slopes.device_ptr(&stream);
|
||||||
|
ptr as *const core::ffi::c_void
|
||||||
} else {
|
} else {
|
||||||
std::ptr::null()
|
std::ptr::null()
|
||||||
};
|
};
|
||||||
|
@ -161,17 +164,17 @@ impl FlashAttn {
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
let (q_ptr, _guard) = q.device_ptr(&stream);
|
||||||
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
let (k_ptr, _guard) = k.device_ptr(&stream);
|
||||||
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
|
let (v_ptr, _guard) = v.device_ptr(&stream);
|
||||||
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
|
let (dst_ptr, _guard) = dst.device_ptr(&stream);
|
||||||
let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void;
|
let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream);
|
||||||
ffi::run_mha(
|
ffi::run_mha(
|
||||||
q_ptr,
|
q_ptr as *const core::ffi::c_void,
|
||||||
k_ptr,
|
k_ptr as *const core::ffi::c_void,
|
||||||
v_ptr,
|
v_ptr as *const core::ffi::c_void,
|
||||||
dst_ptr,
|
dst_ptr as *const core::ffi::c_void,
|
||||||
softmax_lse_ptr,
|
softmax_lse_ptr as *const core::ffi::c_void,
|
||||||
/* alibi_slopes_ptr */ alibi_slopes_ptr,
|
/* alibi_slopes_ptr */ alibi_slopes_ptr,
|
||||||
/* cu_seqlens_q_ptr */ std::ptr::null(),
|
/* cu_seqlens_q_ptr */ std::ptr::null(),
|
||||||
/* cu_seqlens_k_ptr */ std::ptr::null(),
|
/* cu_seqlens_k_ptr */ std::ptr::null(),
|
||||||
|
@ -550,6 +553,7 @@ impl FlashAttnVarLen {
|
||||||
|
|
||||||
let batch_size = nseqlens_q - 1;
|
let batch_size = nseqlens_q - 1;
|
||||||
|
|
||||||
|
let stream = dev.cuda_stream();
|
||||||
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
||||||
if alibi_slopes.dtype() != DType::F32 {
|
if alibi_slopes.dtype() != DType::F32 {
|
||||||
candle::bail!(
|
candle::bail!(
|
||||||
|
@ -576,7 +580,9 @@ impl FlashAttnVarLen {
|
||||||
|
|
||||||
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
||||||
|
|
||||||
*alibi_slopes.device_ptr() as *const core::ffi::c_void
|
// Dropping the guard here doesn't seem very safe.
|
||||||
|
let (ptr, _guard) = alibi_slopes.device_ptr(&stream);
|
||||||
|
ptr as *const core::ffi::c_void
|
||||||
} else {
|
} else {
|
||||||
std::ptr::null()
|
std::ptr::null()
|
||||||
};
|
};
|
||||||
|
@ -621,22 +627,22 @@ impl FlashAttnVarLen {
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
let (q_ptr, _guard) = q.device_ptr(&stream);
|
||||||
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
let (k_ptr, _guard) = k.device_ptr(&stream);
|
||||||
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
|
let (v_ptr, _guard) = v.device_ptr(&stream);
|
||||||
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
|
let (dst_ptr, _guard) = dst.device_ptr(&stream);
|
||||||
let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void;
|
let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream);
|
||||||
let seqlens_q_ptr = *seqlens_q.device_ptr() as *const core::ffi::c_int;
|
let (seqlens_q_ptr, _guard) = seqlens_q.device_ptr(&stream);
|
||||||
let seqlens_k_ptr = *seqlens_k.device_ptr() as *const core::ffi::c_int;
|
let (seqlens_k_ptr, _guard) = seqlens_k.device_ptr(&stream);
|
||||||
ffi::run_mha(
|
ffi::run_mha(
|
||||||
q_ptr,
|
q_ptr as *const core::ffi::c_void,
|
||||||
k_ptr,
|
k_ptr as *const core::ffi::c_void,
|
||||||
v_ptr,
|
v_ptr as *const core::ffi::c_void,
|
||||||
dst_ptr,
|
dst_ptr as *const core::ffi::c_void,
|
||||||
softmax_lse_ptr,
|
softmax_lse_ptr as *const core::ffi::c_void,
|
||||||
/* alibi_slopes_ptr */ alibi_slopes_ptr,
|
/* alibi_slopes_ptr */ alibi_slopes_ptr as *const core::ffi::c_void,
|
||||||
/* cu_seqlens_q_ptr */ seqlens_q_ptr,
|
/* cu_seqlens_q_ptr */ seqlens_q_ptr as *const i32,
|
||||||
/* cu_seqlens_k_ptr */ seqlens_k_ptr,
|
/* cu_seqlens_k_ptr */ seqlens_k_ptr as *const i32,
|
||||||
/* q_batch_stride */ 0,
|
/* q_batch_stride */ 0,
|
||||||
/* k_batch_stride */ 0,
|
/* k_batch_stride */ 0,
|
||||||
/* v_batch_stride */ 0,
|
/* v_batch_stride */ 0,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "candle-kernels"
|
name = "candle-kernels"
|
||||||
version = "0.8.4"
|
version = "0.9.0-alpha.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "CUDA kernels for Candle"
|
description = "CUDA kernels for Candle"
|
||||||
|
|
|
@ -7,5 +7,5 @@ fn main() {
|
||||||
let builder = bindgen_cuda::Builder::default();
|
let builder = bindgen_cuda::Builder::default();
|
||||||
println!("cargo:info={builder:?}");
|
println!("cargo:info={builder:?}");
|
||||||
let bindings = builder.build_ptx().unwrap();
|
let bindings = builder.build_ptx().unwrap();
|
||||||
bindings.write("src/lib.rs").unwrap();
|
bindings.write("src/ptx.rs").unwrap();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,78 @@
|
||||||
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
mod ptx;
|
||||||
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
|
||||||
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
|
#[repr(u32)]
|
||||||
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
pub enum Id {
|
||||||
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
|
Affine,
|
||||||
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
|
Binary,
|
||||||
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
Cast,
|
||||||
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
|
Conv,
|
||||||
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
Fill,
|
||||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
Indexing,
|
||||||
|
Quantized,
|
||||||
|
Reduce,
|
||||||
|
Sort,
|
||||||
|
Ternary,
|
||||||
|
Unary,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const ALL_IDS: [Id; 11] = [
|
||||||
|
Id::Affine,
|
||||||
|
Id::Binary,
|
||||||
|
Id::Cast,
|
||||||
|
Id::Conv,
|
||||||
|
Id::Fill,
|
||||||
|
Id::Indexing,
|
||||||
|
Id::Quantized,
|
||||||
|
Id::Reduce,
|
||||||
|
Id::Sort,
|
||||||
|
Id::Ternary,
|
||||||
|
Id::Unary,
|
||||||
|
];
|
||||||
|
|
||||||
|
pub struct Module {
|
||||||
|
index: usize,
|
||||||
|
ptx: &'static str,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module {
|
||||||
|
pub fn index(&self) -> usize {
|
||||||
|
self.index
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ptx(&self) -> &'static str {
|
||||||
|
self.ptx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn module_index(id: Id) -> usize {
|
||||||
|
let mut i = 0;
|
||||||
|
while i < ALL_IDS.len() {
|
||||||
|
if ALL_IDS[i] as u32 == id as u32 {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
panic!("id not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! mdl {
|
||||||
|
($cst:ident, $id:ident) => {
|
||||||
|
pub const $cst: Module = Module {
|
||||||
|
index: module_index(Id::$id),
|
||||||
|
ptx: ptx::$cst,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
mdl!(AFFINE, Affine);
|
||||||
|
mdl!(BINARY, Binary);
|
||||||
|
mdl!(CAST, Cast);
|
||||||
|
mdl!(CONV, Conv);
|
||||||
|
mdl!(FILL, Fill);
|
||||||
|
mdl!(INDEXING, Indexing);
|
||||||
|
mdl!(QUANTIZED, Quantized);
|
||||||
|
mdl!(REDUCE, Reduce);
|
||||||
|
mdl!(SORT, Sort);
|
||||||
|
mdl!(TERNARY, Ternary);
|
||||||
|
mdl!(UNARY, Unary);
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
||||||
|
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
||||||
|
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
|
||||||
|
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
|
||||||
|
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||||
|
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
|
||||||
|
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
|
||||||
|
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||||
|
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
|
||||||
|
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
||||||
|
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "candle-metal-kernels"
|
name = "candle-metal-kernels"
|
||||||
version = "0.8.4"
|
version = "0.9.0-alpha.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Metal kernels for Candle"
|
description = "Metal kernels for Candle"
|
||||||
|
|
|
@ -90,7 +90,7 @@ impl candle::CustomOp1 for Sigmoid {
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::backend::BackendStorage;
|
use candle::backend::BackendStorage;
|
||||||
use candle::cuda_backend::cudarc::driver::{
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits,
|
||||||
};
|
};
|
||||||
use candle::cuda_backend::SlicePtrOrNull;
|
use candle::cuda_backend::SlicePtrOrNull;
|
||||||
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
||||||
|
@ -110,13 +110,17 @@ impl candle::CustomOp1 for Sigmoid {
|
||||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), kernels::UNARY)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), &kernels::UNARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
|
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
|
||||||
|
|
||||||
let params = (el_count, dims.len(), &ds, src, &out);
|
let mut builder = func.builder();
|
||||||
|
candle::builder_arg!(builder, el_count, dims.len());
|
||||||
|
ds.builder_arg(&mut builder);
|
||||||
|
builder.arg(src);
|
||||||
|
builder.arg(&out);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(out)
|
Ok(out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -340,7 +344,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||||
layout: &Layout,
|
layout: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::cuda_backend::cudarc::driver::{
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||||
};
|
};
|
||||||
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
||||||
use candle::{CudaDevice, WithDType};
|
use candle::{CudaDevice, WithDType};
|
||||||
|
@ -367,12 +371,15 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||||
block_dim: (1, 32, 1),
|
block_dim: (1, 32, 1),
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let params = (&src, &dst, n_cols as i32);
|
let mut builder = func.builder();
|
||||||
|
builder.arg(&src);
|
||||||
|
builder.arg(&dst);
|
||||||
|
candle::builder_arg!(builder, n_cols as i32);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -516,7 +523,7 @@ impl candle::CustomOp2 for RmsNorm {
|
||||||
l2: &Layout,
|
l2: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::cuda_backend::cudarc::driver::{
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||||
};
|
};
|
||||||
use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
|
use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
|
||||||
use candle::{CudaDevice, WithDType};
|
use candle::{CudaDevice, WithDType};
|
||||||
|
@ -552,19 +559,16 @@ impl candle::CustomOp2 for RmsNorm {
|
||||||
block_dim: (block_size, 1, 1),
|
block_dim: (block_size, 1, 1),
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let params = (
|
let mut builder = func.builder();
|
||||||
&src,
|
builder.arg(&src);
|
||||||
&dst,
|
builder.arg(&dst);
|
||||||
&alpha,
|
builder.arg(&alpha);
|
||||||
n_cols as i32,
|
candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
|
||||||
block_size as i32,
|
|
||||||
self.eps,
|
|
||||||
);
|
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -751,7 +755,7 @@ impl candle::CustomOp3 for LayerNorm {
|
||||||
l3: &Layout,
|
l3: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::cuda_backend::cudarc::driver::{
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||||
};
|
};
|
||||||
use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
|
use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
|
||||||
use candle::{CudaDevice, WithDType};
|
use candle::{CudaDevice, WithDType};
|
||||||
|
@ -793,20 +797,18 @@ impl candle::CustomOp3 for LayerNorm {
|
||||||
block_dim: (block_size, 1, 1),
|
block_dim: (block_size, 1, 1),
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?;
|
let func =
|
||||||
|
dev.get_or_load_func(&kernel_name::<T>("layernorm"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let params = (
|
let mut builder = func.builder();
|
||||||
&src,
|
builder.arg(&src);
|
||||||
&dst,
|
builder.arg(&dst);
|
||||||
&alpha,
|
builder.arg(&alpha);
|
||||||
&beta,
|
builder.arg(&beta);
|
||||||
n_cols as i32,
|
candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
|
||||||
block_size as i32,
|
|
||||||
self.eps,
|
|
||||||
);
|
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -88,7 +88,7 @@ impl candle::CustomOp3 for RotaryEmbI {
|
||||||
l3: &Layout,
|
l3: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::cuda_backend::cudarc::driver::{
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||||
};
|
};
|
||||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||||
use candle::{CudaDevice, WithDType};
|
use candle::{CudaDevice, WithDType};
|
||||||
|
@ -117,12 +117,17 @@ impl candle::CustomOp3 for RotaryEmbI {
|
||||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||||
let el = b * h * t * d;
|
let el = b * h * t * d;
|
||||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let params = (&src, &cos, &sin, &dst, (b * h) as u32, (t * d) as u32);
|
let mut builder = func.builder();
|
||||||
|
builder.arg(&src);
|
||||||
|
builder.arg(&cos);
|
||||||
|
builder.arg(&sin);
|
||||||
|
builder.arg(&dst);
|
||||||
|
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -333,7 +338,7 @@ impl candle::CustomOp3 for RotaryEmb {
|
||||||
l3: &Layout,
|
l3: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::cuda_backend::cudarc::driver::{
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||||
};
|
};
|
||||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||||
use candle::{CudaDevice, WithDType};
|
use candle::{CudaDevice, WithDType};
|
||||||
|
@ -362,20 +367,17 @@ impl candle::CustomOp3 for RotaryEmb {
|
||||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||||
let el = b * h * t * d;
|
let el = b * h * t * d;
|
||||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let params = (
|
let mut builder = func.builder();
|
||||||
&src,
|
builder.arg(&src);
|
||||||
&cos,
|
builder.arg(&cos);
|
||||||
&sin,
|
builder.arg(&sin);
|
||||||
&dst,
|
builder.arg(&dst);
|
||||||
(b * h) as u32,
|
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32);
|
||||||
(t * d) as u32,
|
|
||||||
d as u32,
|
|
||||||
);
|
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -587,7 +589,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
||||||
l3: &Layout,
|
l3: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::cuda_backend::cudarc::driver::{
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||||
};
|
};
|
||||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||||
use candle::{CudaDevice, WithDType};
|
use candle::{CudaDevice, WithDType};
|
||||||
|
@ -616,14 +618,17 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
||||||
let (b, t, h, d) = l_src.shape().dims4()?;
|
let (b, t, h, d) = l_src.shape().dims4()?;
|
||||||
let el = b * h * t * d;
|
let el = b * h * t * d;
|
||||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let params = (
|
let mut builder = func.builder();
|
||||||
&src, &cos, &sin, &dst, b as u32, t as u32, h as u32, d as u32,
|
builder.arg(&src);
|
||||||
);
|
builder.arg(&cos);
|
||||||
|
builder.arg(&sin);
|
||||||
|
builder.arg(&dst);
|
||||||
|
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "candle-onnx"
|
name = "candle-onnx"
|
||||||
version = "0.8.4"
|
version = "0.9.0-alpha.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "ONNX support for Candle"
|
description = "ONNX support for Candle"
|
||||||
|
@ -10,8 +10,8 @@ categories = ["science"]
|
||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", package = "candle-core", version = "0.8.4" }
|
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.1" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.8.4" }
|
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.1" }
|
||||||
prost = "0.12.1"
|
prost = "0.12.1"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
|
Loading…
Reference in New Issue