Add a naive conv2d cuda kernel. (#438)
* Add a naive conv2d cuda kernel. * Proper conv2d support on the rust side. * Conv1d testing on gpu. * Also use the test on gpus. * Fix the clean-ptx target.
This commit is contained in:
parent
eab54e4490
commit
34f4b3187e
2
Makefile
2
Makefile
|
@ -2,6 +2,8 @@ clean-ptx:
|
|||
find target -name "*.ptx" -type f -delete
|
||||
echo "" > candle-kernels/src/lib.rs
|
||||
touch candle-kernels/build.rs
|
||||
touch candle-examples/build.rs
|
||||
touch candle-flash-attn/build.rs
|
||||
|
||||
clean:
|
||||
cargo clean
|
||||
|
|
|
@ -897,7 +897,6 @@ impl<'a> Map2 for Conv1D<'a> {
|
|||
// Kernel shape: (c_out, c_in_k, k_size)
|
||||
// Input shape: (b_size, c_in, l_in) or (c_in, l_in)
|
||||
let p = &self.0;
|
||||
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(k_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
|
@ -917,7 +916,44 @@ impl<'a> Map2 for Conv1D<'a> {
|
|||
panic!("unexpected input shape for conv1d {dims:?}")
|
||||
};
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (el, l_out, p.stride, &ds, inp, k, &out);
|
||||
let params = (el, l_out, p.stride, p.padding, &ds, inp, k, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
||||
impl<'a> Map2 for Conv2D<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
inp: &CudaSlice<T>,
|
||||
inp_l: &Layout,
|
||||
k: &CudaSlice<T>,
|
||||
k_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
// Kernel shape: (c_out, c_in_k, w_k, h_k)
|
||||
// Input shape: (b_size, c_in, w_in, c_in)
|
||||
let p = &self.0;
|
||||
let inp = &inp.slice(inp_l.start_offset()..);
|
||||
let k = &k.slice(k_l.start_offset()..);
|
||||
let shape = inp_l.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let (out_w, out_h) = (p.out_w(), p.out_h());
|
||||
let dst_el = p.c_out * out_w * out_h * p.b_size;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv2d"), kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let ds = if dims.len() == 4 {
|
||||
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
||||
} else {
|
||||
panic!("unexpected input shape for conv1d {dims:?}")
|
||||
};
|
||||
let ds = dev.htod_copy(ds).w()?;
|
||||
let params = (el, out_w, out_h, p.stride, p.padding, &ds, inp, k, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
|
@ -1383,12 +1419,14 @@ impl BackendStorage for CudaStorage {
|
|||
|
||||
fn conv2d(
|
||||
&self,
|
||||
_l: &Layout,
|
||||
_kernel: &Self,
|
||||
_kernel_l: &Layout,
|
||||
_params: &crate::conv::ParamsConv2D,
|
||||
l: &Layout,
|
||||
kernel: &Self,
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConv2D,
|
||||
) -> Result<Self> {
|
||||
todo!()
|
||||
let device = self.device().clone();
|
||||
let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
|
|
|
@ -15,9 +15,7 @@ print(res.flatten())
|
|||
res = torch.nn.functional.conv1d(t, w, padding=1)
|
||||
print(res.flatten())
|
||||
*/
|
||||
#[test]
|
||||
fn conv1d() -> Result<()> {
|
||||
let dev = &Device::Cpu;
|
||||
fn conv1d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
0.4056f32, -0.8689, -0.0773, -1.5630, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866, 0.4145,
|
||||
|
@ -51,9 +49,7 @@ fn conv1d() -> Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv1d_small() -> Result<()> {
|
||||
let dev = &Device::Cpu;
|
||||
fn conv1d_small(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?;
|
||||
let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?;
|
||||
let res = t.conv1d(&w, 0, 1)?;
|
||||
|
@ -82,9 +78,7 @@ print(w.flatten())
|
|||
res = torch.nn.functional.conv2d(t, w)
|
||||
print(res.flatten())
|
||||
*/
|
||||
#[test]
|
||||
fn conv2d() -> Result<()> {
|
||||
let dev = &Device::Cpu;
|
||||
fn conv2d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
||||
|
@ -138,9 +132,7 @@ print(w.flatten())
|
|||
res = torch.nn.functional.conv2d(t, w)
|
||||
print(res.flatten())
|
||||
*/
|
||||
#[test]
|
||||
fn conv2d_small() -> Result<()> {
|
||||
let dev = &Device::Cpu;
|
||||
fn conv2d_small(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
0.4056f32, -0.8689, 0.6843, 0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866, 0.4145,
|
||||
|
@ -160,9 +152,7 @@ fn conv2d_small() -> Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv2d_smaller() -> Result<()> {
|
||||
let dev = &Device::Cpu;
|
||||
fn conv2d_smaller(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
0.4056f32, -0.8689, 0.6843, 0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866,
|
||||
|
@ -180,3 +170,9 @@ fn conv2d_smaller() -> Result<()> {
|
|||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(conv1d, conv1d_cpu, conv1d_gpu);
|
||||
test_device!(conv1d_small, conv1d_small_cpu, conv1d_small_gpu);
|
||||
test_device!(conv2d, conv2d_cpu, conv2d_gpu);
|
||||
test_device!(conv2d_small, conv2d_small_cpu, conv2d_small_gpu);
|
||||
test_device!(conv2d_smaller, conv2d_smaller_cpu, conv2d_smaller_gpu);
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
#include "cuda_utils.cuh"
|
||||
#include<stdint.h>
|
||||
|
||||
// Naive implementation of conv1d.
|
||||
template <typename T, typename A>
|
||||
__device__ void conv1d(
|
||||
const size_t src_numel,
|
||||
const size_t l_out,
|
||||
const size_t stride,
|
||||
const size_t stride,
|
||||
const size_t padding,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
const T *kernel,
|
||||
|
@ -19,7 +21,6 @@ __device__ void conv1d(
|
|||
const size_t *k_s = info + 9;
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t k_size = k_dims[2];
|
||||
const size_t k_over_2 = k_size / 2;
|
||||
const size_t c_out = k_dims[0];
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t l_in = src_dims[2];
|
||||
|
@ -32,12 +33,73 @@ __device__ void conv1d(
|
|||
const size_t src_idx0 = b_idx * src_s[0];
|
||||
A d = 0;
|
||||
for (size_t offset = 0; offset < k_size; ++offset) {
|
||||
const size_t src_l_plus = stride * dst_l + offset;
|
||||
if (k_over_2 <= src_l_plus && src_l_plus < l_in + k_over_2) {
|
||||
const size_t src_l = src_l_plus - k_over_2;
|
||||
size_t src_l = stride * dst_l + offset;
|
||||
if (src_l < padding || src_l >= padding + l_in) {
|
||||
continue;
|
||||
}
|
||||
src_l -= padding;
|
||||
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
|
||||
const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_l * src_s[2];
|
||||
const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + offset * k_s[2];
|
||||
d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]);
|
||||
}
|
||||
}
|
||||
dst[dst_i] = static_cast<T>(d);
|
||||
}
|
||||
|
||||
// Naive implementation of conv2d.
|
||||
template <typename T, typename A>
|
||||
__device__ void conv2d(
|
||||
const size_t src_numel,
|
||||
const size_t w_out,
|
||||
const size_t h_out,
|
||||
const size_t stride,
|
||||
const size_t padding,
|
||||
const size_t *info,
|
||||
const T *src,
|
||||
const T *kernel,
|
||||
T *dst
|
||||
) {
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (dst_i >= src_numel) {
|
||||
return;
|
||||
}
|
||||
// src: (b_size, c_in, w_in, h_in)
|
||||
// k: (c_out, c_in, w_k, h_k)
|
||||
const size_t *src_dims = info;
|
||||
const size_t *src_s = info + 4;
|
||||
const size_t *k_dims = info + 8;
|
||||
const size_t *k_s = info + 12;
|
||||
const size_t w_k = k_dims[2];
|
||||
const size_t h_k = k_dims[3];
|
||||
const size_t c_out = k_dims[0];
|
||||
const size_t c_in = src_dims[1];
|
||||
const size_t w_in = src_dims[2];
|
||||
const size_t h_in = src_dims[3];
|
||||
|
||||
// TODO
|
||||
const size_t b_idx = dst_i / (w_out * h_out * c_out);
|
||||
const size_t dst_c_idx = (dst_i / (w_out * h_out)) % c_out;
|
||||
const size_t dst_w = (dst_i / h_out) % w_out;
|
||||
const size_t dst_h = dst_i % h_out;
|
||||
|
||||
const size_t src_idx0 = b_idx * src_s[0];
|
||||
A d = 0;
|
||||
for (size_t w_offset = 0; w_offset < w_k; ++w_offset) {
|
||||
size_t src_w = stride * dst_w + w_offset;
|
||||
if (src_w < padding || src_w >= w_in + padding) {
|
||||
continue;
|
||||
}
|
||||
src_w -= padding;
|
||||
for (size_t h_offset = 0; h_offset < h_k; ++h_offset) {
|
||||
size_t src_h = stride * dst_h + h_offset;
|
||||
if (src_h < padding || src_h >= h_in + padding) {
|
||||
continue;
|
||||
}
|
||||
src_h -= padding;
|
||||
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
|
||||
const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_l * src_s[2];
|
||||
const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + offset * k_s[2];
|
||||
const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3];
|
||||
const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + w_offset * k_s[2] + h_offset * k_s[3];
|
||||
d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]);
|
||||
}
|
||||
}
|
||||
|
@ -51,20 +113,38 @@ extern "C" __global__ void FN_NAME( \
|
|||
const size_t src_numel, \
|
||||
const size_t num_dims, \
|
||||
const size_t stride, \
|
||||
const size_t padding, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
const TYPENAME *kernel, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
conv1d<TYPENAME, TYPEACC>(src_numel, num_dims, stride, info, src, kernel, dst); \
|
||||
conv1d<TYPENAME, TYPEACC>(src_numel, num_dims, stride, padding, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#define CONV2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t src_numel, \
|
||||
const size_t w_out, \
|
||||
const size_t h_out, \
|
||||
const size_t stride, \
|
||||
const size_t padding, \
|
||||
const size_t *info, \
|
||||
const TYPENAME *src, \
|
||||
const TYPENAME *kernel, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
CONV1D_OP(__nv_bfloat16, float, conv1d_bf16)
|
||||
CONV2D_OP(__nv_bfloat16, float, conv2d_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
CONV1D_OP(__half, float, conv1d_f16)
|
||||
CONV2D_OP(__half, float, conv2d_f16)
|
||||
#endif
|
||||
|
||||
CONV1D_OP(float, float, conv1d_f32)
|
||||
|
@ -72,3 +152,8 @@ CONV1D_OP(double, double, conv1d_f64)
|
|||
CONV1D_OP(uint8_t, uint8_t, conv1d_u8)
|
||||
CONV1D_OP(uint32_t, uint32_t, conv1d_u32)
|
||||
|
||||
CONV2D_OP(float, float, conv2d_f32)
|
||||
CONV2D_OP(double, double, conv2d_f64)
|
||||
CONV2D_OP(uint8_t, uint8_t, conv2d_u8)
|
||||
CONV2D_OP(uint32_t, uint32_t, conv2d_u32)
|
||||
|
||||
|
|
Loading…
Reference in New Issue