Softmax implementation for cuda. (#747)
This commit is contained in:
parent
94c6a8d3d3
commit
a0d65585db
|
@ -1,7 +1,7 @@
|
|||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
|
||||
use candle_kernels as kernels;
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
use cudarc::driver::{
|
||||
|
@ -383,7 +383,7 @@ impl BackendDevice for CudaDevice {
|
|||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum CudaStorageSlice {
|
||||
pub enum CudaStorageSlice {
|
||||
U8(CudaSlice<u8>),
|
||||
U32(CudaSlice<u32>),
|
||||
I64(CudaSlice<i64>),
|
||||
|
@ -394,7 +394,7 @@ enum CudaStorageSlice {
|
|||
}
|
||||
type S = CudaStorageSlice;
|
||||
|
||||
trait Map1 {
|
||||
pub trait Map1 {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
|
@ -416,7 +416,7 @@ trait Map1 {
|
|||
}
|
||||
}
|
||||
|
||||
trait Map2 {
|
||||
pub trait Map2 {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
|
@ -441,7 +441,7 @@ trait Map2 {
|
|||
}
|
||||
}
|
||||
|
||||
trait Map2InPlace {
|
||||
pub trait Map2InPlace {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
|
@ -472,7 +472,7 @@ trait Map2InPlace {
|
|||
}
|
||||
}
|
||||
|
||||
trait Map1Any {
|
||||
pub trait Map1Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
|
@ -495,7 +495,7 @@ trait Map1Any {
|
|||
}
|
||||
}
|
||||
|
||||
trait Map2Any {
|
||||
pub trait Map2Any {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
|
@ -532,7 +532,7 @@ impl Map1 for Clone {
|
|||
}
|
||||
}
|
||||
|
||||
fn kernel_name<T: WithDType>(root: &str) -> String {
|
||||
pub fn kernel_name<T: WithDType>(root: &str) -> String {
|
||||
let dtype = T::DTYPE.as_str();
|
||||
format!("{root}_{dtype}")
|
||||
}
|
||||
|
@ -1310,8 +1310,8 @@ fn slice_src_and_dst<'a, T>(
|
|||
|
||||
#[derive(Debug)]
|
||||
pub struct CudaStorage {
|
||||
slice: CudaStorageSlice,
|
||||
device: CudaDevice,
|
||||
pub slice: CudaStorageSlice,
|
||||
pub device: CudaDevice,
|
||||
}
|
||||
|
||||
pub trait CudaDType: Sized {
|
||||
|
|
|
@ -126,19 +126,62 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
_storage: &candle::CudaStorage,
|
||||
_layout: &Layout,
|
||||
storage: &candle::CudaStorage,
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
candle::bail!("TODO: implement a cuda kernel")
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
|
||||
struct S;
|
||||
impl Map1 for S {
|
||||
fn f<T: DeviceRepr + WithDType>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let src = match layout.contiguous_offsets() {
|
||||
None => candle::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
};
|
||||
let el = layout.shape().elem_count();
|
||||
let dims = layout.shape().dims();
|
||||
let dim_m1 = dims[dims.len() - 1];
|
||||
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
|
||||
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (n_rows as u32, 1, 1),
|
||||
block_dim: (1, 32, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (src, &dst, n_cols as i32);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
use candle::backend::BackendStorage;
|
||||
let dev = storage.device();
|
||||
let slice = S.map(&storage.slice, dev, layout)?;
|
||||
let dst = candle::cuda_backend::CudaStorage {
|
||||
slice,
|
||||
device: dev.clone(),
|
||||
};
|
||||
Ok((dst, layout.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
||||
if xs.device().is_cpu() {
|
||||
xs.apply_op1_no_bwd(&SoftmaxLastDim)
|
||||
} else {
|
||||
softmax(xs, candle::D::Minus1)
|
||||
}
|
||||
xs.apply_op1_no_bwd(&SoftmaxLastDim)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue