Update to cudarc 0.14 (breaking change). (#2858)
* Start updating to cudarc 0.14. * Adapt a couple more things. * And a couple more fixes. * More tweaks. * And a couple more fixes. * Bump the major version number. * Proper module system for the cuda kernels. * Proper ptx loading. * Launch the sort kernel. * Custom op. * Start using the builder pattern. * More builder. * More builder. * Get candle-core to compile. * Get the tests to pass. * Get candle-nn to work too. * Support for custom cuda functions. * cudnn fixes. * Get flash attn to run. * Switch the crate versions to be alpha. * Bump the ug dependency.
This commit is contained in:
parent
d6db305829
commit
d9904a3baf
26
Cargo.toml
26
Cargo.toml
|
@ -20,7 +20,7 @@ exclude = [
|
|||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.8.4"
|
||||
version = "0.9.0-alpha.1"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
|
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
|
|||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.8.4" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.8.4" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.4" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.8.4" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.4" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.8.4" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.8.4" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.8.4" }
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.1" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.1" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.1" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.1" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.1" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.1" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.1" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.1" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.13.5", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
cudarc = { version = "0.14.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
fancy-regex = "0.13.0"
|
||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.4.1"
|
||||
|
@ -70,9 +70,9 @@ tokenizers = { version = "0.21.0", default-features = false }
|
|||
tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
ug = "0.1.0"
|
||||
ug-cuda = "0.1.0"
|
||||
ug-metal = "0.1.0"
|
||||
ug = "0.2.0"
|
||||
ug-cuda = "0.2.0"
|
||||
ug-metal = "0.2.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "1.1.1", default-features = false }
|
||||
metal = { version = "0.27.0", features = ["mps"]}
|
||||
|
|
|
@ -43,7 +43,7 @@ pub(crate) fn launch_conv2d<
|
|||
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
||||
return Ok(cudnn.clone());
|
||||
}
|
||||
let c = Cudnn::new(dev.cuda_device());
|
||||
let c = Cudnn::new(dev.cuda_stream());
|
||||
if let Ok(c) = &c {
|
||||
cudnn.borrow_mut().insert(device_id, c.clone());
|
||||
}
|
||||
|
@ -109,7 +109,7 @@ pub(crate) fn launch_conv2d<
|
|||
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
||||
};
|
||||
let workspace_size = conv2d.get_workspace_size(alg)?;
|
||||
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
|
||||
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
|
||||
unsafe {
|
||||
conv2d.launch::<CudaSlice<u8>, _, _, _>(
|
||||
alg,
|
||||
|
|
|
@ -2,8 +2,9 @@ use crate::backend::BackendDevice;
|
|||
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
|
||||
use cudarc::driver::{CudaFunction, LaunchConfig, PushKernelArg};
|
||||
use half::{bf16, f16};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
|
||||
|
@ -24,10 +25,17 @@ impl DeviceId {
|
|||
struct CudaRng(cudarc::curand::CudaRng);
|
||||
unsafe impl Send for CudaRng {}
|
||||
|
||||
pub struct ModuleStore {
|
||||
mdls: [Option<Arc<cudarc::driver::CudaModule>>; kernels::ALL_IDS.len()],
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CudaDevice {
|
||||
id: DeviceId,
|
||||
device: Arc<cudarc::driver::CudaDevice>,
|
||||
context: Arc<cudarc::driver::CudaContext>,
|
||||
modules: Arc<std::sync::RwLock<ModuleStore>>,
|
||||
custom_modules: Arc<std::sync::RwLock<HashMap<String, Arc<cudarc::driver::CudaModule>>>>,
|
||||
stream: Arc<cudarc::driver::CudaStream>,
|
||||
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
|
||||
curand: Arc<Mutex<CudaRng>>,
|
||||
}
|
||||
|
@ -39,16 +47,51 @@ impl std::fmt::Debug for CudaDevice {
|
|||
}
|
||||
|
||||
impl std::ops::Deref for CudaDevice {
|
||||
type Target = Arc<cudarc::driver::CudaDevice>;
|
||||
type Target = Arc<cudarc::driver::CudaStream>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.device
|
||||
&self.stream
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CudaFunc {
|
||||
func: CudaFunction,
|
||||
stream: Arc<cudarc::driver::CudaStream>,
|
||||
}
|
||||
|
||||
impl std::ops::Deref for CudaFunc {
|
||||
type Target = CudaFunction;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.func
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaFunc {
|
||||
pub fn into_cuda_function(self) -> CudaFunction {
|
||||
self.func
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! builder_arg {
|
||||
($b:ident, $($arg:expr),*) => {
|
||||
$(
|
||||
let __arg = $arg;
|
||||
$b.arg(&__arg);
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
impl CudaFunc {
|
||||
pub fn builder(&self) -> cudarc::driver::LaunchArgs<'_> {
|
||||
self.stream.launch_builder(&self.func)
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaDevice {
|
||||
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
|
||||
self.device.clone()
|
||||
pub fn cuda_stream(&self) -> Arc<cudarc::driver::CudaStream> {
|
||||
self.stream.clone()
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
|
@ -56,7 +99,7 @@ impl CudaDevice {
|
|||
&self,
|
||||
func_name: &'static str,
|
||||
kernel: ug::lang::ssa::Kernel,
|
||||
) -> Result<CudaFunction> {
|
||||
) -> Result<CudaFunc> {
|
||||
let mut buf = vec![];
|
||||
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
|
||||
let cuda_code = String::from_utf8(buf)?;
|
||||
|
@ -65,12 +108,12 @@ impl CudaDevice {
|
|||
..Default::default()
|
||||
};
|
||||
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
|
||||
self.device.load_ptx(ptx, "ug", &[func_name]).w()?;
|
||||
let func = match self.device.get_func("ug", func_name) {
|
||||
Some(func) => func,
|
||||
None => crate::bail!("unknown function ug::{func_name}"),
|
||||
};
|
||||
Ok(func)
|
||||
let module = self.context.load_module(ptx).w()?;
|
||||
let func = module.load_function(func_name).w()?;
|
||||
Ok(CudaFunc {
|
||||
func,
|
||||
stream: self.stream.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn id(&self) -> DeviceId {
|
||||
|
@ -84,57 +127,84 @@ impl CudaDevice {
|
|||
DType::U8 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
|
||||
let params = (&data, v as u8, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as u8;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
|
||||
let params = (&data, v as u32, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as u32;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
|
||||
let params = (&data, v as i64, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as i64;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
|
||||
let params = (&data, bf16::from_f64(v), elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = bf16::from_f64(v);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
|
||||
let params = (&data, f16::from_f64(v), elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = f16::from_f64(v);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
|
||||
let params = (&data, v as f32, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as f32;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
|
||||
let params = (&data, v, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
|
@ -144,38 +214,69 @@ impl CudaDevice {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
||||
if !self.has_func(module_name, module_name) {
|
||||
// Leaking the string here is a bit sad but we need a &'static str and this is only
|
||||
// done once per kernel name.
|
||||
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
|
||||
self.load_ptx(ptx.into(), module_name, &[static_module_name])
|
||||
.map_err(|cuda| CudaError::Load {
|
||||
cuda,
|
||||
module_name: module_name.to_string(),
|
||||
})
|
||||
.w()?;
|
||||
pub fn get_or_load_custom_func(
|
||||
&self,
|
||||
fn_name: &str,
|
||||
module_name: &str,
|
||||
ptx: &str,
|
||||
) -> Result<CudaFunc> {
|
||||
let ms = self.custom_modules.read().unwrap();
|
||||
if let Some(mdl) = ms.get(module_name).as_ref() {
|
||||
let func = mdl.load_function(fn_name).w()?;
|
||||
return Ok(CudaFunc {
|
||||
func,
|
||||
stream: self.stream.clone(),
|
||||
});
|
||||
}
|
||||
self.get_func(module_name, module_name)
|
||||
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
||||
// able to only build the error value if needed.
|
||||
.ok_or(CudaError::MissingKernel {
|
||||
module_name: module_name.to_string(),
|
||||
})
|
||||
.w()
|
||||
drop(ms);
|
||||
let mut ms = self.custom_modules.write().unwrap();
|
||||
let cuda_module = self.context.load_module(ptx.into()).w()?;
|
||||
ms.insert(module_name.to_string(), cuda_module.clone());
|
||||
let func = cuda_module.load_function(fn_name).w()?;
|
||||
Ok(CudaFunc {
|
||||
func,
|
||||
stream: self.stream.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_or_load_func(&self, fn_name: &str, mdl: &kernels::Module) -> Result<CudaFunc> {
|
||||
let ms = self.modules.read().unwrap();
|
||||
if let Some(mdl) = ms.mdls[mdl.index()].as_ref() {
|
||||
let func = mdl.load_function(fn_name).w()?;
|
||||
return Ok(CudaFunc {
|
||||
func,
|
||||
stream: self.stream.clone(),
|
||||
});
|
||||
}
|
||||
drop(ms);
|
||||
let mut ms = self.modules.write().unwrap();
|
||||
let cuda_module = self.context.load_module(mdl.ptx().into()).w()?;
|
||||
ms.mdls[mdl.index()] = Some(cuda_module.clone());
|
||||
let func = cuda_module.load_function(fn_name).w()?;
|
||||
Ok(CudaFunc {
|
||||
func,
|
||||
stream: self.stream.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaDevice {
|
||||
pub fn new_with_stream(ordinal: usize) -> Result<Self> {
|
||||
let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?;
|
||||
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
||||
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
||||
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
|
||||
let stream = context.new_stream().w()?;
|
||||
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
|
||||
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
|
||||
let module_store = ModuleStore {
|
||||
mdls: [const { None }; kernels::ALL_IDS.len()],
|
||||
};
|
||||
Ok(Self {
|
||||
id: DeviceId::new(),
|
||||
device,
|
||||
context,
|
||||
stream,
|
||||
blas: Arc::new(blas),
|
||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||
modules: Arc::new(std::sync::RwLock::new(module_store)),
|
||||
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -184,14 +285,21 @@ impl BackendDevice for CudaDevice {
|
|||
type Storage = CudaStorage;
|
||||
|
||||
fn new(ordinal: usize) -> Result<Self> {
|
||||
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
|
||||
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
||||
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
||||
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
|
||||
let stream = context.default_stream();
|
||||
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
|
||||
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
|
||||
let module_store = ModuleStore {
|
||||
mdls: [const { None }; kernels::ALL_IDS.len()],
|
||||
};
|
||||
Ok(Self {
|
||||
id: DeviceId::new(),
|
||||
device,
|
||||
context,
|
||||
stream,
|
||||
blas: Arc::new(blas),
|
||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||
modules: Arc::new(std::sync::RwLock::new(module_store)),
|
||||
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -199,13 +307,13 @@ impl BackendDevice for CudaDevice {
|
|||
// We do not call set_seed but instead create a new curand object. This ensures that the
|
||||
// state will be identical and the same random numbers will be generated.
|
||||
let mut curand = self.curand.lock().unwrap();
|
||||
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
||||
curand.0 = cudarc::curand::CudaRng::new(seed, self.stream.clone()).w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
crate::DeviceLocation::Cuda {
|
||||
gpu_id: self.device.ordinal(),
|
||||
gpu_id: self.context.ordinal(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -373,31 +481,31 @@ impl BackendDevice for CudaDevice {
|
|||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||
let slice = match T::cpu_storage_ref(s) {
|
||||
CpuStorageRef::U8(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorageRef::U32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorageRef::I64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorageRef::BF16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorageRef::F16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorageRef::F32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorageRef::F64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
|
@ -410,31 +518,31 @@ impl BackendDevice for CudaDevice {
|
|||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U8(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorage::U32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::I64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorage::F16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorage::F32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorage::F64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
|
@ -447,31 +555,31 @@ impl BackendDevice for CudaDevice {
|
|||
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U8(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorage::U32(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::I64(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorage::F16(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorage::F32(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorage::F64(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
|
@ -482,7 +590,7 @@ impl BackendDevice for CudaDevice {
|
|||
}
|
||||
|
||||
fn synchronize(&self) -> Result<()> {
|
||||
self.device.synchronize().map_err(crate::Error::wrap)?;
|
||||
self.stream.synchronize().map_err(crate::Error::wrap)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -396,7 +396,10 @@ impl UgIOp1 {
|
|||
{
|
||||
let device = device.as_cuda_device()?;
|
||||
let func = device.compile(name, kernel)?;
|
||||
Ok(Self { name, func })
|
||||
Ok(Self {
|
||||
name,
|
||||
func: func.into_cuda_function(),
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "metal")]
|
||||
{
|
||||
|
@ -459,16 +462,16 @@ impl InplaceOp1 for UgIOp1 {
|
|||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
|
||||
use crate::cuda_backend::WrapErr;
|
||||
use cudarc::driver::LaunchAsync;
|
||||
use cudarc::driver::PushKernelArg;
|
||||
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let stream = sto.device.cuda_stream();
|
||||
// TODO: support more dtypes.
|
||||
let sto = sto.as_cuda_slice::<f32>()?;
|
||||
let sto = match layout.contiguous_offsets() {
|
||||
None => crate::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => sto.slice(o1..o2),
|
||||
};
|
||||
let params = (&sto,);
|
||||
let (g, b) = if elem_count % 32 == 0 {
|
||||
(elem_count / 32, 32)
|
||||
} else {
|
||||
|
@ -479,7 +482,9 @@ impl InplaceOp1 for UgIOp1 {
|
|||
block_dim: (b as u32, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
unsafe { self.func.clone().launch(cfg, params) }.w()?;
|
||||
let mut builder = stream.launch_builder(&self.func);
|
||||
builder.arg(&sto);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
use super::{GgmlDType, QStorage};
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
||||
use crate::{CudaDevice, CudaStorage, Result};
|
||||
use crate::{builder_arg as barg, CudaDevice, CudaStorage, Result};
|
||||
use half::f16;
|
||||
|
||||
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
|
||||
use cudarc::driver::{CudaSlice, CudaView, PushKernelArg};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct PaddedCudaSlice {
|
||||
|
@ -50,19 +50,20 @@ fn quantize_q8_1(
|
|||
ky: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let kx = elem_count;
|
||||
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
|
||||
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
|
||||
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
|
||||
let func = dev.get_or_load_func("quantize_q8_1", &candle_kernels::QUANTIZED)?;
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (num_blocks as u32, ky as u32, 1),
|
||||
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
let params = (src, dst, kx as i32, kx_padded as i32);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let mut builder = func.builder();
|
||||
builder.arg(src);
|
||||
builder.arg(dst);
|
||||
barg!(builder, kx as i32, kx_padded as i32);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -72,8 +73,6 @@ fn dequantize_f32(
|
|||
elem_count: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let nb = (elem_count + 255) / 256;
|
||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
|
||||
|
@ -99,7 +98,7 @@ fn dequantize_f32(
|
|||
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
|
||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
|
||||
// See e.g.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||
|
@ -110,15 +109,20 @@ fn dequantize_f32(
|
|||
};
|
||||
|
||||
if is_k {
|
||||
let params = (&data.inner, &dst);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&data.inner);
|
||||
builder.arg(&dst);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
} else {
|
||||
let nb32 = match dtype {
|
||||
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||
_ => elem_count / 32,
|
||||
};
|
||||
let params = (&data.inner, &dst, nb32 as i32);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&data.inner);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, nb32 as i32);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
}
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
@ -129,8 +133,6 @@ fn dequantize_f16(
|
|||
elem_count: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let nb = (elem_count + 255) / 256;
|
||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
|
||||
|
@ -156,7 +158,7 @@ fn dequantize_f16(
|
|||
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
|
||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
|
||||
// See e.g.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||
|
@ -167,15 +169,20 @@ fn dequantize_f16(
|
|||
};
|
||||
|
||||
if is_k {
|
||||
let params = (&data.inner, &dst);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&data.inner);
|
||||
builder.arg(&dst);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
} else {
|
||||
let nb32 = match dtype {
|
||||
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||
_ => elem_count / 32,
|
||||
};
|
||||
let params = (&data.inner, &dst, nb32 as i32);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&data.inner);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, nb32 as i32);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
}
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
@ -188,8 +195,6 @@ fn dequantize_mul_mat_vec(
|
|||
nrows: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
||||
if data_elems < ncols * nrows {
|
||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||
|
@ -210,7 +215,7 @@ fn dequantize_mul_mat_vec(
|
|||
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
|
||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
|
||||
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
|
@ -219,8 +224,12 @@ fn dequantize_mul_mat_vec(
|
|||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
let params = (&data.inner, y, &dst, ncols as i32, nrows as i32);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&data.inner);
|
||||
builder.arg(y);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, ncols as i32, nrows as i32);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
|
@ -233,8 +242,6 @@ fn mul_mat_vec_via_q8_1(
|
|||
b_size: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
||||
if data_elems < ncols * nrows {
|
||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||
|
@ -266,7 +273,7 @@ fn mul_mat_vec_via_q8_1(
|
|||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let kernel_name = format!("{kernel_name}{b_size}");
|
||||
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let func = dev.get_or_load_func(&kernel_name, &candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
|
||||
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
||||
let (nblocks, nwarps) = match b_size {
|
||||
|
@ -281,16 +288,18 @@ fn mul_mat_vec_via_q8_1(
|
|||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
let params = (
|
||||
&data.inner,
|
||||
&y_q8_1,
|
||||
&dst,
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&data.inner);
|
||||
builder.arg(&y_q8_1);
|
||||
builder.arg(&dst);
|
||||
barg!(
|
||||
builder,
|
||||
/* ncols_x */ ncols as i32,
|
||||
/* nrows_x */ nrows as i32,
|
||||
/* nrows_y */ ncols_padded as i32,
|
||||
/* nrows_dst */ nrows as i32,
|
||||
/* nrows_dst */ nrows as i32
|
||||
);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
|
@ -305,8 +314,6 @@ fn mul_mat_via_q8_1(
|
|||
y_cols: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
||||
if data_elems < x_rows * x_cols {
|
||||
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
|
||||
|
@ -338,7 +345,7 @@ fn mul_mat_via_q8_1(
|
|||
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
|
||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (
|
||||
|
@ -350,17 +357,19 @@ fn mul_mat_via_q8_1(
|
|||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
let params = (
|
||||
/* vx */ &data.inner,
|
||||
/* vy */ &y_q8_1,
|
||||
/* dst */ &dst,
|
||||
let mut builder = func.builder();
|
||||
builder.arg(/* vx */ &data.inner);
|
||||
builder.arg(/* vy */ &y_q8_1);
|
||||
builder.arg(/* dst */ &dst);
|
||||
barg!(
|
||||
builder,
|
||||
/* ncols_x */ x_cols as i32,
|
||||
/* nrows_x */ x_rows as i32,
|
||||
/* ncols_y */ y_cols as i32,
|
||||
/* nrows_y */ k_padded as i32,
|
||||
/* nrows_dst */ x_rows as i32,
|
||||
/* nrows_dst */ x_rows as i32
|
||||
);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
|
@ -416,7 +425,7 @@ impl QCudaStorage {
|
|||
|
||||
let buffer = self
|
||||
.device
|
||||
.dtoh_sync_copy(&self.data.inner.slice(..self.data.len))
|
||||
.memcpy_dtov(&self.data.inner.slice(..self.data.len))
|
||||
.w()?;
|
||||
let mut out = vec![0.0; elem_count];
|
||||
let block_len = elem_count / self.dtype.block_size();
|
||||
|
@ -449,7 +458,7 @@ impl QCudaStorage {
|
|||
// Run the quantization on cpu.
|
||||
let src = match &src.slice {
|
||||
crate::cuda_backend::CudaStorageSlice::F32(data) => {
|
||||
self.device.dtoh_sync_copy(data).w()?
|
||||
self.device.memcpy_dtov(data).w()?
|
||||
}
|
||||
_ => crate::bail!("only f32 can be quantized"),
|
||||
};
|
||||
|
@ -462,7 +471,7 @@ impl QCudaStorage {
|
|||
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
|
||||
let mut inner = unsafe { self.device.alloc::<u8>(padded_len).w()? };
|
||||
self.device
|
||||
.htod_sync_copy_into(data.as_ref(), &mut inner.slice_mut(..data.len()))
|
||||
.memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))
|
||||
.w()?;
|
||||
self.data = PaddedCudaSlice {
|
||||
inner,
|
||||
|
@ -599,7 +608,7 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
|||
let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size();
|
||||
let mut inner = unsafe { device.alloc::<u8>(padded_len).w()? };
|
||||
device
|
||||
.htod_sync_copy_into(data, &mut inner.slice_mut(..data.len()))
|
||||
.memcpy_htod(data, &mut inner.slice_mut(..data.len()))
|
||||
.w()?;
|
||||
Ok(QStorage::Cuda(QCudaStorage {
|
||||
data: PaddedCudaSlice {
|
||||
|
@ -624,7 +633,7 @@ mod test {
|
|||
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
|
||||
let y = dev.htod_sync_copy(&vs).w()?;
|
||||
let y = dev.memcpy_stod(&vs).w()?;
|
||||
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -634,7 +643,7 @@ mod test {
|
|||
let dev = CudaDevice::new(0)?;
|
||||
let ncols = 256;
|
||||
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
|
||||
let y = dev.htod_sync_copy(&vs).w()?;
|
||||
let y = dev.memcpy_stod(&vs).w()?;
|
||||
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
|
||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||
let cuda_storage = mul_mat_vec_via_q8_1(
|
||||
|
@ -647,7 +656,7 @@ mod test {
|
|||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
||||
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||
assert_eq!(vs.len(), 1);
|
||||
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
|
||||
// Q8 means 1/256 precision.
|
||||
|
@ -662,7 +671,7 @@ mod test {
|
|||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
||||
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||
assert_eq!(vs.len(), 1);
|
||||
assert_eq!(vs[0], 5561851.0);
|
||||
Ok(())
|
||||
|
@ -673,7 +682,7 @@ mod test {
|
|||
let dev = CudaDevice::new(0)?;
|
||||
let ncols = 256;
|
||||
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
|
||||
let y = dev.htod_sync_copy(&vs).w()?;
|
||||
let y = dev.memcpy_stod(&vs).w()?;
|
||||
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
|
||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||
let cuda_storage = mul_mat_via_q8_1(
|
||||
|
@ -687,7 +696,7 @@ mod test {
|
|||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
||||
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||
|
||||
/*
|
||||
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
|
||||
|
@ -714,7 +723,7 @@ mod test {
|
|||
let dev = CudaDevice::new(0)?;
|
||||
let (x_rows, ncols, y_cols) = (4, 16, 2048);
|
||||
let vs: Vec<f32> = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect();
|
||||
let y = dev.htod_sync_copy(&vs).w()?;
|
||||
let y = dev.memcpy_stod(&vs).w()?;
|
||||
let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?;
|
||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||
let cuda_storage = mul_mat_via_q8_1(
|
||||
|
@ -728,7 +737,7 @@ mod test {
|
|||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
let _vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
||||
let _vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -56,7 +56,7 @@ impl ArgSort {
|
|||
mod cuda {
|
||||
use super::*;
|
||||
use crate::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, ValidAsZeroBits,
|
||||
};
|
||||
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
|
||||
use crate::{CudaDevice, WithDType};
|
||||
|
@ -69,6 +69,8 @@ mod cuda {
|
|||
layout: &crate::Layout,
|
||||
_wrap: W,
|
||||
) -> Result<S> {
|
||||
use cudarc::driver::PushKernelArg;
|
||||
|
||||
let slice = match layout.contiguous_offsets() {
|
||||
None => crate::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
|
@ -76,20 +78,24 @@ mod cuda {
|
|||
let elem_count = layout.shape().elem_count();
|
||||
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
|
||||
let func = if self.asc {
|
||||
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
|
||||
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), &kernels::SORT)?
|
||||
} else {
|
||||
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
|
||||
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), &kernels::SORT)?
|
||||
};
|
||||
let ncols = self.last_dim;
|
||||
let nrows = elem_count / ncols;
|
||||
let ncols_pad = next_power_of_2(ncols);
|
||||
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (1, nrows as u32, 1),
|
||||
block_dim: (ncols_pad as u32, 1, 1),
|
||||
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
|
||||
};
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let stream = dev.cuda_stream();
|
||||
let mut builder = stream.launch_builder(&func);
|
||||
let ncols = ncols as i32;
|
||||
let ncols_pad = ncols_pad as i32;
|
||||
builder.arg(&slice).arg(&dst).arg(&ncols).arg(&ncols_pad);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(S::U32(dst))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -56,7 +56,7 @@ impl CustomOp1 for LayerNorm {
|
|||
layout: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::cudarc::driver::{LaunchAsync, LaunchConfig};
|
||||
use candle::cuda_backend::cudarc::driver::{LaunchConfig, PushKernelArg};
|
||||
use candle::cuda_backend::WrapErr;
|
||||
let (d1, d2) = layout.shape().dims2()?;
|
||||
let d1 = d1 as u32;
|
||||
|
@ -69,14 +69,18 @@ impl CustomOp1 for LayerNorm {
|
|||
};
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
||||
let func = dev.get_or_load_func("rms_f32", cuda_kernels::LAYERNORM_KERNELS)?;
|
||||
let params = (&dst, &slice, self.eps, d1, d2);
|
||||
let func =
|
||||
dev.get_or_load_custom_func("rms_f32", "mymodule", cuda_kernels::LAYERNORM_KERNELS)?;
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (d1, 1, 1),
|
||||
block_dim: (d2, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&dst);
|
||||
builder.arg(&slice);
|
||||
candle::builder_arg!(builder, self.eps, d1, d2);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
|
||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
||||
Ok((dst, layout.shape().clone()))
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.8.4"
|
||||
version = "0.9.0-alpha.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
|
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
|||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.4" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.1" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
|
|
|
@ -88,6 +88,7 @@ impl FlashAttn {
|
|||
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
|
||||
}
|
||||
|
||||
let stream = dev.cuda_stream();
|
||||
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
||||
if alibi_slopes.dtype() != DType::F32 {
|
||||
candle::bail!(
|
||||
|
@ -114,7 +115,9 @@ impl FlashAttn {
|
|||
|
||||
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
||||
|
||||
*alibi_slopes.device_ptr() as *const core::ffi::c_void
|
||||
// Dropping the guard here doesn't seem very safe.
|
||||
let (ptr, _guard) = alibi_slopes.device_ptr(&stream);
|
||||
ptr as *const core::ffi::c_void
|
||||
} else {
|
||||
std::ptr::null()
|
||||
};
|
||||
|
@ -161,17 +164,17 @@ impl FlashAttn {
|
|||
}
|
||||
|
||||
unsafe {
|
||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
||||
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
||||
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
|
||||
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
|
||||
let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void;
|
||||
let (q_ptr, _guard) = q.device_ptr(&stream);
|
||||
let (k_ptr, _guard) = k.device_ptr(&stream);
|
||||
let (v_ptr, _guard) = v.device_ptr(&stream);
|
||||
let (dst_ptr, _guard) = dst.device_ptr(&stream);
|
||||
let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream);
|
||||
ffi::run_mha(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
dst_ptr,
|
||||
softmax_lse_ptr,
|
||||
q_ptr as *const core::ffi::c_void,
|
||||
k_ptr as *const core::ffi::c_void,
|
||||
v_ptr as *const core::ffi::c_void,
|
||||
dst_ptr as *const core::ffi::c_void,
|
||||
softmax_lse_ptr as *const core::ffi::c_void,
|
||||
/* alibi_slopes_ptr */ alibi_slopes_ptr,
|
||||
/* cu_seqlens_q_ptr */ std::ptr::null(),
|
||||
/* cu_seqlens_k_ptr */ std::ptr::null(),
|
||||
|
@ -550,6 +553,7 @@ impl FlashAttnVarLen {
|
|||
|
||||
let batch_size = nseqlens_q - 1;
|
||||
|
||||
let stream = dev.cuda_stream();
|
||||
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
||||
if alibi_slopes.dtype() != DType::F32 {
|
||||
candle::bail!(
|
||||
|
@ -576,7 +580,9 @@ impl FlashAttnVarLen {
|
|||
|
||||
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
||||
|
||||
*alibi_slopes.device_ptr() as *const core::ffi::c_void
|
||||
// Dropping the guard here doesn't seem very safe.
|
||||
let (ptr, _guard) = alibi_slopes.device_ptr(&stream);
|
||||
ptr as *const core::ffi::c_void
|
||||
} else {
|
||||
std::ptr::null()
|
||||
};
|
||||
|
@ -621,22 +627,22 @@ impl FlashAttnVarLen {
|
|||
}
|
||||
|
||||
unsafe {
|
||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
||||
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
||||
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
|
||||
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
|
||||
let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void;
|
||||
let seqlens_q_ptr = *seqlens_q.device_ptr() as *const core::ffi::c_int;
|
||||
let seqlens_k_ptr = *seqlens_k.device_ptr() as *const core::ffi::c_int;
|
||||
let (q_ptr, _guard) = q.device_ptr(&stream);
|
||||
let (k_ptr, _guard) = k.device_ptr(&stream);
|
||||
let (v_ptr, _guard) = v.device_ptr(&stream);
|
||||
let (dst_ptr, _guard) = dst.device_ptr(&stream);
|
||||
let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream);
|
||||
let (seqlens_q_ptr, _guard) = seqlens_q.device_ptr(&stream);
|
||||
let (seqlens_k_ptr, _guard) = seqlens_k.device_ptr(&stream);
|
||||
ffi::run_mha(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
dst_ptr,
|
||||
softmax_lse_ptr,
|
||||
/* alibi_slopes_ptr */ alibi_slopes_ptr,
|
||||
/* cu_seqlens_q_ptr */ seqlens_q_ptr,
|
||||
/* cu_seqlens_k_ptr */ seqlens_k_ptr,
|
||||
q_ptr as *const core::ffi::c_void,
|
||||
k_ptr as *const core::ffi::c_void,
|
||||
v_ptr as *const core::ffi::c_void,
|
||||
dst_ptr as *const core::ffi::c_void,
|
||||
softmax_lse_ptr as *const core::ffi::c_void,
|
||||
/* alibi_slopes_ptr */ alibi_slopes_ptr as *const core::ffi::c_void,
|
||||
/* cu_seqlens_q_ptr */ seqlens_q_ptr as *const i32,
|
||||
/* cu_seqlens_k_ptr */ seqlens_k_ptr as *const i32,
|
||||
/* q_batch_stride */ 0,
|
||||
/* k_batch_stride */ 0,
|
||||
/* v_batch_stride */ 0,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.8.4"
|
||||
version = "0.9.0-alpha.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
|
|
@ -7,5 +7,5 @@ fn main() {
|
|||
let builder = bindgen_cuda::Builder::default();
|
||||
println!("cargo:info={builder:?}");
|
||||
let bindings = builder.build_ptx().unwrap();
|
||||
bindings.write("src/lib.rs").unwrap();
|
||||
bindings.write("src/ptx.rs").unwrap();
|
||||
}
|
||||
|
|
|
@ -1,11 +1,78 @@
|
|||
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
||||
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
||||
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
|
||||
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
|
||||
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
|
||||
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
|
||||
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
|
||||
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
||||
mod ptx;
|
||||
|
||||
#[repr(u32)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum Id {
|
||||
Affine,
|
||||
Binary,
|
||||
Cast,
|
||||
Conv,
|
||||
Fill,
|
||||
Indexing,
|
||||
Quantized,
|
||||
Reduce,
|
||||
Sort,
|
||||
Ternary,
|
||||
Unary,
|
||||
}
|
||||
|
||||
pub const ALL_IDS: [Id; 11] = [
|
||||
Id::Affine,
|
||||
Id::Binary,
|
||||
Id::Cast,
|
||||
Id::Conv,
|
||||
Id::Fill,
|
||||
Id::Indexing,
|
||||
Id::Quantized,
|
||||
Id::Reduce,
|
||||
Id::Sort,
|
||||
Id::Ternary,
|
||||
Id::Unary,
|
||||
];
|
||||
|
||||
pub struct Module {
|
||||
index: usize,
|
||||
ptx: &'static str,
|
||||
}
|
||||
|
||||
impl Module {
|
||||
pub fn index(&self) -> usize {
|
||||
self.index
|
||||
}
|
||||
|
||||
pub fn ptx(&self) -> &'static str {
|
||||
self.ptx
|
||||
}
|
||||
}
|
||||
|
||||
const fn module_index(id: Id) -> usize {
|
||||
let mut i = 0;
|
||||
while i < ALL_IDS.len() {
|
||||
if ALL_IDS[i] as u32 == id as u32 {
|
||||
return i;
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
panic!("id not found")
|
||||
}
|
||||
|
||||
macro_rules! mdl {
|
||||
($cst:ident, $id:ident) => {
|
||||
pub const $cst: Module = Module {
|
||||
index: module_index(Id::$id),
|
||||
ptx: ptx::$cst,
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
mdl!(AFFINE, Affine);
|
||||
mdl!(BINARY, Binary);
|
||||
mdl!(CAST, Cast);
|
||||
mdl!(CONV, Conv);
|
||||
mdl!(FILL, Fill);
|
||||
mdl!(INDEXING, Indexing);
|
||||
mdl!(QUANTIZED, Quantized);
|
||||
mdl!(REDUCE, Reduce);
|
||||
mdl!(SORT, Sort);
|
||||
mdl!(TERNARY, Ternary);
|
||||
mdl!(UNARY, Unary);
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
||||
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
||||
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
|
||||
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
|
||||
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
|
||||
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
|
||||
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
|
||||
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.8.4"
|
||||
version = "0.9.0-alpha.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
|
|
|
@ -90,7 +90,7 @@ impl candle::CustomOp1 for Sigmoid {
|
|||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits,
|
||||
};
|
||||
use candle::cuda_backend::SlicePtrOrNull;
|
||||
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
||||
|
@ -110,13 +110,17 @@ impl candle::CustomOp1 for Sigmoid {
|
|||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), kernels::UNARY)?;
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), &kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
|
||||
|
||||
let params = (el_count, dims.len(), &ds, src, &out);
|
||||
let mut builder = func.builder();
|
||||
candle::builder_arg!(builder, el_count, dims.len());
|
||||
ds.builder_arg(&mut builder);
|
||||
builder.arg(src);
|
||||
builder.arg(&out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
@ -340,7 +344,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||
layout: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
|
@ -367,12 +371,15 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||
block_dim: (1, 32, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), kernels::REDUCE)?;
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (&src, &dst, n_cols as i32);
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&dst);
|
||||
candle::builder_arg!(builder, n_cols as i32);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
@ -516,7 +523,7 @@ impl candle::CustomOp2 for RmsNorm {
|
|||
l2: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
|
@ -552,19 +559,16 @@ impl candle::CustomOp2 for RmsNorm {
|
|||
block_dim: (block_size, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), kernels::REDUCE)?;
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (
|
||||
&src,
|
||||
&dst,
|
||||
&alpha,
|
||||
n_cols as i32,
|
||||
block_size as i32,
|
||||
self.eps,
|
||||
);
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&dst);
|
||||
builder.arg(&alpha);
|
||||
candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
@ -751,7 +755,7 @@ impl candle::CustomOp3 for LayerNorm {
|
|||
l3: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
|
@ -793,20 +797,18 @@ impl candle::CustomOp3 for LayerNorm {
|
|||
block_dim: (block_size, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?;
|
||||
let func =
|
||||
dev.get_or_load_func(&kernel_name::<T>("layernorm"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (
|
||||
&src,
|
||||
&dst,
|
||||
&alpha,
|
||||
&beta,
|
||||
n_cols as i32,
|
||||
block_size as i32,
|
||||
self.eps,
|
||||
);
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&dst);
|
||||
builder.arg(&alpha);
|
||||
builder.arg(&beta);
|
||||
candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -88,7 +88,7 @@ impl candle::CustomOp3 for RotaryEmbI {
|
|||
l3: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
|
@ -117,12 +117,17 @@ impl candle::CustomOp3 for RotaryEmbI {
|
|||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let el = b * h * t * d;
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), kernels::REDUCE)?;
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (&src, &cos, &sin, &dst, (b * h) as u32, (t * d) as u32);
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&cos);
|
||||
builder.arg(&sin);
|
||||
builder.arg(&dst);
|
||||
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
|
@ -333,7 +338,7 @@ impl candle::CustomOp3 for RotaryEmb {
|
|||
l3: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
|
@ -362,20 +367,17 @@ impl candle::CustomOp3 for RotaryEmb {
|
|||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let el = b * h * t * d;
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), kernels::REDUCE)?;
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (
|
||||
&src,
|
||||
&cos,
|
||||
&sin,
|
||||
&dst,
|
||||
(b * h) as u32,
|
||||
(t * d) as u32,
|
||||
d as u32,
|
||||
);
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&cos);
|
||||
builder.arg(&sin);
|
||||
builder.arg(&dst);
|
||||
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
|
@ -587,7 +589,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
|||
l3: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
|
@ -616,14 +618,17 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
|||
let (b, t, h, d) = l_src.shape().dims4()?;
|
||||
let el = b * h * t * d;
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), kernels::REDUCE)?;
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (
|
||||
&src, &cos, &sin, &dst, b as u32, t as u32, h as u32, d as u32,
|
||||
);
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&cos);
|
||||
builder.arg(&sin);
|
||||
builder.arg(&dst);
|
||||
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "candle-onnx"
|
||||
version = "0.8.4"
|
||||
version = "0.9.0-alpha.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "ONNX support for Candle"
|
||||
|
@ -10,8 +10,8 @@ categories = ["science"]
|
|||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.8.4" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.8.4" }
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.1" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
|
|
Loading…
Reference in New Issue