Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d587f21
Support prefill in Cudagraph
littledgg Aug 18, 2025
0e5e305
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
littledgg Aug 18, 2025
63f1256
Refactor GetBlockShapeAndSplitKVBlock Kernel V2
littledgg Aug 19, 2025
109d4f2
Refactor GetBlockShapeAndSplitKVBlock Kernel V2.1
littledgg Aug 20, 2025
7ba1c17
Refactor GetBlockShapeAndSplitKVBlock Kernel V2.2
littledgg Aug 20, 2025
1e60bfd
Refactor GetBlockShapeAndSplitKVBlock Kernel V2.3
littledgg Aug 20, 2025
fd41ba3
Refactor GetBlockShapeAndSplitKVBlock Kernel V2.4
littledgg Aug 20, 2025
5d52620
Refactor GetBlockShapeAndSplitKVBlock Kernel V2.5
littledgg Aug 20, 2025
39dd695
Solve problem about encoder_num_blocks_x_cpu
littledgg Aug 22, 2025
4bc60cc
Add early-exit mechanism for attention kernel
littledgg Aug 25, 2025
d267feb
fix test case about append-attention
littledgg Aug 25, 2025
2a40480
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
littledgg Aug 25, 2025
edbbbf0
Update testcode, Add annotations to related tensors
littledgg Aug 26, 2025
86d1a94
solve conflict
littledgg Aug 26, 2025
cbca27d
move get_input_length_list
littledgg Aug 26, 2025
3f40024
solve conflict
littledgg Aug 26, 2025
62a3f05
solve test_code
littledgg Aug 27, 2025
0f92c63
Add annotations about early-exit for attention kernel
littledgg Aug 27, 2025
f57a1ca
Add annotations about early-exit for attention kernel2
littledgg Aug 27, 2025
684f6c9
solve comment
littledgg Aug 27, 2025
edb923c
solve conflict
littledgg Sep 1, 2025
960c6c9
Merge branch 'develop' into prefill_in_cudagraph
gongshaotian Sep 2, 2025
5df6078
solve conflict and test case
littledgg Sep 4, 2025
c7b3d5b
solve conflict
littledgg Sep 4, 2025
8f8011d
solve mtp
littledgg Sep 5, 2025
e49ff4a
Merge branch 'develop' into prefill_in_cudagraph
littledgg Sep 5, 2025
d15e4fd
Merge branch 'develop' into prefill_in_cudagraph
littledgg Sep 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要跑一下deepseek的模型,确认下 mla attention 能跑prefill 捕获吗

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下一个PR再支持吧

Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ __global__ void multi_query_append_attention_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
Expand All @@ -74,6 +75,11 @@ __global__ void multi_query_append_attention_kernel(

block_table_now = block_table + batch_id * max_block_num_per_seq;

//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}

const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
Expand Down Expand Up @@ -422,6 +428,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
Expand All @@ -445,6 +452,11 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;

//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}

const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
Expand Down Expand Up @@ -902,6 +914,7 @@ void MultiQueryAppendAttention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
Expand Down Expand Up @@ -960,6 +973,7 @@ void MultiQueryAppendAttention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
Expand Down Expand Up @@ -1134,6 +1148,7 @@ void MultiQueryAppendAttention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
Expand Down Expand Up @@ -1206,6 +1221,7 @@ void MultiQueryAppendAttention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
Expand Down
16 changes: 16 additions & 0 deletions custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ __global__ void multi_query_append_attention_c4_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
Expand Down Expand Up @@ -85,6 +86,11 @@ __global__ void multi_query_append_attention_c4_kernel(

block_table_now = block_table + batch_id * max_block_num_per_seq;

//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}

const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
Expand Down Expand Up @@ -520,6 +526,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
Expand Down Expand Up @@ -549,6 +556,11 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;

//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}

const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
Expand Down Expand Up @@ -1107,6 +1119,7 @@ void MultiQueryAppendC4Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
Expand Down Expand Up @@ -1171,6 +1184,7 @@ void MultiQueryAppendC4Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
Expand Down Expand Up @@ -1365,6 +1379,7 @@ void MultiQueryAppendC4Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
Expand Down Expand Up @@ -1445,6 +1460,7 @@ void MultiQueryAppendC4Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
Expand Down
16 changes: 16 additions & 0 deletions custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ __global__ void multi_query_append_attention_c8_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
Expand Down Expand Up @@ -86,6 +87,11 @@ __global__ void multi_query_append_attention_c8_kernel(

block_table_now = block_table + batch_id * max_block_num_per_seq;

//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}

const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
Expand Down Expand Up @@ -489,6 +495,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
Expand Down Expand Up @@ -518,6 +525,11 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;

//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
return;
}

const uint32_t q_len = seq_lens[batch_id];
if (q_len <= 0) {
return;
Expand Down Expand Up @@ -1075,6 +1087,7 @@ void MultiQueryAppendC8Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
Expand Down Expand Up @@ -1133,6 +1146,7 @@ void MultiQueryAppendC8Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
Expand Down Expand Up @@ -1350,6 +1364,7 @@ void MultiQueryAppendC8Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
Expand Down Expand Up @@ -1424,6 +1439,7 @@ void MultiQueryAppendC8Attention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,21 @@ get_max_len_kv_ernel(int *max_seq_lens_out, const int *seq_lens_this_time,
}
}

std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
void GetBlockShapeAndSplitKVBlock(
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
paddle::Tensor &decoder_batch_ids, // Inplace
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU
paddle::Tensor &encoder_batch_ids, // Inplace
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, CPU
paddle::Tensor &kv_batch_ids, // Inplace
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU
paddle::Tensor &max_len_kv_cpu, // Inplace, CPU
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
Expand All @@ -223,31 +230,22 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
int max_system_len = max_len_cpu_ptr[6];
int max_just_dec_len_without_system = max_len_cpu_ptr[7];

paddle::Tensor encoder_batch_ids;
paddle::Tensor encoder_tile_ids_per_batch;
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
paddle::Tensor kv_batch_ids;
paddle::Tensor kv_tile_ids_per_batch;
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
paddle::Tensor max_len_kv_cpu; /*cpu*/


auto max_len_kv =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
get_max_len_kv_ernel<128><<<1, 128, 0, stream>>>(
max_len_kv.data<int>(), seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(), bsz);

max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false);

max_len_kv_cpu.copy_(max_len_kv, max_len_kv_cpu.place(), false);

if (max_enc_len_this_time > 0) {
const uint32_t max_tile_size_per_bs_kv =
div_up(max_enc_dec_len_this_time, block_size);
kv_batch_ids =
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
seq_lens_encoder.place());
kv_tile_ids_per_batch =
GetEmptyTensor({bsz * max_tile_size_per_bs_kv}, paddle::DataType::INT32,
seq_lens_encoder.place());
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_batch_ids.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_tile_ids_per_batch.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
auto kv_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());

Expand All @@ -258,38 +256,20 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
kv_tile_ids_per_batch.data<int>(), kv_num_blocks_x.data<int>(), bsz,
block_size, block_size);

kv_num_blocks_x_cpu = kv_num_blocks_x.copy_to(paddle::CPUPlace(), false);

const uint32_t encoder_max_tile_size_per_bs_q =
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
encoder_batch_ids =
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
paddle::DataType::INT32, seq_lens_encoder.place());
encoder_tile_ids_per_batch =
GetEmptyTensor({bsz * encoder_max_tile_size_per_bs_q},
paddle::DataType::INT32, seq_lens_encoder.place());
kv_num_blocks_x_cpu.copy_(kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false);
// Clear buffer
const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_batch_ids.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_tile_ids_per_batch.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
auto encoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(), nullptr,
encoder_batch_ids.data<int>(),
encoder_tile_ids_per_batch.data<int>(),
encoder_num_blocks_x.data<int>(), bsz,
encoder_block_shape_q, group_size);
encoder_num_blocks_x_cpu =
encoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
} else {
encoder_batch_ids =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
encoder_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
encoder_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
kv_batch_ids =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
kv_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
kv_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
}

if (max_just_dec_len_this_time > 0) {
Expand All @@ -314,15 +294,6 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
}

return {
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks_x_cpu, /*cpu*/
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks_x_cpu, /*cpu*/
max_len_kv_cpu, /*cpu*/
};
}

PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
Expand All @@ -333,16 +304,17 @@ PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
"decoder_batch_ids",
"decoder_tile_ids_per_batch",
"decoder_num_blocks_x_cpu",
"max_len_tensor_cpu"
"max_len_tensor_cpu",
"encoder_batch_ids",
"encoder_tile_ids_per_batch",
"encoder_num_blocks_x_cpu",
"kv_batch_ids",
"kv_tile_ids_per_batch",
"kv_num_blocks_x_cpu",
"max_len_kv_cpu"
})
.Outputs({
paddle::Optional("encoder_batch_ids"),
paddle::Optional("encoder_tile_ids_per_batch"),
paddle::Optional("encoder_num_blocks_x_cpu"),
paddle::Optional("kv_batch_ids"),
paddle::Optional("kv_tile_ids_per_batch"),
paddle::Optional("kv_num_blocks_x_cpu"),
"max_len_kv_cpu"

})
.Attrs({
"encoder_block_shape_q: int",
Expand Down
9 changes: 8 additions & 1 deletion custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,21 @@ paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank, const int device_id,
paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor &kv_signal_metadata,
const int layer_id);

std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
void GetBlockShapeAndSplitKVBlock(
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
paddle::Tensor &decoder_batch_ids, // Inplace
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
paddle::Tensor &encoder_batch_ids, // Inplace
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &kv_batch_ids, // Inplace
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &max_len_kv_cpu, // Inplace, Pinned Memory
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
Expand Down
Loading
Loading