1972 lines
72 KiB
Rust
1972 lines
72 KiB
Rust
use crate::backend::{BackendDevice, BackendStorage};
|
|
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
|
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
|
|
pub use candle_kernels as kernels;
|
|
pub use cudarc;
|
|
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
|
use cudarc::driver::{
|
|
CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
|
};
|
|
use half::{bf16, f16};
|
|
|
|
#[cfg(feature = "cudnn")]
|
|
pub mod cudnn;
|
|
mod device;
|
|
mod error;
|
|
mod utils;
|
|
pub use device::{CudaDevice, DeviceId};
|
|
pub use error::{CudaError, WrapErr};
|
|
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S};
|
|
|
|
pub enum SlicePtrOrNull<T> {
|
|
Ptr(CudaSlice<T>),
|
|
Null,
|
|
}
|
|
|
|
unsafe impl<T: DeviceRepr> DeviceRepr for &SlicePtrOrNull<T> {
|
|
fn as_kernel_param(&self) -> *mut std::ffi::c_void {
|
|
match self {
|
|
SlicePtrOrNull::Ptr(slice) => slice.as_kernel_param(),
|
|
SlicePtrOrNull::Null => 0usize.as_kernel_param(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl SlicePtrOrNull<usize> {
|
|
pub fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
|
|
let ds = if l.is_contiguous() {
|
|
SlicePtrOrNull::Null
|
|
} else {
|
|
SlicePtrOrNull::Ptr(dev.htod_copy([l.dims(), l.stride()].concat()).w()?)
|
|
};
|
|
Ok(ds)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum CudaStorageSlice {
|
|
U8(CudaSlice<u8>),
|
|
U32(CudaSlice<u32>),
|
|
I64(CudaSlice<i64>),
|
|
BF16(CudaSlice<bf16>),
|
|
F16(CudaSlice<f16>),
|
|
F32(CudaSlice<f32>),
|
|
F64(CudaSlice<f64>),
|
|
}
|
|
|
|
struct Clone;
|
|
impl Map1 for Clone {
|
|
fn f<T: DeviceRepr>(
|
|
&self,
|
|
s: &CudaSlice<T>,
|
|
_: &CudaDevice,
|
|
_: &Layout,
|
|
) -> Result<CudaSlice<T>> {
|
|
s.try_clone().w()
|
|
}
|
|
}
|
|
|
|
pub fn kernel_name<T: WithDType>(root: &str) -> String {
|
|
let dtype = T::DTYPE.as_str();
|
|
format!("{root}_{dtype}")
|
|
}
|
|
|
|
struct Affine(f64, f64);
|
|
impl Map1 for Affine {
|
|
fn f<T: DeviceRepr + WithDType>(
|
|
&self,
|
|
src: &CudaSlice<T>,
|
|
dev: &CudaDevice,
|
|
layout: &Layout,
|
|
) -> Result<CudaSlice<T>> {
|
|
let shape = layout.shape();
|
|
let dims = shape.dims();
|
|
let el = shape.elem_count();
|
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
|
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
|
let src = &src.slice(layout.start_offset()..);
|
|
let func = dev.get_or_load_func(&kernel_name::<T>("affine"), kernels::AFFINE)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
|
let params = (
|
|
el,
|
|
dims.len(),
|
|
&ds,
|
|
src,
|
|
&out,
|
|
T::from_f64(self.0),
|
|
T::from_f64(self.1),
|
|
);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(out)
|
|
}
|
|
}
|
|
|
|
struct Elu(f64);
|
|
impl Map1 for Elu {
|
|
fn f<T: DeviceRepr + WithDType>(
|
|
&self,
|
|
src: &CudaSlice<T>,
|
|
dev: &CudaDevice,
|
|
layout: &Layout,
|
|
) -> Result<CudaSlice<T>> {
|
|
let shape = layout.shape();
|
|
let dims = shape.dims();
|
|
let el = shape.elem_count();
|
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
|
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
|
let src = &src.slice(layout.start_offset()..);
|
|
let func = dev.get_or_load_func(&kernel_name::<T>("uelu"), kernels::UNARY)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
|
let params = (el, dims.len(), &ds, T::from_f64(self.0), src, &out);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(out)
|
|
}
|
|
}
|
|
|
|
struct Im2Col1D {
|
|
l_k: usize,
|
|
stride: usize,
|
|
dilation: usize,
|
|
padding: usize,
|
|
}
|
|
|
|
impl Im2Col1D {
|
|
fn l_out(&self, l: usize) -> usize {
|
|
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
|
|
}
|
|
}
|
|
|
|
impl Map1 for Im2Col1D {
|
|
fn f<T: DeviceRepr + WithDType>(
|
|
&self,
|
|
src: &CudaSlice<T>,
|
|
dev: &CudaDevice,
|
|
layout: &Layout,
|
|
) -> Result<CudaSlice<T>> {
|
|
let shape = layout.shape();
|
|
let dims = shape.dims();
|
|
let l_out = self.l_out(dims[2]);
|
|
let dst_el = dims[0] * l_out * dims[1] * self.l_k;
|
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
|
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
|
let src = &src.slice(layout.start_offset()..);
|
|
let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), kernels::CONV)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
|
let params = (
|
|
dst_el,
|
|
l_out,
|
|
self.l_k,
|
|
self.stride,
|
|
self.padding,
|
|
self.dilation,
|
|
&ds,
|
|
src,
|
|
&dst,
|
|
);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(dst)
|
|
}
|
|
}
|
|
|
|
struct Im2Col {
|
|
h_k: usize,
|
|
w_k: usize,
|
|
stride: usize,
|
|
dilation: usize,
|
|
padding: usize,
|
|
}
|
|
|
|
impl Im2Col {
|
|
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
|
|
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
|
|
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
|
|
(h_out, w_out)
|
|
}
|
|
}
|
|
|
|
impl Map1 for Im2Col {
|
|
fn f<T: DeviceRepr + WithDType>(
|
|
&self,
|
|
src: &CudaSlice<T>,
|
|
dev: &CudaDevice,
|
|
layout: &Layout,
|
|
) -> Result<CudaSlice<T>> {
|
|
let shape = layout.shape();
|
|
let dims = shape.dims();
|
|
let (h_out, w_out) = self.hw_out(dims[2], dims[3]);
|
|
let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k;
|
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
|
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
|
|
let src = &src.slice(layout.start_offset()..);
|
|
let func = dev.get_or_load_func(&kernel_name::<T>("im2col"), kernels::CONV)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
|
let params = (
|
|
dst_el,
|
|
h_out,
|
|
w_out,
|
|
self.h_k,
|
|
self.w_k,
|
|
self.stride,
|
|
self.padding,
|
|
self.dilation,
|
|
&ds,
|
|
src,
|
|
&dst,
|
|
);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(dst)
|
|
}
|
|
}
|
|
|
|
struct Powf(f64);
|
|
impl Map1 for Powf {
|
|
fn f<T: DeviceRepr + WithDType>(
|
|
&self,
|
|
src: &CudaSlice<T>,
|
|
dev: &CudaDevice,
|
|
layout: &Layout,
|
|
) -> Result<CudaSlice<T>> {
|
|
let shape = layout.shape();
|
|
let dims = shape.dims();
|
|
let el = shape.elem_count();
|
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
|
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
|
let src = &src.slice(layout.start_offset()..);
|
|
let func = dev.get_or_load_func(&kernel_name::<T>("upowf"), kernels::UNARY)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
|
let params = (el, dims.len(), &ds, T::from_f64(self.0), src, &out);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(out)
|
|
}
|
|
}
|
|
|
|
struct Sum<'a>(&'a [usize]);
|
|
impl<'a> Map1 for Sum<'a> {
|
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
&self,
|
|
src: &CudaSlice<T>,
|
|
dev: &CudaDevice,
|
|
layout: &Layout,
|
|
) -> Result<CudaSlice<T>> {
|
|
let shape = layout.shape();
|
|
let src_dims = shape.dims();
|
|
let el = shape.elem_count();
|
|
let mut dst_el = el;
|
|
for &sum_dim in self.0.iter() {
|
|
dst_el /= src_dims[sum_dim];
|
|
}
|
|
let mut sum_dims = self.0.to_vec();
|
|
// Sort the sum_dims as they have to be processed from left to right when converting the
|
|
// indexes.
|
|
sum_dims.sort();
|
|
let sum_dims_l: Vec<usize> = sum_dims.iter().map(|&d| src_dims[d]).collect();
|
|
let sum_dims_s: Vec<usize> = sum_dims
|
|
.iter()
|
|
.map(|&d| src_dims[d + 1..].iter().product::<usize>())
|
|
.collect();
|
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
|
let ds = dev
|
|
.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())
|
|
.w()?;
|
|
let src = &src.slice(layout.start_offset()..);
|
|
let func = dev.get_or_load_func(&kernel_name::<T>("sum"), kernels::REDUCE)?;
|
|
let out = dev.alloc_zeros::<T>(dst_el).w()?;
|
|
let params = (el, src_dims.len(), sum_dims.len(), &ds, src, &out);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(out)
|
|
}
|
|
}
|
|
|
|
struct FastReduce<'a>(&'a [usize], ReduceOp);
|
|
impl<'a> Map1Any for FastReduce<'a> {
|
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
|
&self,
|
|
src: &CudaSlice<T>,
|
|
dev: &CudaDevice,
|
|
layout: &Layout,
|
|
wrap: W,
|
|
) -> Result<S> {
|
|
let src_stride = layout.stride();
|
|
let src_dims = layout.shape().dims();
|
|
let src_el: usize = src_dims.iter().product();
|
|
// Source dims and strides with the sum dims at the end.
|
|
let mut dims = vec![];
|
|
let mut stride = vec![];
|
|
let mut dst_el: usize = 1;
|
|
for (dim_idx, &d) in src_dims.iter().enumerate() {
|
|
if !self.0.contains(&dim_idx) {
|
|
dst_el *= d;
|
|
dims.push(d);
|
|
stride.push(src_stride[dim_idx]);
|
|
}
|
|
}
|
|
for &dim_idx in self.0.iter() {
|
|
dims.push(src_dims[dim_idx]);
|
|
stride.push(src_stride[dim_idx]);
|
|
}
|
|
let el_to_sum_per_block = src_el / dst_el;
|
|
// The reduction loop requires the shared array to be properly initialized and for
|
|
// this we want the number of threads to be a power of two.
|
|
let block_dim = usize::min(1024, el_to_sum_per_block).next_power_of_two();
|
|
let cfg = LaunchConfig {
|
|
// TODO: Maybe use grid_y if the output is too large?
|
|
// TODO: Specialized implementation when reducing on no or all dimensions or when
|
|
// reducing only aggregate a small number of elements together.
|
|
grid_dim: (dst_el as u32, 1, 1),
|
|
block_dim: (block_dim as u32, 1, 1),
|
|
shared_mem_bytes: 0,
|
|
};
|
|
let ds = dev
|
|
.htod_copy([dims.as_slice(), stride.as_slice()].concat())
|
|
.w()?;
|
|
let src = &src.slice(layout.start_offset()..);
|
|
let (name, check_empty, return_index) = match self.1 {
|
|
ReduceOp::Sum => ("fast_sum", false, false),
|
|
ReduceOp::Min => ("fast_min", true, false),
|
|
ReduceOp::Max => ("fast_max", true, false),
|
|
ReduceOp::ArgMin => ("fast_argmin", true, true),
|
|
ReduceOp::ArgMax => ("fast_argmax", true, true),
|
|
};
|
|
if check_empty && layout.shape().elem_count() == 0 {
|
|
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
|
}
|
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
|
|
if return_index {
|
|
// SAFETY: filled in by the follow up kernel.
|
|
let out = unsafe { dev.alloc::<u32>(dst_el) }.w()?;
|
|
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(S::U32(out))
|
|
} else {
|
|
// SAFETY: filled in by the follow up kernel.
|
|
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
|
let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(wrap(out))
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<U: UnaryOpT> Map1 for U {
|
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
&self,
|
|
src: &CudaSlice<T>,
|
|
dev: &CudaDevice,
|
|
layout: &Layout,
|
|
) -> Result<CudaSlice<T>> {
|
|
let shape = layout.shape();
|
|
let dims = shape.dims();
|
|
let el_count = shape.elem_count();
|
|
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
|
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
|
let src = &src.slice(layout.start_offset()..);
|
|
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::UNARY)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
|
|
let params = (el_count, dims.len(), &ds, src, &out);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(out)
|
|
}
|
|
}
|
|
|
|
struct IndexSelect<'a>(&'a CudaStorage, &'a Layout, usize);
|
|
impl<'a> Map1 for IndexSelect<'a> {
|
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
&self,
|
|
src: &CudaSlice<T>,
|
|
dev: &CudaDevice,
|
|
src_l: &Layout,
|
|
) -> Result<CudaSlice<T>> {
|
|
let ids_l = &self.1;
|
|
let (name, ids) = match &self.0.slice {
|
|
CudaStorageSlice::U32(slice) => {
|
|
("is_u32", *slice.slice(ids_l.start_offset()..).device_ptr())
|
|
}
|
|
CudaStorageSlice::U8(slice) => {
|
|
("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr())
|
|
}
|
|
CudaStorageSlice::I64(slice) => {
|
|
("is_i64", *slice.slice(ids_l.start_offset()..).device_ptr())
|
|
}
|
|
_ => Err(CudaError::UnexpectedDType {
|
|
msg: "index_select ids should be u8 or u32",
|
|
expected: DType::U32,
|
|
got: self.0.dtype(),
|
|
})
|
|
.w()?,
|
|
};
|
|
let ids_shape = ids_l.shape();
|
|
let ids_dims = ids_shape.dims();
|
|
let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?;
|
|
let src = match src_l.contiguous_offsets() {
|
|
Some((o1, o2)) => src.slice(o1..o2),
|
|
None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?,
|
|
};
|
|
let left_size: usize = src_l.dims()[..self.2].iter().product();
|
|
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
|
|
let src_dim_size = src_l.dims()[self.2];
|
|
let ids_dim_size = ids_shape.elem_count();
|
|
let dst_el = ids_shape.elem_count() * left_size * right_size;
|
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
|
let params = (
|
|
dst_el,
|
|
ids_dims.len(),
|
|
&ds,
|
|
ids,
|
|
&src,
|
|
&out,
|
|
left_size,
|
|
src_dim_size,
|
|
ids_dim_size,
|
|
right_size,
|
|
);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(out)
|
|
}
|
|
}
|
|
|
|
struct Gather<'a>(&'a CudaStorage, &'a Layout, usize);
|
|
impl<'a> Map1 for Gather<'a> {
|
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
&self,
|
|
src: &CudaSlice<T>,
|
|
dev: &CudaDevice,
|
|
src_l: &Layout,
|
|
) -> Result<CudaSlice<T>> {
|
|
let ids = &self.0;
|
|
let ids_l = &self.1;
|
|
let dim = self.2;
|
|
let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() {
|
|
Some(o12) => o12,
|
|
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
|
|
};
|
|
let (name, ids) = match &ids.slice {
|
|
CudaStorageSlice::U32(slice) => {
|
|
("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr())
|
|
}
|
|
CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
|
CudaStorageSlice::I64(slice) => {
|
|
("gather_i64", *slice.slice(ids_o1..ids_o2).device_ptr())
|
|
}
|
|
_ => Err(CudaError::UnexpectedDType {
|
|
msg: "gather ids should be u8/u32/i64",
|
|
expected: DType::U32,
|
|
got: ids.dtype(),
|
|
})?,
|
|
};
|
|
let el = ids_l.shape().elem_count();
|
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
|
let src = match src_l.contiguous_offsets() {
|
|
Some((o1, o2)) => src.slice(o1..o2),
|
|
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
|
|
};
|
|
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
|
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
|
let src_dim_sz = src_l.dims()[dim];
|
|
let ids_dim_sz = ids_l.dims()[dim];
|
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
|
let params = (
|
|
el, ids, &src, &out, left_sz, src_dim_sz, ids_dim_sz, right_sz,
|
|
);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(out)
|
|
}
|
|
}
|
|
|
|
struct IndexAdd<'a>(&'a CudaStorage, &'a Layout, usize);
|
|
impl<'a> Map2InPlace for IndexAdd<'a> {
|
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
&self,
|
|
dst: &mut CudaSlice<T>,
|
|
dst_shape: &Shape,
|
|
src: &CudaSlice<T>,
|
|
src_l: &Layout,
|
|
dev: &CudaDevice,
|
|
) -> Result<()> {
|
|
let ids = &self.0;
|
|
let ids_l = &self.1;
|
|
let dim = self.2;
|
|
let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() {
|
|
Some(o12) => o12,
|
|
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
|
};
|
|
let (name, ids) = match &ids.slice {
|
|
CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
|
CudaStorageSlice::I64(slice) => ("ia_i64", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
|
CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
|
_ => Err(CudaError::UnexpectedDType {
|
|
msg: "index-add ids should be u8/u32/i64",
|
|
expected: DType::U32,
|
|
got: ids.dtype(),
|
|
})?,
|
|
};
|
|
let src = match src_l.contiguous_offsets() {
|
|
Some((o1, o2)) => src.slice(o1..o2),
|
|
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
|
};
|
|
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
|
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
|
let src_dim_sz = src_l.dims()[dim];
|
|
let dst_dim_sz = dst_shape.dims()[dim];
|
|
let ids_dim_sz = ids_l.dims()[0];
|
|
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let params = (
|
|
ids, ids_dim_sz, &src, dst, left_sz, src_dim_sz, dst_dim_sz, right_sz,
|
|
);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
struct ScatterAdd<'a>(&'a CudaStorage, &'a Layout, usize);
|
|
impl<'a> Map2InPlace for ScatterAdd<'a> {
|
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
&self,
|
|
dst: &mut CudaSlice<T>,
|
|
dst_shape: &Shape,
|
|
src: &CudaSlice<T>,
|
|
src_l: &Layout,
|
|
dev: &CudaDevice,
|
|
) -> Result<()> {
|
|
let ids = &self.0;
|
|
let ids_l = &self.1;
|
|
let dim = self.2;
|
|
let (ids_o1, ids_o2) = match ids_l.contiguous_offsets() {
|
|
Some(o12) => o12,
|
|
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
|
};
|
|
let (name, ids) = match &ids.slice {
|
|
CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
|
CudaStorageSlice::I64(slice) => ("sa_i64", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
|
CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()),
|
|
_ => Err(CudaError::UnexpectedDType {
|
|
msg: "scatter-add ids should be u8/u32/i64",
|
|
expected: DType::U32,
|
|
got: ids.dtype(),
|
|
})?,
|
|
};
|
|
let src = match src_l.contiguous_offsets() {
|
|
Some((o1, o2)) => src.slice(o1..o2),
|
|
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
|
};
|
|
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
|
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
|
let src_dim_sz = src_l.dims()[dim];
|
|
let dst_dim_sz = dst_shape.dims()[dim];
|
|
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let params = (ids, &src, dst, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
|
|
impl<'a> Map2 for Conv1D<'a> {
|
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
&self,
|
|
inp: &CudaSlice<T>,
|
|
inp_l: &Layout,
|
|
k: &CudaSlice<T>,
|
|
k_l: &Layout,
|
|
dev: &CudaDevice,
|
|
) -> Result<CudaSlice<T>> {
|
|
// Kernel shape: (c_out, c_in_k, k_size)
|
|
// Input shape: (b_size, c_in, l_in) or (c_in, l_in)
|
|
let p = &self.0;
|
|
let inp = &inp.slice(inp_l.start_offset()..);
|
|
let k = &k.slice(k_l.start_offset()..);
|
|
let shape = inp_l.shape();
|
|
let dims = shape.dims();
|
|
let el = shape.elem_count();
|
|
let l_out = p.l_out();
|
|
let dst_el = p.c_out * l_out * p.b_size;
|
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
|
let func = dev.get_or_load_func(&kernel_name::<T>("conv1d"), kernels::CONV)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
|
let ds = if dims.len() == 3 {
|
|
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
|
} else if dims.len() == 2 {
|
|
[&[1], dims, &[1], inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
|
} else {
|
|
crate::bail!("unexpected input shape for conv1d {dims:?}")
|
|
};
|
|
let ds = dev.htod_copy(ds).w()?;
|
|
let params = (
|
|
el, l_out, p.stride, p.padding, p.dilation, &ds, inp, k, &out,
|
|
);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(out)
|
|
}
|
|
}
|
|
|
|
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
|
impl<'a> Map2 for Conv2D<'a> {
|
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
&self,
|
|
inp: &CudaSlice<T>,
|
|
inp_l: &Layout,
|
|
k: &CudaSlice<T>,
|
|
k_l: &Layout,
|
|
dev: &CudaDevice,
|
|
) -> Result<CudaSlice<T>> {
|
|
// Kernel shape: (c_out, c_in_k, h_k, w_k)
|
|
// Input shape: (b_size, c_in, h_in, w_in)
|
|
let p = &self.0;
|
|
let (out_w, out_h) = (p.out_w(), p.out_h());
|
|
let dst_el = p.c_out * out_w * out_h * p.b_size;
|
|
let inp = &inp.slice(inp_l.start_offset()..);
|
|
let k = &k.slice(k_l.start_offset()..);
|
|
let shape = inp_l.shape();
|
|
let dims = shape.dims();
|
|
let el = shape.elem_count();
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
|
let func = dev.get_or_load_func(&kernel_name::<T>("conv2d"), kernels::CONV)?;
|
|
let ds = if dims.len() == 4 {
|
|
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
|
} else {
|
|
crate::bail!("unexpected input shape for conv2d {dims:?}")
|
|
};
|
|
let ds = dev.htod_copy(ds).w()?;
|
|
let params = (
|
|
el, out_w, out_h, p.stride, p.padding, p.dilation, &ds, inp, k, &out,
|
|
);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(out)
|
|
}
|
|
}
|
|
|
|
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
|
impl<'a> Map2 for ConvTranspose1D<'a> {
|
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
&self,
|
|
inp: &CudaSlice<T>,
|
|
inp_l: &Layout,
|
|
k: &CudaSlice<T>,
|
|
k_l: &Layout,
|
|
dev: &CudaDevice,
|
|
) -> Result<CudaSlice<T>> {
|
|
// Kernel shape: (c_in_k, c_out, l_k)
|
|
// Input shape: (b_size, c_in, l_in)
|
|
let p = &self.0;
|
|
let l_out = p.l_out();
|
|
let dst_el = p.c_out * l_out * p.b_size;
|
|
let inp = &inp.slice(inp_l.start_offset()..);
|
|
let k = &k.slice(k_l.start_offset()..);
|
|
let shape = inp_l.shape();
|
|
let dims = shape.dims();
|
|
let el = shape.elem_count();
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
|
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose1d"), kernels::CONV)?;
|
|
let ds = if dims.len() == 3 {
|
|
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
|
} else {
|
|
crate::bail!("unexpected input shape for conv_transpose1d {dims:?}")
|
|
};
|
|
let ds = dev.htod_copy(ds).w()?;
|
|
let params = (
|
|
el,
|
|
l_out,
|
|
p.stride,
|
|
p.padding,
|
|
p.output_padding,
|
|
p.dilation,
|
|
&ds,
|
|
inp,
|
|
k,
|
|
&out,
|
|
);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(out)
|
|
}
|
|
}
|
|
|
|
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
|
|
impl<'a> Map2 for ConvTranspose2D<'a> {
|
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
&self,
|
|
inp: &CudaSlice<T>,
|
|
inp_l: &Layout,
|
|
k: &CudaSlice<T>,
|
|
k_l: &Layout,
|
|
dev: &CudaDevice,
|
|
) -> Result<CudaSlice<T>> {
|
|
// Kernel shape: (c_in_k, c_out, h_k, w_k)
|
|
// Input shape: (b_size, c_in, h_in, w_in)
|
|
let p = &self.0;
|
|
let (out_w, out_h) = (p.out_w(), p.out_h());
|
|
let dst_el = p.c_out * out_w * out_h * p.b_size;
|
|
let inp = &inp.slice(inp_l.start_offset()..);
|
|
let k = &k.slice(k_l.start_offset()..);
|
|
let shape = inp_l.shape();
|
|
let dims = shape.dims();
|
|
let el = shape.elem_count();
|
|
|
|
// SAFETY: Set later by running the kernel.
|
|
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
|
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose2d"), kernels::CONV)?;
|
|
let ds = if dims.len() == 4 {
|
|
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
|
} else {
|
|
crate::bail!("unexpected input shape for conv_transpose2d {dims:?}")
|
|
};
|
|
let ds = dev.htod_copy(ds).w()?;
|
|
let params = (
|
|
el,
|
|
out_w,
|
|
out_h,
|
|
p.stride,
|
|
p.padding,
|
|
p.output_padding,
|
|
p.dilation,
|
|
&ds,
|
|
inp,
|
|
k,
|
|
&out,
|
|
);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(out)
|
|
}
|
|
}
|
|
|
|
enum PoolOp {
|
|
Max,
|
|
Avg,
|
|
}
|
|
|
|
struct Pool2D {
|
|
w_k: usize,
|
|
h_k: usize,
|
|
w_stride: usize,
|
|
h_stride: usize,
|
|
op: PoolOp,
|
|
}
|
|
|
|
impl Map1 for Pool2D {
|
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
&self,
|
|
inp: &CudaSlice<T>,
|
|
dev: &CudaDevice,
|
|
inp_l: &Layout,
|
|
) -> Result<CudaSlice<T>> {
|
|
// Input shape: (b_size, c, h, w)
|
|
let inp = &inp.slice(inp_l.start_offset()..);
|
|
let shape = inp_l.shape();
|
|
let dims = shape.dims();
|
|
let ds = if dims.len() == 4 {
|
|
[dims, inp_l.stride()].concat()
|
|
} else {
|
|
crate::bail!("unexpected input shape for pool {dims:?}")
|
|
};
|
|
let el = shape.elem_count();
|
|
let out_w = (dims[2] - self.w_k) / self.w_stride + 1;
|
|
let out_h = (dims[3] - self.h_k) / self.h_stride + 1;
|
|
let dst_el = out_w * out_h * dims[0] * dims[1];
|
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
|
let kname = match self.op {
|
|
PoolOp::Max => "max_pool2d",
|
|
PoolOp::Avg => "avg_pool2d",
|
|
};
|
|
let func = dev.get_or_load_func(&kernel_name::<T>(kname), kernels::CONV)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
|
let ds = dev.htod_copy(ds).w()?;
|
|
let params = (
|
|
el,
|
|
self.w_k,
|
|
self.h_k,
|
|
self.w_stride,
|
|
self.h_stride,
|
|
&ds,
|
|
inp,
|
|
&out,
|
|
);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(out)
|
|
}
|
|
}
|
|
|
|
struct UpsampleNearest2D(usize, usize);
|
|
impl Map1 for UpsampleNearest2D {
|
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
&self,
|
|
inp: &CudaSlice<T>,
|
|
dev: &CudaDevice,
|
|
inp_l: &Layout,
|
|
) -> Result<CudaSlice<T>> {
|
|
// Input shape: (b_size, c, h, w)
|
|
let inp = &inp.slice(inp_l.start_offset()..);
|
|
let shape = inp_l.shape();
|
|
let dims = shape.dims();
|
|
let ds = if dims.len() == 4 {
|
|
[dims, inp_l.stride()].concat()
|
|
} else {
|
|
crate::bail!("unexpected input shape for upsample {dims:?}")
|
|
};
|
|
let (out_w, out_h) = (self.0, self.1);
|
|
let dst_el = out_w * out_h * dims[0] * dims[1];
|
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
|
let func = dev.get_or_load_func(&kernel_name::<T>("upsample_nearest2d"), kernels::CONV)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
|
let ds = dev.htod_copy(ds).w()?;
|
|
let scale_w = dims[2] as f64 / out_w as f64;
|
|
let scale_h = dims[3] as f64 / out_h as f64;
|
|
let params = (out_w, out_h, scale_w, scale_h, &ds, inp, &out);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(out)
|
|
}
|
|
}
|
|
|
|
struct WhereCond<'a>(&'a CudaStorage, &'a Layout);
|
|
impl<'a> Map2 for WhereCond<'a> {
|
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
&self,
|
|
t: &CudaSlice<T>,
|
|
layout_t: &Layout,
|
|
f: &CudaSlice<T>,
|
|
layout_f: &Layout,
|
|
dev: &CudaDevice,
|
|
) -> Result<CudaSlice<T>> {
|
|
let ids_l = &self.1;
|
|
let (ids, name) = match &self.0.slice {
|
|
CudaStorageSlice::U8(slice) => {
|
|
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
|
|
(ptr, "where_u8")
|
|
}
|
|
CudaStorageSlice::U32(slice) => {
|
|
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
|
|
(ptr, "where_u32")
|
|
}
|
|
CudaStorageSlice::I64(slice) => {
|
|
let ptr = *slice.slice(ids_l.start_offset()..).device_ptr();
|
|
(ptr, "where_i64")
|
|
}
|
|
_ => Err(CudaError::UnexpectedDType {
|
|
msg: "where conditions should be u8/u32/i64",
|
|
expected: DType::U32,
|
|
got: self.0.dtype(),
|
|
})
|
|
.w()?,
|
|
};
|
|
let shape = ids_l.shape();
|
|
let dims = shape.dims();
|
|
let el = shape.elem_count();
|
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
|
let ds = dev
|
|
.htod_copy([dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())
|
|
.w()?;
|
|
let t = &t.slice(layout_t.start_offset()..);
|
|
let f = &f.slice(layout_f.start_offset()..);
|
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::TERNARY)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
|
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
|
// SAFETY: ffi
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(out)
|
|
}
|
|
}
|
|
|
|
impl<U: crate::op::BinaryOpT> Map2 for U {
|
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
&self,
|
|
lhs: &CudaSlice<T>,
|
|
lhs_l: &Layout,
|
|
rhs: &CudaSlice<T>,
|
|
rhs_l: &Layout,
|
|
dev: &CudaDevice,
|
|
) -> Result<CudaSlice<T>> {
|
|
let shape = lhs_l.shape();
|
|
let dims = shape.dims();
|
|
let elem_count = shape.elem_count();
|
|
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
|
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
|
|
SlicePtrOrNull::Null
|
|
} else {
|
|
SlicePtrOrNull::Ptr(
|
|
dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
|
|
.w()?,
|
|
)
|
|
};
|
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
|
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), kernels::BINARY)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let out = unsafe { dev.alloc::<T>(elem_count) }.w()?;
|
|
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
|
// SAFETY: ffi
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(out)
|
|
}
|
|
}
|
|
|
|
struct Cmp(CmpOp);
|
|
impl Map2Any for Cmp {
|
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
&self,
|
|
lhs: &CudaSlice<T>,
|
|
lhs_l: &Layout,
|
|
rhs: &CudaSlice<T>,
|
|
rhs_l: &Layout,
|
|
dev: &CudaDevice,
|
|
) -> Result<S> {
|
|
let shape = lhs_l.shape();
|
|
let dims = shape.dims();
|
|
let elem_count = shape.elem_count();
|
|
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
|
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
|
|
SlicePtrOrNull::Null
|
|
} else {
|
|
SlicePtrOrNull::Ptr(
|
|
dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())
|
|
.w()?,
|
|
)
|
|
};
|
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
|
let name = match self.0 {
|
|
CmpOp::Eq => "eq",
|
|
CmpOp::Ne => "ne",
|
|
CmpOp::Lt => "lt",
|
|
CmpOp::Le => "le",
|
|
CmpOp::Gt => "gt",
|
|
CmpOp::Ge => "ge",
|
|
};
|
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::BINARY)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let out = unsafe { dev.alloc::<u8>(elem_count) }.w()?;
|
|
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
|
// SAFETY: ffi
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(S::U8(out))
|
|
}
|
|
}
|
|
|
|
fn slice_src_and_dst<'a, T>(
|
|
src: &'a CudaSlice<T>,
|
|
src_l: &Layout,
|
|
dst: &'a mut CudaSlice<T>,
|
|
dst_offset: usize,
|
|
) -> (
|
|
cudarc::driver::CudaView<'a, T>,
|
|
cudarc::driver::CudaViewMut<'a, T>,
|
|
) {
|
|
let src_offset = src_l.start_offset();
|
|
let to_copy = dst
|
|
.len()
|
|
.saturating_sub(dst_offset)
|
|
.min(src.len().saturating_sub(src_offset));
|
|
let src = src.slice(src_offset..src_offset + to_copy);
|
|
let dst = dst.slice_mut(dst_offset..dst_offset + to_copy);
|
|
(src, dst)
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct CudaStorage {
|
|
pub slice: CudaStorageSlice,
|
|
pub device: CudaDevice,
|
|
}
|
|
|
|
pub trait CudaDType: Sized {
|
|
fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>>;
|
|
fn wrap_cuda_slice(s: CudaSlice<Self>, dev: CudaDevice) -> CudaStorage;
|
|
}
|
|
|
|
macro_rules! cuda_dtype {
|
|
($ty:ty, $dtype:ident) => {
|
|
impl CudaDType for $ty {
|
|
fn as_cuda_slice(s: &CudaStorage) -> Result<&CudaSlice<Self>> {
|
|
match &s.slice {
|
|
CudaStorageSlice::$dtype(data) => Ok(&data),
|
|
_ => Err(crate::Error::UnexpectedDType {
|
|
expected: DType::$dtype,
|
|
got: s.dtype(),
|
|
msg: "unexpected dtype",
|
|
}
|
|
.bt()),
|
|
}
|
|
}
|
|
|
|
fn wrap_cuda_slice(slice: CudaSlice<Self>, device: CudaDevice) -> CudaStorage {
|
|
let slice = CudaStorageSlice::$dtype(slice);
|
|
CudaStorage { slice, device }
|
|
}
|
|
}
|
|
};
|
|
}
|
|
cuda_dtype!(u8, U8);
|
|
cuda_dtype!(u32, U32);
|
|
cuda_dtype!(i64, I64);
|
|
cuda_dtype!(f16, F16);
|
|
cuda_dtype!(bf16, BF16);
|
|
cuda_dtype!(f32, F32);
|
|
cuda_dtype!(f64, F64);
|
|
|
|
impl CudaStorage {
|
|
pub fn wrap_cuda_slice<T: CudaDType>(slice: CudaSlice<T>, device: CudaDevice) -> CudaStorage {
|
|
T::wrap_cuda_slice(slice, device)
|
|
}
|
|
|
|
pub fn as_cuda_slice<T: CudaDType>(&self) -> Result<&CudaSlice<T>> {
|
|
T::as_cuda_slice(self)
|
|
}
|
|
}
|
|
|
|
fn gemm_config<T>(
|
|
alpha: T,
|
|
beta: T,
|
|
(b, m, n, k): (usize, usize, usize, usize),
|
|
lhs_l: &Layout,
|
|
rhs_l: &Layout,
|
|
) -> Result<StridedBatchedConfig<T>> {
|
|
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm
|
|
use cudarc::cublas::sys::cublasOperation_t;
|
|
|
|
let lhs_stride = lhs_l.stride();
|
|
let rhs_stride = rhs_l.stride();
|
|
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
|
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
|
// The a tensor has dims batching, k, n (rhs)
|
|
// We also allow for the case where the stride on the minor dimension is not as expected but
|
|
// there is a single element.
|
|
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
|
(n as i32, cublasOperation_t::CUBLAS_OP_N)
|
|
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
|
|
(k as i32, cublasOperation_t::CUBLAS_OP_T)
|
|
} else {
|
|
Err(CudaError::MatMulNonContiguous {
|
|
lhs_stride: lhs_l.clone(),
|
|
rhs_stride: rhs_l.clone(),
|
|
mnk: (m, n, k),
|
|
})?
|
|
};
|
|
// The b tensor has dims batching, m, k (lhs)
|
|
// We also allow for the case where the stride on the minor dimension is not as expected but
|
|
// there is a single element.
|
|
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
|
(k as i32, cublasOperation_t::CUBLAS_OP_N)
|
|
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
|
|
(m as i32, cublasOperation_t::CUBLAS_OP_T)
|
|
} else {
|
|
Err(CudaError::MatMulNonContiguous {
|
|
lhs_stride: lhs_l.clone(),
|
|
rhs_stride: rhs_l.clone(),
|
|
mnk: (m, n, k),
|
|
})?
|
|
};
|
|
// The setup below was copied from:
|
|
// https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531
|
|
let gemm = GemmConfig {
|
|
alpha,
|
|
beta,
|
|
m: n as i32,
|
|
n: m as i32,
|
|
k: k as i32,
|
|
lda,
|
|
ldb,
|
|
ldc: n as i32,
|
|
transa,
|
|
transb,
|
|
};
|
|
|
|
let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] {
|
|
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
|
[_, stride] if lhs_l.dims()[0] == 1 => stride,
|
|
[stride, _] if lhs_l.dims()[1] == 1 => stride,
|
|
[stride] => stride,
|
|
[] => m * k,
|
|
_ => Err(CudaError::MatMulNonContiguous {
|
|
lhs_stride: lhs_l.clone(),
|
|
rhs_stride: rhs_l.clone(),
|
|
mnk: (m, n, k),
|
|
})?,
|
|
};
|
|
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
|
|
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
|
[_, stride] if rhs_l.dims()[0] == 1 => stride,
|
|
[stride, _] if rhs_l.dims()[1] == 1 => stride,
|
|
[stride] => stride,
|
|
[] => n * k,
|
|
_ => Err(CudaError::MatMulNonContiguous {
|
|
lhs_stride: lhs_l.clone(),
|
|
rhs_stride: rhs_l.clone(),
|
|
mnk: (m, n, k),
|
|
})?,
|
|
};
|
|
|
|
Ok(StridedBatchedConfig {
|
|
batch_size: b as i32,
|
|
gemm,
|
|
stride_a: stride_a as i64,
|
|
stride_b: stride_b as i64,
|
|
stride_c: (m * n) as i64,
|
|
})
|
|
}
|
|
|
|
impl BackendStorage for CudaStorage {
|
|
type Device = CudaDevice;
|
|
|
|
fn try_clone(&self, layout: &Layout) -> Result<Self> {
|
|
let slice = Clone.map(&self.slice, self.device(), layout)?;
|
|
let device = self.device.clone();
|
|
Ok(Self { slice, device })
|
|
}
|
|
|
|
fn dtype(&self) -> DType {
|
|
match self.slice {
|
|
CudaStorageSlice::U8(_) => DType::U8,
|
|
CudaStorageSlice::U32(_) => DType::U32,
|
|
CudaStorageSlice::I64(_) => DType::I64,
|
|
CudaStorageSlice::BF16(_) => DType::BF16,
|
|
CudaStorageSlice::F16(_) => DType::F16,
|
|
CudaStorageSlice::F32(_) => DType::F32,
|
|
CudaStorageSlice::F64(_) => DType::F64,
|
|
}
|
|
}
|
|
|
|
fn device(&self) -> &CudaDevice {
|
|
&self.device
|
|
}
|
|
|
|
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
|
let shape = layout.shape();
|
|
let dims = shape.dims();
|
|
let el = shape.elem_count();
|
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
|
let dev = self.device();
|
|
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
|
let start_o = layout.start_offset();
|
|
// This returns an i64 rather than a &i64, this is useful to get around some temporary
|
|
// lifetime issue and is safe as long as self.slice does not go out of scope before inp
|
|
// is used.
|
|
let inp = match &self.slice {
|
|
CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(),
|
|
CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(),
|
|
CudaStorageSlice::I64(inp) => *inp.slice(start_o..).device_ptr(),
|
|
CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(),
|
|
CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(),
|
|
CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(),
|
|
CudaStorageSlice::F64(inp) => *inp.slice(start_o..).device_ptr(),
|
|
};
|
|
let inp = &inp;
|
|
|
|
let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str());
|
|
let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?;
|
|
let slice = match dtype {
|
|
DType::U8 => {
|
|
let out = unsafe { dev.alloc::<u8>(el) }.w()?;
|
|
let params = (el, dims.len(), &ds, *inp, &out);
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
CudaStorageSlice::U8(out)
|
|
}
|
|
DType::U32 => {
|
|
let out = unsafe { dev.alloc::<u32>(el) }.w()?;
|
|
let params = (el, dims.len(), &ds, *inp, &out);
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
CudaStorageSlice::U32(out)
|
|
}
|
|
DType::I64 => {
|
|
let out = unsafe { dev.alloc::<i64>(el) }.w()?;
|
|
let params = (el, dims.len(), &ds, *inp, &out);
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
CudaStorageSlice::I64(out)
|
|
}
|
|
DType::BF16 => {
|
|
let out = unsafe { dev.alloc::<bf16>(el) }.w()?;
|
|
let params = (el, dims.len(), &ds, *inp, &out);
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
CudaStorageSlice::BF16(out)
|
|
}
|
|
DType::F16 => {
|
|
let out = unsafe { dev.alloc::<f16>(el) }.w()?;
|
|
let params = (el, dims.len(), &ds, *inp, &out);
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
CudaStorageSlice::F16(out)
|
|
}
|
|
DType::F32 => {
|
|
let out = unsafe { dev.alloc::<f32>(el) }.w()?;
|
|
let params = (el, dims.len(), &ds, *inp, &out);
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
CudaStorageSlice::F32(out)
|
|
}
|
|
DType::F64 => {
|
|
let out = unsafe { dev.alloc::<f64>(el) }.w()?;
|
|
let params = (el, dims.len(), &ds, *inp, &out);
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
CudaStorageSlice::F64(out)
|
|
}
|
|
};
|
|
Ok(Self {
|
|
slice,
|
|
device: dev.clone(),
|
|
})
|
|
}
|
|
|
|
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
|
let device = self.device().clone();
|
|
let slice = Affine(mul, add).map(&self.slice, &device, layout)?;
|
|
Ok(Self { slice, device })
|
|
}
|
|
|
|
fn powf(&self, layout: &Layout, e: f64) -> Result<Self> {
|
|
let device = self.device().clone();
|
|
let slice = Powf(e).map(&self.slice, &device, layout)?;
|
|
Ok(Self { slice, device })
|
|
}
|
|
|
|
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
|
|
let device = self.device().clone();
|
|
let slice = Elu(alpha).map(&self.slice, &device, layout)?;
|
|
Ok(Self { slice, device })
|
|
}
|
|
|
|
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
|
let device = self.device().clone();
|
|
let slice = FastReduce(sum_dims, op).map(&self.slice, &device, layout)?;
|
|
Ok(Self { slice, device })
|
|
}
|
|
|
|
fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
|
|
let device = self.device().clone();
|
|
let slice = Cmp(op).map(&self.slice, lhs_l, &rhs.slice, rhs_l, &device)?;
|
|
Ok(Self { slice, device })
|
|
}
|
|
|
|
fn unary_impl<U: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
|
let device = self.device().clone();
|
|
let slice = U::V.map(&self.slice, &device, layout)?;
|
|
Ok(Self { slice, device })
|
|
}
|
|
|
|
fn binary_impl<B: BinaryOpT>(
|
|
&self,
|
|
rhs: &Self,
|
|
lhs_l: &Layout,
|
|
rhs_l: &Layout,
|
|
) -> Result<Self> {
|
|
let device = self.device().clone();
|
|
let slice = B::V.map(&self.slice, lhs_l, &rhs.slice, rhs_l, &device)?;
|
|
Ok(Self { slice, device })
|
|
}
|
|
|
|
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
|
match &self.slice {
|
|
CudaStorageSlice::U8(slice) => {
|
|
let dev = slice.device();
|
|
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
|
|
Ok(CpuStorage::U8(cpu_storage))
|
|
}
|
|
CudaStorageSlice::U32(slice) => {
|
|
let dev = slice.device();
|
|
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
|
|
Ok(CpuStorage::U32(cpu_storage))
|
|
}
|
|
CudaStorageSlice::I64(slice) => {
|
|
let dev = slice.device();
|
|
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
|
|
Ok(CpuStorage::I64(cpu_storage))
|
|
}
|
|
CudaStorageSlice::BF16(slice) => {
|
|
let dev = slice.device();
|
|
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
|
|
Ok(CpuStorage::BF16(cpu_storage))
|
|
}
|
|
CudaStorageSlice::F16(slice) => {
|
|
let dev = slice.device();
|
|
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
|
|
Ok(CpuStorage::F16(cpu_storage))
|
|
}
|
|
CudaStorageSlice::F32(slice) => {
|
|
let dev = slice.device();
|
|
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
|
|
Ok(CpuStorage::F32(cpu_storage))
|
|
}
|
|
CudaStorageSlice::F64(slice) => {
|
|
let dev = slice.device();
|
|
let cpu_storage = dev.dtoh_sync_copy(slice).w()?;
|
|
Ok(CpuStorage::F64(cpu_storage))
|
|
}
|
|
}
|
|
}
|
|
|
|
fn where_cond(
|
|
&self,
|
|
layout: &Layout,
|
|
t: &Self,
|
|
t_l: &Layout,
|
|
f: &Self,
|
|
f_l: &Layout,
|
|
) -> Result<Self> {
|
|
let device = self.device().clone();
|
|
let slice = WhereCond(self, layout).map(&t.slice, t_l, &f.slice, f_l, &device)?;
|
|
Ok(Self { slice, device })
|
|
}
|
|
|
|
fn conv1d(
|
|
&self,
|
|
l: &Layout,
|
|
kernel: &Self,
|
|
kernel_l: &Layout,
|
|
params: &crate::conv::ParamsConv1D,
|
|
) -> Result<Self> {
|
|
const USE_IM2COL_CONV1D: bool = true;
|
|
|
|
let device = self.device().clone();
|
|
if !USE_IM2COL_CONV1D {
|
|
let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
|
return Ok(Self { slice, device });
|
|
}
|
|
|
|
let col = Im2Col1D {
|
|
l_k: params.k_size,
|
|
stride: params.stride,
|
|
dilation: params.dilation,
|
|
padding: params.padding,
|
|
}
|
|
.map(&self.slice, &device, l)?;
|
|
let col = Self { slice: col, device };
|
|
let l_out = params.l_out();
|
|
let b = params.b_size;
|
|
let n = params.c_out;
|
|
let k = params.k_size * params.c_in;
|
|
let m = l_out;
|
|
let col_l = Layout::contiguous((b, m, k));
|
|
let res = if kernel_l.is_contiguous() {
|
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
|
.transpose(1, 2)?
|
|
.broadcast_as((b, k, n))?;
|
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
|
} else {
|
|
// Make the kernel contiguous if not already the case.
|
|
let mut kernel_c = unsafe {
|
|
self.device()
|
|
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
|
};
|
|
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
|
.transpose(1, 2)?
|
|
.broadcast_as((b, k, n))?;
|
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
|
};
|
|
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
|
|
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
|
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
|
Ok(res_t)
|
|
}
|
|
|
|
fn conv_transpose1d(
|
|
&self,
|
|
l: &Layout,
|
|
kernel: &Self,
|
|
kernel_l: &Layout,
|
|
params: &crate::conv::ParamsConvTranspose1D,
|
|
) -> Result<Self> {
|
|
let device = self.device().clone();
|
|
let slice =
|
|
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
|
Ok(Self { slice, device })
|
|
}
|
|
|
|
#[cfg(not(feature = "cudnn"))]
|
|
fn conv2d(
|
|
&self,
|
|
l: &Layout,
|
|
kernel: &Self,
|
|
kernel_l: &Layout,
|
|
params: &crate::conv::ParamsConv2D,
|
|
) -> Result<Self> {
|
|
const USE_IM2COL_CONV2D: bool = true;
|
|
|
|
let device = self.device().clone();
|
|
if !USE_IM2COL_CONV2D {
|
|
let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
|
return Ok(Self { slice, device });
|
|
}
|
|
|
|
let col = Im2Col {
|
|
h_k: params.k_h,
|
|
w_k: params.k_w,
|
|
stride: params.stride,
|
|
dilation: params.dilation,
|
|
padding: params.padding,
|
|
}
|
|
.map(&self.slice, &device, l)?;
|
|
let col = Self { slice: col, device };
|
|
let h_out = params.out_h();
|
|
let w_out = params.out_w();
|
|
let b = params.b_size;
|
|
let n = params.c_out;
|
|
let k = params.k_h * params.k_w * params.c_in;
|
|
let m = h_out * w_out;
|
|
let col_l = Layout::contiguous((b, m, k));
|
|
let res = if kernel_l.is_contiguous() {
|
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
|
.transpose(1, 2)?
|
|
.broadcast_as((b, k, n))?;
|
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
|
} else {
|
|
// Make the kernel contiguous if not already the case.
|
|
let mut kernel_c = unsafe {
|
|
self.device()
|
|
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
|
};
|
|
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
|
.transpose(1, 2)?
|
|
.broadcast_as((b, k, n))?;
|
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
|
};
|
|
let res_l = Layout::contiguous((b, h_out, w_out, n))
|
|
.transpose(1, 2)?
|
|
.transpose(1, 3)?;
|
|
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
|
res.copy_strided_src(&mut res_t, 0, &res_l)?;
|
|
Ok(res_t)
|
|
}
|
|
|
|
#[cfg(feature = "cudnn")]
|
|
fn conv2d(
|
|
&self,
|
|
inp_l: &Layout,
|
|
kernel: &Self,
|
|
kernel_l: &Layout,
|
|
params: &crate::conv::ParamsConv2D,
|
|
) -> Result<Self> {
|
|
let device = self.device().clone();
|
|
if !kernel_l.is_contiguous() {
|
|
let slice = Conv2D(params).map(&self.slice, inp_l, &kernel.slice, kernel_l, &device)?;
|
|
return Ok(Self { slice, device });
|
|
}
|
|
let (out_w, out_h) = (params.out_w(), params.out_h());
|
|
let dst_el = params.c_out * out_w * out_h * params.b_size;
|
|
let slice = match (&self.slice, &kernel.slice) {
|
|
(S::U8(inp), S::U8(k)) => {
|
|
let inp = &inp.slice(inp_l.start_offset()..);
|
|
let k = &k.slice(kernel_l.start_offset()..);
|
|
let mut out = unsafe { device.alloc::<u8>(dst_el) }.w()?;
|
|
crate::cudnn::launch_conv2d::<u8>(inp, inp_l, k, &mut out, params, &device)
|
|
.map_err(crate::Error::wrap)?;
|
|
S::U8(out)
|
|
}
|
|
(S::BF16(inp), S::BF16(k)) => {
|
|
let inp = &inp.slice(inp_l.start_offset()..);
|
|
let k = &k.slice(kernel_l.start_offset()..);
|
|
let mut out = unsafe { device.alloc::<bf16>(dst_el) }.w()?;
|
|
crate::cudnn::launch_conv2d::<bf16>(inp, inp_l, k, &mut out, params, &device)
|
|
.map_err(crate::Error::wrap)?;
|
|
S::BF16(out)
|
|
}
|
|
(S::F16(inp), S::F16(k)) => {
|
|
let inp = &inp.slice(inp_l.start_offset()..);
|
|
let k = &k.slice(kernel_l.start_offset()..);
|
|
let mut out = unsafe { device.alloc::<f16>(dst_el) }.w()?;
|
|
crate::cudnn::launch_conv2d::<f16>(inp, inp_l, k, &mut out, params, &device)
|
|
.map_err(crate::Error::wrap)?;
|
|
S::F16(out)
|
|
}
|
|
(S::F32(inp), S::F32(k)) => {
|
|
let inp = &inp.slice(inp_l.start_offset()..);
|
|
let k = &k.slice(kernel_l.start_offset()..);
|
|
let mut out = unsafe { device.alloc::<f32>(dst_el) }.w()?;
|
|
crate::cudnn::launch_conv2d::<f32>(inp, inp_l, k, &mut out, params, &device)
|
|
.map_err(crate::Error::wrap)?;
|
|
S::F32(out)
|
|
}
|
|
(S::F64(inp), S::F64(k)) => {
|
|
let inp = &inp.slice(inp_l.start_offset()..);
|
|
let k = &k.slice(kernel_l.start_offset()..);
|
|
let mut out = unsafe { device.alloc::<f64>(dst_el) }.w()?;
|
|
crate::cudnn::launch_conv2d::<f64>(inp, inp_l, k, &mut out, params, &device)
|
|
.map_err(crate::Error::wrap)?;
|
|
S::F64(out)
|
|
}
|
|
(S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?,
|
|
(S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv2d does not support i64"))?,
|
|
_ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?,
|
|
};
|
|
Ok(Self { slice, device })
|
|
}
|
|
|
|
fn conv_transpose2d(
|
|
&self,
|
|
l: &Layout,
|
|
kernel: &Self,
|
|
kernel_l: &Layout,
|
|
params: &crate::conv::ParamsConvTranspose2D,
|
|
) -> Result<Self> {
|
|
let device = self.device().clone();
|
|
let slice =
|
|
ConvTranspose2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
|
Ok(Self { slice, device })
|
|
}
|
|
|
|
fn avg_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> {
|
|
let device = self.device().clone();
|
|
let slice = Pool2D {
|
|
w_k: k.0,
|
|
h_k: k.1,
|
|
w_stride: stride.0,
|
|
h_stride: stride.1,
|
|
op: PoolOp::Avg,
|
|
}
|
|
.map(&self.slice, &device, l)?;
|
|
Ok(Self { slice, device })
|
|
}
|
|
|
|
fn max_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> {
|
|
let device = self.device().clone();
|
|
let slice = Pool2D {
|
|
w_k: k.0,
|
|
h_k: k.1,
|
|
w_stride: stride.0,
|
|
h_stride: stride.1,
|
|
op: PoolOp::Max,
|
|
}
|
|
.map(&self.slice, &device, l)?;
|
|
Ok(Self { slice, device })
|
|
}
|
|
|
|
fn upsample_nearest1d(&self, _: &Layout, _out_sz: usize) -> Result<Self> {
|
|
crate::bail!("upsample-nearest1d is not supported on cuda")
|
|
}
|
|
|
|
fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
|
|
let device = self.device().clone();
|
|
let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?;
|
|
Ok(Self { slice, device })
|
|
}
|
|
|
|
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
|
let device = self.device().clone();
|
|
let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?;
|
|
Ok(Self { slice, device })
|
|
}
|
|
fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
|
|
let device = self.device().clone();
|
|
let slice = Gather(ids, ids_l, dim).map(&self.slice, &device, l)?;
|
|
Ok(Self { slice, device })
|
|
}
|
|
fn scatter_add(
|
|
&self,
|
|
l: &Layout,
|
|
ids: &Self,
|
|
ids_l: &Layout,
|
|
src: &Self,
|
|
src_l: &Layout,
|
|
dim: usize,
|
|
) -> Result<Self> {
|
|
let device = self.device().clone();
|
|
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
|
self.copy_strided_src(&mut acc, 0, l)?;
|
|
ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
|
Ok(acc)
|
|
}
|
|
fn index_add(
|
|
&self,
|
|
l: &Layout,
|
|
ids: &Self,
|
|
ids_l: &Layout,
|
|
src: &Self,
|
|
src_l: &Layout,
|
|
dim: usize,
|
|
) -> Result<Self> {
|
|
let device = self.device().clone();
|
|
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
|
self.copy_strided_src(&mut acc, 0, l)?;
|
|
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
|
Ok(acc)
|
|
}
|
|
|
|
fn matmul(
|
|
&self,
|
|
rhs: &Self,
|
|
(b, m, n, k): (usize, usize, usize, usize),
|
|
lhs_l: &Layout,
|
|
rhs_l: &Layout,
|
|
) -> Result<Self> {
|
|
let elem_count = b * m * n;
|
|
let dev = &self.device;
|
|
let slice = match (&self.slice, &rhs.slice) {
|
|
(CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => {
|
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
|
let cfg = gemm_config(bf16::ONE, bf16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
|
|
let mut out = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
|
|
unsafe { gemm_strided_batched_bf16(&self.device.blas, cfg, rhs, lhs, &mut out) }
|
|
.w()?;
|
|
CudaStorageSlice::BF16(out)
|
|
}
|
|
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
|
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
|
let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
|
|
let mut out = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
|
unsafe { gemm_strided_batched_f16(&self.device.blas, cfg, rhs, lhs, &mut out) }
|
|
.w()?;
|
|
CudaStorageSlice::F16(out)
|
|
}
|
|
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
|
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
|
|
let mut out = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
|
unsafe {
|
|
self.device
|
|
.blas
|
|
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
|
|
}
|
|
.w()?;
|
|
CudaStorageSlice::F32(out)
|
|
}
|
|
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
|
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
|
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
|
|
let mut out = unsafe { dev.alloc::<f64>(elem_count) }.w()?;
|
|
unsafe {
|
|
self.device
|
|
.blas
|
|
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
|
|
}
|
|
.w()?;
|
|
CudaStorageSlice::F64(out)
|
|
}
|
|
_ => Err(CudaError::InternalError("dtype mismatch in matmul op"))?,
|
|
};
|
|
let device = dev.clone();
|
|
Ok(Self { slice, device })
|
|
}
|
|
|
|
fn copy2d(
|
|
&self,
|
|
dst: &mut Self,
|
|
d1: usize,
|
|
d2: usize,
|
|
src_s: usize,
|
|
dst_s: usize,
|
|
src_o: usize,
|
|
dst_o: usize,
|
|
) -> Result<()> {
|
|
let dev = &self.device;
|
|
let d1 = d1 as u32;
|
|
let d2 = d2 as u32;
|
|
// Nothing to copy so we exit early to avoid launching a kernel and some potential invalid
|
|
// argument with a null pointer.
|
|
if d1 == 0 || d2 == 0 {
|
|
return Ok(());
|
|
}
|
|
let dst_s = dst_s as u32;
|
|
let src_s = src_s as u32;
|
|
let (src, dst, kname) = match (&self.slice, &mut dst.slice) {
|
|
(S::U8(s), S::U8(d)) => (
|
|
*s.slice(src_o..).device_ptr(),
|
|
*d.slice(dst_o..).device_ptr(),
|
|
"copy2d_u8",
|
|
),
|
|
(S::U32(s), S::U32(d)) => (
|
|
*s.slice(src_o..).device_ptr(),
|
|
*d.slice(dst_o..).device_ptr(),
|
|
"copy2d_u32",
|
|
),
|
|
(S::I64(s), S::I64(d)) => (
|
|
*s.slice(src_o..).device_ptr(),
|
|
*d.slice(dst_o..).device_ptr(),
|
|
"copy2d_i64",
|
|
),
|
|
(S::BF16(s), S::BF16(d)) => (
|
|
*s.slice(src_o..).device_ptr(),
|
|
*d.slice(dst_o..).device_ptr(),
|
|
"copy2d_bf16",
|
|
),
|
|
(S::F16(s), S::F16(d)) => (
|
|
*s.slice(src_o..).device_ptr(),
|
|
*d.slice(dst_o..).device_ptr(),
|
|
"copy2d_f16",
|
|
),
|
|
(S::F32(s), S::F32(d)) => (
|
|
*s.slice(src_o..).device_ptr(),
|
|
*d.slice(dst_o..).device_ptr(),
|
|
"copy2d_f32",
|
|
),
|
|
(S::F64(s), S::F64(d)) => (
|
|
*s.slice(src_o..).device_ptr(),
|
|
*d.slice(dst_o..).device_ptr(),
|
|
"copy2d_f64",
|
|
),
|
|
_ => Err(CudaError::InternalError("dtype mismatch in copy2d"))?,
|
|
};
|
|
let func = dev.get_or_load_func(kname, kernels::FILL)?;
|
|
let cfg = LaunchConfig::for_num_elems(d1 * d2);
|
|
let params = (src, dst, d1, d2, src_s, dst_s);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
Ok(())
|
|
}
|
|
|
|
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
|
let src_shape = src_l.shape();
|
|
let dims = src_shape.dims();
|
|
let el_count = src_shape.elem_count();
|
|
if el_count == 0 {
|
|
return Ok(());
|
|
}
|
|
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
|
let dev = &self.device;
|
|
let ds = SlicePtrOrNull::params_from_layout(dev, src_l)?;
|
|
match (&self.slice, &mut dst.slice) {
|
|
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
|
|
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
|
if src_l.is_contiguous() {
|
|
dev.dtod_copy(&src, &mut dst).w()?
|
|
} else {
|
|
let func = dev.get_or_load_func("ucopy_bf16", kernels::UNARY)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?
|
|
}
|
|
}
|
|
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
|
|
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
|
if src_l.is_contiguous() {
|
|
dev.dtod_copy(&src, &mut dst).w()?
|
|
} else {
|
|
let func = dev.get_or_load_func("ucopy_f16", kernels::UNARY)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?
|
|
}
|
|
}
|
|
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
|
|
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
|
if src_l.is_contiguous() {
|
|
dev.dtod_copy(&src, &mut dst).w()?
|
|
} else {
|
|
let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?
|
|
}
|
|
}
|
|
(CudaStorageSlice::U8(src), CudaStorageSlice::U8(dst)) => {
|
|
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
|
if src_l.is_contiguous() {
|
|
dev.dtod_copy(&src, &mut dst).w()?
|
|
} else {
|
|
let func = dev.get_or_load_func("ucopy_u8", kernels::UNARY)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?
|
|
}
|
|
}
|
|
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
|
|
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
|
if src_l.is_contiguous() {
|
|
dev.dtod_copy(&src, &mut dst).w()?
|
|
} else {
|
|
let func = dev.get_or_load_func("ucopy_u32", kernels::UNARY)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?
|
|
}
|
|
}
|
|
(CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => {
|
|
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
|
if src_l.is_contiguous() {
|
|
dev.dtod_copy(&src, &mut dst).w()?
|
|
} else {
|
|
let func = dev.get_or_load_func("ucopy_i64", kernels::UNARY)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?
|
|
}
|
|
}
|
|
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
|
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
|
if src_l.is_contiguous() {
|
|
dev.dtod_copy(&src, &mut dst).w()?
|
|
} else {
|
|
let func = dev.get_or_load_func("ucopy_f64", kernels::UNARY)?;
|
|
// SAFETY: Set later by running the kernel.
|
|
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
|
// SAFETY: ffi.
|
|
unsafe { func.launch(cfg, params) }.w()?;
|
|
}
|
|
}
|
|
_ => Err(CudaError::InternalError(
|
|
"dtype mismatch in copy_strided op",
|
|
))?,
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
// Default for the reduced precision setting is false, similar to pytorch.
|
|
// https://github.com/pytorch/pytorch/issues/123157
|
|
static MM_F16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
|
|
std::sync::atomic::AtomicBool::new(false);
|
|
static MM_BF16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
|
|
std::sync::atomic::AtomicBool::new(false);
|
|
|
|
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
|
/// allowed with f16 GEMMs.
|
|
pub fn gemm_reduced_precision_f16() -> bool {
|
|
MM_F16_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed)
|
|
}
|
|
|
|
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
|
/// allowed with f16 GEMMs.
|
|
pub fn set_gemm_reduced_precision_f16(b: bool) {
|
|
MM_F16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
|
|
}
|
|
|
|
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
|
/// allowed with bf16 GEMMs.
|
|
pub fn gemm_reduced_precision_bf16() -> bool {
|
|
MM_BF16_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed)
|
|
}
|
|
|
|
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
|
/// allowed with bf16 GEMMs.
|
|
pub fn set_gemm_reduced_precision_bf16(b: bool) {
|
|
MM_BF16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
|
|
}
|
|
|
|
unsafe fn gemm_strided_batched_f16(
|
|
cublas: &cudarc::cublas::CudaBlas,
|
|
cfg: StridedBatchedConfig<f16>,
|
|
a: &cudarc::driver::CudaView<f16>,
|
|
b: &cudarc::driver::CudaView<f16>,
|
|
c: &mut CudaSlice<f16>,
|
|
) -> std::result::Result<(), cudarc::cublas::result::CublasError> {
|
|
use cudarc::cublas::sys;
|
|
use cudarc::driver::DevicePtrMut;
|
|
|
|
let compute_type = if gemm_reduced_precision_f16() {
|
|
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F
|
|
} else {
|
|
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
|
|
};
|
|
|
|
let alpha = cfg.gemm.alpha;
|
|
let beta = cfg.gemm.beta;
|
|
cudarc::cublas::result::gemm_strided_batched_ex(
|
|
*cublas.handle(),
|
|
cfg.gemm.transa,
|
|
cfg.gemm.transb,
|
|
cfg.gemm.m,
|
|
cfg.gemm.n,
|
|
cfg.gemm.k,
|
|
(&alpha) as *const f16 as *const _,
|
|
*a.device_ptr() as *const _,
|
|
sys::cudaDataType_t::CUDA_R_16F,
|
|
cfg.gemm.lda,
|
|
cfg.stride_a,
|
|
*b.device_ptr() as *const _,
|
|
sys::cudaDataType_t::CUDA_R_16F,
|
|
cfg.gemm.ldb,
|
|
cfg.stride_b,
|
|
(&beta) as *const f16 as *const _,
|
|
*c.device_ptr_mut() as *mut _,
|
|
sys::cudaDataType_t::CUDA_R_16F,
|
|
cfg.gemm.ldc,
|
|
cfg.stride_c,
|
|
cfg.batch_size,
|
|
compute_type,
|
|
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
|
|
)
|
|
}
|
|
|
|
unsafe fn gemm_strided_batched_bf16(
|
|
cublas: &cudarc::cublas::CudaBlas,
|
|
cfg: StridedBatchedConfig<bf16>,
|
|
a: &cudarc::driver::CudaView<bf16>,
|
|
b: &cudarc::driver::CudaView<bf16>,
|
|
c: &mut CudaSlice<bf16>,
|
|
) -> std::result::Result<(), cudarc::cublas::result::CublasError> {
|
|
use cudarc::cublas::sys;
|
|
use cudarc::driver::DevicePtrMut;
|
|
|
|
let compute_type = if gemm_reduced_precision_bf16() {
|
|
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F
|
|
} else {
|
|
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
|
|
};
|
|
|
|
let alpha = cfg.gemm.alpha;
|
|
let beta = cfg.gemm.beta;
|
|
cudarc::cublas::result::gemm_strided_batched_ex(
|
|
*cublas.handle(),
|
|
cfg.gemm.transa,
|
|
cfg.gemm.transb,
|
|
cfg.gemm.m,
|
|
cfg.gemm.n,
|
|
cfg.gemm.k,
|
|
(&alpha) as *const bf16 as *const _,
|
|
*a.device_ptr() as *const _,
|
|
sys::cudaDataType_t::CUDA_R_16BF,
|
|
cfg.gemm.lda,
|
|
cfg.stride_a,
|
|
*b.device_ptr() as *const _,
|
|
sys::cudaDataType_t::CUDA_R_16BF,
|
|
cfg.gemm.ldb,
|
|
cfg.stride_b,
|
|
(&beta) as *const bf16 as *const _,
|
|
*c.device_ptr_mut() as *mut _,
|
|
sys::cudaDataType_t::CUDA_R_16BF,
|
|
cfg.gemm.ldc,
|
|
cfg.stride_c,
|
|
cfg.batch_size,
|
|
compute_type,
|
|
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
|
|
)
|
|
}
|