Add a dedicated cuda kernel for softmax. (#746)

This commit is contained in:
Laurent Mazare 2023-09-05 17:53:20 +02:00 committed by GitHub
parent 6615daf242
commit 94c6a8d3d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 55 additions and 0 deletions

View File

@ -49,6 +49,50 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block,
dst[dst_id] = shr[0];
}
// Softmax implementation adapted from ggml.
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L4159
template <typename T, typename ACC>
__device__ void softmax(const T * x, T * dst, const int ncols) {
const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int block_size = blockDim.y;
const int tid = threadIdx.y;
T max_val = -INFINITY;
for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
max_val = maxg(max_val, x[i]);
}
// find the max value in the block
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
max_val = maxg(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
}
ACC tmp = 0.;
for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
const T val = expg(x[i] - max_val);
tmp += static_cast<ACC>(val);
dst[i] = val;
}
// sum up partial sums
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
const ACC inv_tmp = 1. / tmp;
for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
dst[i] *= inv_tmp;
}
}
template <typename T>
__device__ void
fast_max(const size_t src_numel, const size_t el_to_sum_per_block,
@ -290,12 +334,21 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
} \
}
#define SOFTMAX_OP(TYPENAME, ACC_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const TYPENAME *src, TYPENAME *dst, \
const int n_cols) { \
softmax<TYPENAME, ACC_TYPENAME>(src, dst, n_cols); \
} \
#if __CUDA_ARCH__ >= 800
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
SUM_OP(__nv_bfloat16, sum_bf16)
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
#endif
#if __CUDA_ARCH__ >= 530
SOFTMAX_OP(__half, float, softmax_f16)
SUM_OP(__half, sum_f16)
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
#endif
@ -303,6 +356,8 @@ FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fa
SUM_OP(float, sum_f32)
SUM_OP(double, sum_f64)
SUM_OP(uint32_t, sum_u32)
SOFTMAX_OP(float, float, softmax_f32)
SOFTMAX_OP(double, double, softmax_f64)
FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32)
FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)