Skip to content
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
8351e83
success run ngram
gongshaotian Aug 20, 2025
02e8384
Revert "[Code Simplification] remove cum_offsets (#3410)"
lizexu123 Aug 21, 2025
1444ba6
success run ngram5 tp4 42bs
lizexu123 Aug 21, 2025
892c0c2
success run ngram5 tp4 42bs
lizexu123 Aug 22, 2025
18d9823
merge develop
gongshaotian Sep 1, 2025
64ea2f7
mtp draft commit
gongshaotian Sep 2, 2025
3263006
enable target model in cuda graph
littledgg Sep 8, 2025
5b75ade
Merge pull request #1 from littledgg/mtp
gongshaotian Sep 8, 2025
4772a4f
add decorator for target model
zeroRains Sep 9, 2025
4a0a6df
enable draft model in cudagraph v0.5
littledgg Sep 10, 2025
ec4a2df
revert revrt cum_offset
littledgg Sep 12, 2025
529214c
Merge pull request #3 from littledgg/mtp
gongshaotian Sep 12, 2025
2dd98da
enable target model in cudagraph v0.9 And clean debug code
littledgg Sep 12, 2025
1d3ef67
Revert "success run ngram"
Sep 12, 2025
349988f
add reverted code
Sep 12, 2025
15d3103
enable target model in cudagraph v0.9
littledgg Sep 12, 2025
7f11653
solve comment
littledgg Sep 12, 2025
bb9c911
Merge pull request #4 from littledgg/mtp
gongshaotian Sep 12, 2025
d1115a7
merge remote mtp
Sep 12, 2025
77e64ed
merge develop & solve conflict
Sep 15, 2025
235b0ba
fix bid < 0
Sep 16, 2025
3516be4
Enable Target Model Padding And Draft Model in cudagraph
littledgg Sep 16, 2025
c6cdc17
Merge branch 'mtp' of https://github.com/gongshaotian/FastDeploy into…
littledgg Sep 16, 2025
167fb58
solve problem
littledgg Sep 16, 2025
4c10571
Merge pull request #5 from littledgg/mtp
gongshaotian Sep 16, 2025
89c6c83
delete rebuild padding debug note
Sep 16, 2025
fdf49de
Merge branch 'mtp' of https://github.com/gongshaotian/FastDeploy into…
Sep 16, 2025
834639a
fast compile
Sep 17, 2025
4c09b0b
Add capture list for mtp
littledgg Sep 17, 2025
9f71c0e
Merge pull request #6 from littledgg/mtp
gongshaotian Sep 18, 2025
fc6ce99
success run 256 tp1 mtp
Sep 18, 2025
8c306d8
Enable Lite TP2 Bsz256
littledgg Sep 18, 2025
cf01a97
Merge pull request #7 from littledgg/mtp
gongshaotian Sep 19, 2025
e28327a
realy enable tp2 bsz 256
littledgg Sep 22, 2025
a44e2d9
fix problem
littledgg Sep 22, 2025
00de438
Merge pull request #8 from littledgg/mtp
gongshaotian Sep 23, 2025
678152f
Solve problem for Draft model in cudagraph
littledgg Sep 23, 2025
d841cc6
Solve comment
littledgg Sep 23, 2025
3bf990c
Merge pull request #9 from littledgg/mtp
gongshaotian Sep 24, 2025
2eaf778
replace emptytensor as zeros
Sep 24, 2025
24fa8cb
Merge branch 'mtp' of https://github.com/gongshaotian/FastDeploy into…
Sep 24, 2025
1a4190b
Solve comments
littledgg Sep 24, 2025
d3e7df9
Revert "fast compile"
littledgg Sep 24, 2025
96d85a0
Merge pull request #10 from littledgg/mtp
gongshaotian Sep 24, 2025
1c23a3e
merge develop
littledgg Sep 24, 2025
f814026
Merge pull request #11 from littledgg/mtp
gongshaotian Sep 24, 2025
beaaaec
Merge branch 'mtp' of https://github.com/gongshaotian/FastDeploy into…
Sep 24, 2025
ce11adb
fix bug
littledgg Sep 24, 2025
4c06088
Merge pull request #12 from littledgg/mtp
gongshaotian Sep 24, 2025
c885ba6
Merge branch 'mtp' of https://github.com/gongshaotian/FastDeploy into…
Sep 24, 2025
d23206c
fix merge bug
Sep 24, 2025
4a8f947
fix typo
Sep 25, 2025
2137520
fix bug
Sep 25, 2025
a4323aa
merge develop
Oct 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,7 @@ void MultiQueryAppendAttention(
} else {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
dim3 grids_merge(sm_count * 4,
num_heads); // 128k is too large
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
Expand Down Expand Up @@ -1268,7 +1268,7 @@ void MultiQueryAppendAttention(
} else {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
dim3 grids_merge(sm_count * 4,
num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
Expand Down
4 changes: 2 additions & 2 deletions custom_ops/gpu_ops/append_attn/append_attention_c4_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1229,7 +1229,7 @@ void MultiQueryAppendC4Attention(
} else {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
dim3 grids_merge(sm_count * 4,
num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
Expand Down Expand Up @@ -1506,7 +1506,7 @@ void MultiQueryAppendC4Attention(
} else {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
dim3 grids_merge(sm_count * 4,
num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
Expand Down
4 changes: 2 additions & 2 deletions custom_ops/gpu_ops/append_attn/append_attention_c8_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1275,7 +1275,7 @@ void MultiQueryAppendC8Attention(
} else {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
dim3 grids_merge(sm_count * 4,
num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
Expand Down Expand Up @@ -1573,7 +1573,7 @@ void MultiQueryAppendC8Attention(
} else {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
dim3 grids_merge(sm_count * 4,
num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
Expand Down
5 changes: 5 additions & 0 deletions custom_ops/gpu_ops/append_attn/append_attention_func.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2418,6 +2418,9 @@ __global__ void merge_multi_chunks_v2_kernel(
__shared__ float md_smem[bdy * 2];
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
const uint32_t bid = batch_id_per_token[qid];
if(bid == -1){
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意下编码规范

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意下编码规范

这里能把 bid 从 uint32_t 切换成 int 吗?取值范围变小了有无风险?

continue;
}
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
const int seq_len_q = seq_lens_q[bid];
if (seq_len_q == 0) continue;
Expand All @@ -2437,6 +2440,8 @@ __global__ void merge_multi_chunks_v2_kernel(
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
if (num_chunks_this_seq <= 1) {
continue;
}else if (!ENABLE_PREFILL){
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

continue;
}

using LoadT = AlignedVector<T, vec_size>;
Expand Down
214 changes: 107 additions & 107 deletions custom_ops/gpu_ops/append_attn/append_attention_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,113 +232,113 @@ void CascadeAppendAttentionKernel(
enable_prefill,
stream,
out);
} else if (cache_quant_type_str == "cache_int8") {
CascadeAppendAttentionC8Kernel<T, OutT>(meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
shift_bias,
smooth_weight,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
block_shape_q,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
is_decoder,
enable_prefill,
cache_quant_type_str,
stream,
out);
} else if (cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
CascadeAppendAttentionC8Kernel<T, OutT, true>(meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
shift_bias,
smooth_weight,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
block_shape_q,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
is_decoder,
enable_prefill,
cache_quant_type_str,
stream,
out);
} else if (cache_quant_type_str == "cache_int4_zp") {
CascadeAppendAttentionC4Kernel<T, OutT>(meta_data,
qkv,
cache_k,
cache_v,
attn_mask,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
shift_bias,
smooth_weight,
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
num_blocks,
block_shape_q,
max_seq_len,
max_dec_len,
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
is_decoder,
enable_prefill,
stream,
out);
// } else if (cache_quant_type_str == "cache_int8") {
// CascadeAppendAttentionC8Kernel<T, OutT>(meta_data,
// qkv,
// cache_k,
// cache_v,
// attn_mask,
// cache_k_scale,
// cache_v_scale,
// cache_k_zp,
// cache_v_zp,
// shift_bias,
// smooth_weight,
// seq_lens_q,
// seq_lens_kv,
// seq_lens_encoder,
// batch_id_per_token,
// cu_seqlens_q,
// block_table,
// batch_ids,
// tile_ids_per_batch,
// num_blocks,
// block_shape_q,
// max_seq_len,
// max_dec_len,
// quant_max_bound,
// quant_min_bound,
// in_scale,
// max_partition_size,
// encoder_max_partition_size,
// speculate_max_draft_token_num,
// causal,
// is_decoder,
// enable_prefill,
// cache_quant_type_str,
// stream,
// out);
// } else if (cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
// CascadeAppendAttentionC8Kernel<T, OutT, true>(meta_data,
// qkv,
// cache_k,
// cache_v,
// attn_mask,
// cache_k_scale,
// cache_v_scale,
// cache_k_zp,
// cache_v_zp,
// shift_bias,
// smooth_weight,
// seq_lens_q,
// seq_lens_kv,
// seq_lens_encoder,
// batch_id_per_token,
// cu_seqlens_q,
// block_table,
// batch_ids,
// tile_ids_per_batch,
// num_blocks,
// block_shape_q,
// max_seq_len,
// max_dec_len,
// quant_max_bound,
// quant_min_bound,
// in_scale,
// max_partition_size,
// encoder_max_partition_size,
// speculate_max_draft_token_num,
// causal,
// is_decoder,
// enable_prefill,
// cache_quant_type_str,
// stream,
// out);
// } else if (cache_quant_type_str == "cache_int4_zp") {
// CascadeAppendAttentionC4Kernel<T, OutT>(meta_data,
// qkv,
// cache_k,
// cache_v,
// attn_mask,
// cache_k_scale,
// cache_v_scale,
// cache_k_zp,
// cache_v_zp,
// shift_bias,
// smooth_weight,
// seq_lens_q,
// seq_lens_kv,
// seq_lens_encoder,
// batch_id_per_token,
// cu_seqlens_q,
// block_table,
// batch_ids,
// tile_ids_per_batch,
// num_blocks,
// block_shape_q,
// max_seq_len,
// max_dec_len,
// quant_max_bound,
// quant_min_bound,
// in_scale,
// max_partition_size,
// encoder_max_partition_size,
// speculate_max_draft_token_num,
// causal,
// is_decoder,
// enable_prefill,
// stream,
// out);
} else {
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, "
Expand Down
Loading