Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
323 changes: 135 additions & 188 deletions custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1061,207 +1061,154 @@ void MultiQueryAppendAttention(
if (!is_decoder) {
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
}
const int num_chunks = div_up(max_dec_len, chunk_size);

const int num_chunks = div_up(max_seq_len, chunk_size);
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
dim3 blocks(32, num_warps);

if (num_chunks <= 1) {
auto nosplit_kv_kernel =
multi_query_append_attention_warp1_4_kernel<NV_TYPE,
false,
GROUP_SIZE,
CAUSAL,
num_warps,
NUM_WARP_Q,
NUM_WARP_KV,
HEAD_DIM,
BLOCK_SIZE,
num_frags_x,
num_frags_z,
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL>;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(nosplit_kv_kernel,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
}

nosplit_kv_kernel<<<grids, blocks, smem_size, stream>>>(
reinterpret_cast<NV_TYPE *>(const_cast<T *>(qkv.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
max_block_num_per_seq,
scale,
quant_max_bound,
quant_min_bound,
in_scale,
chunk_size,
nullptr,
nullptr,
nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
tmp_workspace = allocator->Allocate(
phi::SizeOf(qkv.dtype()) *
static_cast<size_t>(bsz * num_chunks * num_heads * HEAD_DIM));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(bsz * num_chunks * num_heads));
tmp_d = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(bsz * num_chunks * num_heads));
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
if (ENABLE_PREFILL) {
tmp_workspace =
allocator->Allocate(phi::SizeOf(qkv.dtype()) *
static_cast<size_t>(token_num * num_chunks *
num_heads * HEAD_DIM));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(token_num * num_chunks * num_heads));
tmp_d = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(token_num * num_chunks * num_heads));
} else {
tmp_workspace = allocator->Allocate(
phi::SizeOf(qkv.dtype()) *
static_cast<size_t>(bsz * num_chunks * num_heads * HEAD_DIM));
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads * HEAD_DIM));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(bsz * num_chunks * num_heads));
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads));
tmp_d = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(bsz * num_chunks * num_heads));
} else {
if (ENABLE_PREFILL) {
tmp_workspace =
allocator->Allocate(phi::SizeOf(qkv.dtype()) *
static_cast<size_t>(token_num * num_chunks *
num_heads * HEAD_DIM));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(token_num * num_chunks * num_heads));
tmp_d = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(token_num * num_chunks * num_heads));
} else {
tmp_workspace = allocator->Allocate(
phi::SizeOf(qkv.dtype()) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads * HEAD_DIM));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads));
tmp_d = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads));
}
}
split_kv_kernel<<<grids, blocks, smem_size, stream>>>(
reinterpret_cast<NV_TYPE *>(const_cast<T *>(qkv.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
max_block_num_per_seq,
scale,
quant_max_bound,
quant_min_bound,
in_scale,
chunk_size,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);

// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(bsz, num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_decoder_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM);
} else {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM,
token_num,
speculate_max_draft_token_num);
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads));
}
}
split_kv_kernel<<<grids, blocks, smem_size, stream>>>(
reinterpret_cast<NV_TYPE *>(const_cast<T *>(qkv.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
max_block_num_per_seq,
scale,
quant_max_bound,
quant_min_bound,
in_scale,
chunk_size,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
speculate_max_draft_token_num);

// merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(bsz, num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_decoder_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM);
} else {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM,
token_num,
speculate_max_draft_token_num);
}
}
}

Expand Down
Loading
Loading