Related #3125 .
trtllm_batch_decode_with_kv_cache_mla currently does not expose cum_seq_lens_q in the API, which is used in several places.
trtllm_batch_decode_with_kv_cache_mla is also used in prefill, query is forced to be reshaped into (num_tokens, 1, num_heads, head_dim), which will cause trtllm-gen attention error on large batch size #3125 .
Speculative decoding with variable speculative tokens
The underlying trtllm-gen kernel launcher/kernel supports cum_seq_lens_q
void trtllm_paged_attention_decode (
TensorView out, Optional<TensorView> out_scale_factor, TensorView query, TensorView key_cache,
TensorView value_cache, TensorView workspace_buffer, TensorView block_tables,
TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len,
Variant<double , ffi::Tensor> bmm1_scale, Variant<double , ffi::Tensor> bmm2_scale,
double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t batch_size,
int64_t window_left, int64_t sparse_mla_top_k, int64_t sm_count, bool enable_pdl,
int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<TensorView> cum_seq_lens_q, Optional<TensorView> key_block_scales,
Optional<TensorView> value_block_scales, Optional<float > skip_softmax_threshold_scale_factor,
Optional<bool > uses_shared_paged_kv_idx) {
but the API doesn't
def trtllm_batch_decode_with_kv_cache_mla (
query : torch .Tensor ,
kv_cache : torch .Tensor ,
workspace_buffer : torch .Tensor ,
qk_nope_head_dim : int , # TODO: remove in 1.0?
kv_lora_rank : int ,
qk_rope_head_dim : int ,
block_tables : torch .Tensor ,
seq_lens : torch .Tensor ,
max_seq_len : int ,
sparse_mla_top_k : int = 0 ,
out : Optional [torch .Tensor ] = None ,
bmm1_scale : Union [float , torch .Tensor ] = 1.0 ,
bmm2_scale : Union [float , torch .Tensor ] = 1.0 ,
sinks : Optional [List [torch .Tensor ]] = None ,
skip_softmax_threshold_scale_factor : Optional [float ] = None ,
enable_pdl : bool | None = None ,
backend : str = "auto" ,
is_var_seq : bool = True ,
uses_shared_paged_kv_idx : bool = True ,
) -> torch .Tensor :
run_func (
out ,
None , # fp4 output not supported in wrapper api yet.
query ,
kv_cache ,
kv_cache ,
workspace_buffer ,
block_tables ,
seq_lens ,
max_q_len ,
max_seq_len ,
bmm1_scale ,
bmm2_scale ,
- 1 , # o_sf_scale
- 1 , # o_sf_vec_size
0 , # o_sf_start_index
batch_size ,
- 1 , # window_left
sparse_mla_top_k ,
sm_count ,
enable_pdl ,
workspace_buffer .numel () * workspace_buffer .element_size (),
sinks ,
None , # cum_seq_lens_q
None , # key_block_scales
None , # value_block_scales
skip_softmax_threshold_scale_factor ,
uses_shared_paged_kv_idx ,
)
There is an argument is_var_seq, which is only used in cute-dsl backend.
Related #3125.
trtllm_batch_decode_with_kv_cache_mlacurrently does not exposecum_seq_lens_qin the API, which is used in several places.trtllm_batch_decode_with_kv_cache_mlais also used in prefill,queryis forced to be reshaped into(num_tokens, 1, num_heads, head_dim), which will cause trtllm-gen attention error on large batch size #3125.The underlying trtllm-gen kernel launcher/kernel supports
cum_seq_lens_qflashinfer/csrc/trtllm_fmha_kernel_launcher.cu
Lines 237 to 247 in a99ee72
but the API doesn't
flashinfer/flashinfer/mla/_core.py
Lines 592 to 612 in 9e3d8b9
flashinfer/flashinfer/mla/_core.py
Lines 770 to 798 in 9e3d8b9
There is an argument
is_var_seq, which is only used in cute-dsl backend.