125 lines
3.4 KiB
Plaintext
125 lines
3.4 KiB
Plaintext
#include "kernels.h"
|
|
#include "kernel_helpers.h"
|
|
#include "flash_fwd_launch_template.h"
|
|
|
|
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
FP16_SWITCH(!params.is_bf16, [&] {
|
|
HEADDIM_SWITCH(params.d, [&] {
|
|
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
|
run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
|
|
});
|
|
});
|
|
});
|
|
}
|
|
|
|
extern "C" void run_mha(
|
|
void *q_ptr,
|
|
void *k_ptr,
|
|
void *v_ptr,
|
|
void *o_ptr,
|
|
void *softmax_lse_ptr,
|
|
void *alibi_slopes_ptr,
|
|
|
|
int32_t *cu_seqlens_q_ptr,
|
|
int32_t *cu_seqlens_k_ptr,
|
|
|
|
uint32_t q_batch_stride,
|
|
uint32_t k_batch_stride,
|
|
uint32_t v_batch_stride,
|
|
uint32_t o_batch_stride,
|
|
uint32_t alibi_slopes_batch_stride,
|
|
|
|
uint32_t q_row_stride,
|
|
uint32_t k_row_stride,
|
|
uint32_t v_row_stride,
|
|
uint32_t o_row_stride,
|
|
|
|
uint32_t q_head_stride,
|
|
uint32_t k_head_stride,
|
|
uint32_t v_head_stride,
|
|
uint32_t o_head_stride,
|
|
|
|
uint32_t b,
|
|
uint32_t h,
|
|
uint32_t h_k,
|
|
uint32_t d,
|
|
uint32_t d_rounded,
|
|
float softmax_scale,
|
|
|
|
uint32_t seqlen_q,
|
|
uint32_t seqlen_k,
|
|
uint32_t seqlen_q_rounded,
|
|
uint32_t seqlen_k_rounded,
|
|
|
|
int is_bf16,
|
|
int is_causal,
|
|
|
|
int window_size_left,
|
|
int window_size_right
|
|
) {
|
|
Flash_fwd_params params;
|
|
// Reset the parameters
|
|
memset(¶ms, 0, sizeof(params));
|
|
|
|
// Set the pointers and strides.
|
|
params.q_ptr = q_ptr;
|
|
params.k_ptr = k_ptr;
|
|
params.v_ptr = v_ptr;
|
|
params.o_ptr = o_ptr;
|
|
|
|
params.softmax_lse_ptr = softmax_lse_ptr;
|
|
params.alibi_slopes_ptr = alibi_slopes_ptr;
|
|
|
|
// All stride are in elements, not bytes.
|
|
params.q_batch_stride = q_batch_stride;
|
|
params.k_batch_stride = k_batch_stride;
|
|
params.v_batch_stride = v_batch_stride;
|
|
params.o_batch_stride = o_batch_stride;
|
|
params.alibi_slopes_batch_stride = alibi_slopes_batch_stride;
|
|
|
|
params.q_row_stride = q_row_stride;
|
|
params.k_row_stride = k_row_stride;
|
|
params.v_row_stride = v_row_stride;
|
|
params.o_row_stride = o_row_stride;
|
|
params.q_head_stride = q_head_stride;
|
|
params.k_head_stride = k_head_stride;
|
|
params.v_head_stride = v_head_stride;
|
|
params.o_head_stride = o_head_stride;
|
|
|
|
// Set the dimensions.
|
|
params.b = b;
|
|
params.h = h;
|
|
params.h_k = h_k;
|
|
params.h_h_k_ratio = h / h_k;
|
|
params.seqlen_q = seqlen_q;
|
|
params.seqlen_k = seqlen_k;
|
|
params.seqlen_q_rounded = seqlen_q_rounded;
|
|
params.seqlen_k_rounded = seqlen_k_rounded;
|
|
params.d = d;
|
|
params.d_rounded = d_rounded;
|
|
|
|
// Set the different scale values.
|
|
params.scale_softmax = softmax_scale;
|
|
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
|
|
|
|
params.p_dropout = 1.; // probability to keep
|
|
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
|
|
params.rp_dropout = 1.f / params.p_dropout;
|
|
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
|
|
params.is_bf16 = is_bf16;
|
|
params.cu_seqlens_q = cu_seqlens_q_ptr;
|
|
params.cu_seqlens_k = cu_seqlens_k_ptr;
|
|
params.p_ptr = nullptr; // used for `return_softmax`.
|
|
params.seqused_k = nullptr;
|
|
|
|
params.is_causal = is_causal;
|
|
params.window_size_left = window_size_left;
|
|
params.window_size_right = window_size_right;
|
|
|
|
params.is_seqlens_k_cumulative = true;
|
|
params.num_splits = 1;
|
|
|
|
cudaStream_t stream = 0; // Use the default stream.
|
|
run_mha_fwd(params, stream);
|
|
}
|