Fix sigmoid gradient calculation and move sigmoid into a specialized op (#2114)
* add sigmoid op * small fix * add as a method on `Tensor` * implement gradient calculation for sigmoid * add sigmoid tests * we should have a specialized op for this * fix clippy * fix clippy 2 * Revert all previous commits in favor of a `CustomOp` based solution * use `CustomOp1` implementation * fix rustfmt * experimental add metal impl * add cuda kernel impl * fix fmt * Add a test + reduce some cuda duplication. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
parent
ed7b99f525
commit
3bbb88fcb4
|
@ -18,7 +18,7 @@ pub use device::{CudaDevice, DeviceId};
|
|||
pub use error::{CudaError, WrapErr};
|
||||
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S};
|
||||
|
||||
enum SlicePtrOrNull<T> {
|
||||
pub enum SlicePtrOrNull<T> {
|
||||
Ptr(CudaSlice<T>),
|
||||
Null,
|
||||
}
|
||||
|
@ -33,7 +33,7 @@ unsafe impl<T: DeviceRepr> DeviceRepr for &SlicePtrOrNull<T> {
|
|||
}
|
||||
|
||||
impl SlicePtrOrNull<usize> {
|
||||
fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
|
||||
pub fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
|
||||
let ds = if l.is_contiguous() {
|
||||
SlicePtrOrNull::Null
|
||||
} else {
|
||||
|
|
|
@ -60,6 +60,11 @@ __device__ __forceinline__ T silu_fwd(T x) {
|
|||
return x / (static_cast<T>(1) + expg(-x));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T sigmoid_fwd(T x) {
|
||||
return recipg(static_cast<T>(1) + expg(-x));
|
||||
}
|
||||
|
||||
#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
|
@ -116,6 +121,7 @@ UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
|
|||
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
|
||||
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
|
||||
UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x))
|
||||
UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x))
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
|
@ -142,6 +148,7 @@ UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
|
|||
UNARY_OP(__half, usilu_f16, silu_fwd(x))
|
||||
UNARY_OP1(__half, upowf_f16, powg(x, param))
|
||||
UNARY_OP(__half, usign_f16, sign_(x))
|
||||
UNARY_OP(__half, usigmoid_f16, sigmoid_fwd(x))
|
||||
#endif
|
||||
|
||||
UNARY_OP(uint8_t, ucopy_u8, x)
|
||||
|
@ -193,3 +200,5 @@ UNARY_OP1(float, upowf_f32, powg(x, param))
|
|||
UNARY_OP1(double, upowf_f64, powg(x, param))
|
||||
UNARY_OP(float, usign_f32, sign_(x))
|
||||
UNARY_OP(double, usign_f64, sign_(x))
|
||||
UNARY_OP(float, usigmoid_f32, sigmoid_fwd(x))
|
||||
UNARY_OP(double, usigmoid_f64, sigmoid_fwd(x))
|
||||
|
|
|
@ -129,7 +129,7 @@ macro_rules! ops{
|
|||
pub mod unary {
|
||||
ops!(
|
||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
||||
tanh, recip, silu, sign
|
||||
tanh, recip, silu, sign, sigmoid
|
||||
);
|
||||
}
|
||||
pub mod binary {
|
||||
|
|
|
@ -67,6 +67,9 @@ template <typename T> METAL_FUNC T relu(T in){
|
|||
template <typename T> METAL_FUNC T silu(T in){
|
||||
return in / (static_cast<T>(1) + exp(-in));
|
||||
}
|
||||
template <typename T> METAL_FUNC T sigmoid(T in) {
|
||||
return recip(static_cast<T>(1) + exp(-in));
|
||||
}
|
||||
|
||||
#define TILE_SIZE 2
|
||||
|
||||
|
@ -155,6 +158,7 @@ UNARY_OP(tanh)
|
|||
UNARY_OP(recip)
|
||||
UNARY_OP(relu)
|
||||
UNARY_OP(sign)
|
||||
UNARY_OP(sigmoid)
|
||||
UNARY(id, float, copy_f32, copy_f32_strided)
|
||||
UNARY(id, half, copy_f16, copy_f16_strided)
|
||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||
|
@ -185,6 +189,7 @@ BFLOAT_UNARY_OP(tanh)
|
|||
BFLOAT_UNARY_OP(recip)
|
||||
BFLOAT_UNARY_OP(relu)
|
||||
BFLOAT_UNARY_OP(sign)
|
||||
BFLOAT_UNARY_OP(sigmoid)
|
||||
|
||||
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||
|
||||
|
|
|
@ -43,9 +43,193 @@ pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
|
|||
&xs[0].silu()? * &xs[1]
|
||||
}
|
||||
|
||||
struct Sigmoid;
|
||||
|
||||
impl candle::CustomOp1 for Sigmoid {
|
||||
fn name(&self) -> &'static str {
|
||||
"sigmoid"
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
|
||||
fn fwd<T: num_traits::Float>(v: T) -> T {
|
||||
(v.neg().exp() + T::one()).recip()
|
||||
}
|
||||
|
||||
// FIXME: using `candle::map_dtype` causes compilation errors.
|
||||
let storage = match storage {
|
||||
CpuStorage::BF16(slice) => {
|
||||
CpuStorage::BF16(candle::cpu_backend::unary_map(slice, layout, fwd))
|
||||
}
|
||||
CpuStorage::F16(slice) => {
|
||||
CpuStorage::F16(candle::cpu_backend::unary_map(slice, layout, fwd))
|
||||
}
|
||||
CpuStorage::F32(slice) => {
|
||||
CpuStorage::F32(candle::cpu_backend::unary_map(slice, layout, fwd))
|
||||
}
|
||||
CpuStorage::F64(slice) => {
|
||||
CpuStorage::F64(candle::cpu_backend::unary_map(slice, layout, fwd))
|
||||
}
|
||||
_ => Err(candle::Error::UnsupportedDTypeForOp(
|
||||
storage.dtype(),
|
||||
self.name(),
|
||||
))?,
|
||||
};
|
||||
Ok((storage, layout.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
storage: &candle::CudaStorage,
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||
};
|
||||
use candle::cuda_backend::SlicePtrOrNull;
|
||||
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
|
||||
struct S;
|
||||
impl Map1 for S {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
|
||||
|
||||
let params = (el_count, dims.len(), &ds, src, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
let dev = storage.device();
|
||||
let slice = S.map(&storage.slice, dev, layout)?;
|
||||
let dst = candle::CudaStorage {
|
||||
slice,
|
||||
device: dev.clone(),
|
||||
};
|
||||
Ok((dst, layout.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &candle::MetalStorage,
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::MetalError;
|
||||
let device = storage.device();
|
||||
let dtype = storage.dtype();
|
||||
let shape = layout.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let buffer = device.new_buffer(el_count, dtype, "sigmoid")?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
command_buffer.set_label("sigmoid");
|
||||
let src = candle_metal_kernels::BufferOffset {
|
||||
buffer: storage.buffer(),
|
||||
offset_in_bytes: layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||
};
|
||||
|
||||
match (el_count % 2, dtype, layout.is_contiguous()) {
|
||||
(0, DType::BF16 | DType::F16, true) => {
|
||||
use candle_metal_kernels::unary::contiguous_tiled;
|
||||
let kernel_name = match dtype {
|
||||
DType::F16 => contiguous_tiled::sigmoid::HALF,
|
||||
DType::F32 => contiguous_tiled::sigmoid::FLOAT,
|
||||
DType::BF16 => contiguous_tiled::sigmoid::BFLOAT,
|
||||
dtype => {
|
||||
candle::bail!(
|
||||
"Metal contiguous_tiled unary sigmoid {dtype:?} not implemented"
|
||||
)
|
||||
}
|
||||
};
|
||||
candle_metal_kernels::call_unary_contiguous_tiled(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
kernel_name,
|
||||
el_count,
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
(_, _, true) => {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
let kernel_name = match dtype {
|
||||
DType::F16 => contiguous::sigmoid::HALF,
|
||||
DType::F32 => contiguous::sigmoid::FLOAT,
|
||||
DType::BF16 => contiguous::sigmoid::BFLOAT,
|
||||
dtype => {
|
||||
candle::bail!("Metal contiguous unary sigmoid {dtype:?} not implemented")
|
||||
}
|
||||
};
|
||||
candle_metal_kernels::call_unary_contiguous(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
kernel_name,
|
||||
el_count,
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
(_, _, false) => {
|
||||
use candle_metal_kernels::unary::strided;
|
||||
let kernel_name = match dtype {
|
||||
DType::F16 => strided::sigmoid::HALF,
|
||||
DType::F32 => strided::sigmoid::FLOAT,
|
||||
DType::BF16 => strided::sigmoid::BFLOAT,
|
||||
dtype => {
|
||||
candle::bail!("Metal strided unary sigmoid {dtype:?} not implemented")
|
||||
}
|
||||
};
|
||||
let dst = candle_metal_kernels::BufferOffset::zero_offset(&buffer);
|
||||
candle_metal_kernels::call_unary_strided(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
kernel_name,
|
||||
layout.dims(),
|
||||
src,
|
||||
layout.stride(),
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
}
|
||||
|
||||
let new_storage = candle::MetalStorage::new(buffer, device.clone(), el_count, dtype);
|
||||
Ok((new_storage, layout.shape().clone()))
|
||||
}
|
||||
|
||||
fn bwd(&self, _arg: &Tensor, res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
|
||||
// d/dx sigmoid(x) = (1 - sigmoid(x)) * sigmoid(x)
|
||||
let d_dx_sigmoid = res.ones_like()?.sub(res)?.mul(res)?;
|
||||
Ok(Some(grad_res.mul(&d_dx_sigmoid)?))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: Should we have a specialized op for this?
|
||||
(xs.neg()?.exp()? + 1.0)?.recip()
|
||||
xs.apply_op1(Sigmoid)
|
||||
}
|
||||
|
||||
pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> {
|
||||
|
|
|
@ -170,8 +170,19 @@ fn rope_thd(device: &Device) -> Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn sigmoid(device: &Device) -> Result<()> {
|
||||
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let s1 = candle_nn::ops::sigmoid(&tensor)?;
|
||||
let s2 = (1. / (1. + tensor.neg()?.exp()?)?)?;
|
||||
let diff = (s1 - s2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal);
|
||||
test_device!(rope, rope_cpu, rope_gpu, rope_metal);
|
||||
test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
|
||||
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
|
||||
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
|
||||
test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal);
|
||||
|
|
Loading…
Reference in New Issue