Skip to content

Support cum_seq_lens_q in trtllm_batch_decode_with_kv_cache_mla trtllm-gen backend #3131

@nvjullin

Description

@nvjullin

Related #3125.

trtllm_batch_decode_with_kv_cache_mla currently does not expose cum_seq_lens_q in the API, which is used in several places.

  1. trtllm_batch_decode_with_kv_cache_mla is also used in prefill, query is forced to be reshaped into (num_tokens, 1, num_heads, head_dim), which will cause trtllm-gen attention error on large batch size #3125.
  2. Speculative decoding with variable speculative tokens

The underlying trtllm-gen kernel launcher/kernel supports cum_seq_lens_q

void trtllm_paged_attention_decode(
TensorView out, Optional<TensorView> out_scale_factor, TensorView query, TensorView key_cache,
TensorView value_cache, TensorView workspace_buffer, TensorView block_tables,
TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len,
Variant<double, ffi::Tensor> bmm1_scale, Variant<double, ffi::Tensor> bmm2_scale,
double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t batch_size,
int64_t window_left, int64_t sparse_mla_top_k, int64_t sm_count, bool enable_pdl,
int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<TensorView> cum_seq_lens_q, Optional<TensorView> key_block_scales,
Optional<TensorView> value_block_scales, Optional<float> skip_softmax_threshold_scale_factor,
Optional<bool> uses_shared_paged_kv_idx) {

but the API doesn't
def trtllm_batch_decode_with_kv_cache_mla(
query: torch.Tensor,
kv_cache: torch.Tensor,
workspace_buffer: torch.Tensor,
qk_nope_head_dim: int, # TODO: remove in 1.0?
kv_lora_rank: int,
qk_rope_head_dim: int,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
sparse_mla_top_k: int = 0,
out: Optional[torch.Tensor] = None,
bmm1_scale: Union[float, torch.Tensor] = 1.0,
bmm2_scale: Union[float, torch.Tensor] = 1.0,
sinks: Optional[List[torch.Tensor]] = None,
skip_softmax_threshold_scale_factor: Optional[float] = None,
enable_pdl: bool | None = None,
backend: str = "auto",
is_var_seq: bool = True,
uses_shared_paged_kv_idx: bool = True,
) -> torch.Tensor:

run_func(
out,
None, # fp4 output not supported in wrapper api yet.
query,
kv_cache,
kv_cache,
workspace_buffer,
block_tables,
seq_lens,
max_q_len,
max_seq_len,
bmm1_scale,
bmm2_scale,
-1, # o_sf_scale
-1, # o_sf_vec_size
0, # o_sf_start_index
batch_size,
-1, # window_left
sparse_mla_top_k,
sm_count,
enable_pdl,
workspace_buffer.numel() * workspace_buffer.element_size(),
sinks,
None, # cum_seq_lens_q
None, # key_block_scales
None, # value_block_scales
skip_softmax_threshold_scale_factor,
uses_shared_paged_kv_idx,
)

There is an argument is_var_seq, which is only used in cute-dsl backend.

Metadata

Metadata

Assignees

Type

No type

Projects

Status

In Progress

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions