Skip to content

Commit 8a0efa8

Browse files
committed
[Feature] V4 KVCompressor: compressed-kv RoPE + per-ratio cu_seq_lens_out reuse
Two changes that co-evolve the KVCompressor.forward signature (compressed-rope table + precomputed boundaries are added side by side), so they land together. 1. Compressed-kv RoPE. After the chunk softmax + norm, rotate each compressed chunk's rope tail at its window-center position, mirroring HF DeepseekV4{CSA,HCA}Compressor.forward. ``qk_rope_head_dim`` is wired from the DSA/Indexer configs into the internal KVCompressor, and DSA now forwards ``position_embeddings_compressed`` to the compressor (required for compress_ratio > 0, not just == 4). The chunk->sample map uses ``searchsorted(cu_seq_lens_out, ., right=True) - 1`` — right=True is load-bearing: a chunk on a sample boundary is the first chunk of the next sample, and mapping it to the previous one overruns ``first_token_per_chunk`` and indexes the rope table out of bounds. 2. Hoist cu_seq_lens_out. ``KVCompressor.build_cu_seq_lens_out`` computes the per-sample compressed boundaries once; DeepSeekV4 forward builds one per distinct compress_ratio and caches it on ``SequenceContext.compressed_cu_seq_lens``, so every decoder layer of that ratio reuses a single cumsum + H2D instead of recomputing it. ``total_c`` stays derived in the compressor from the CPU mirror (it must remain a Python int and would force a recompile if threaded through the compiled attn graph). Standalone callers (no cache on seq_ctx) fall back to building it in-place.
1 parent 97e922b commit 8a0efa8

5 files changed

Lines changed: 181 additions & 27 deletions

File tree

xtuner/v1/data_proto/sequence_context.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ class SequenceContext:
3535
# consumers must tolerate fallback to a GPU ``.item()`` path in that case.
3636
cu_seq_lens_q_cpu: torch.Tensor | None
3737
cu_seq_lens_k_cpu: torch.Tensor | None
38+
# Optional per-``compress_ratio`` cache of compressed-chunk cumulative boundaries
39+
# (``{ratio: cu_seq_lens_out}``), populated by chunk-compression models (DeepSeek-V4)
40+
# at the start of forward via ``KVCompressor.build_cu_seq_lens_out`` so every decoder
41+
# layer of a given ratio reuses one cumsum + H2D instead of recomputing it. ``None`` for
42+
# models that don't compress; not a constructor argument (set post-construction, like
43+
# ``seq_idx``) so it stays out of the generic SequenceContext contract.
44+
compressed_cu_seq_lens: dict[int, torch.Tensor] | None
3845
max_length_q: torch.Tensor
3946
max_length_k: torch.Tensor
4047
num_padding: int
@@ -130,6 +137,9 @@ def __init__(
130137
self._shard_start = shard_start
131138
self._shard_size = shard_size
132139
self.seq_idx = None
140+
# Populated lazily by the model forward (chunk-compression models only); see the
141+
# field declaration above.
142+
self.compressed_cu_seq_lens = None
133143

134144
# `DeviceMesh.get_local_rank` is not compatible with `torch.compile`, we calculate `_sp_rank` in
135145
# `SequenceContext`

xtuner/v1/model/moe/deepseek_v4.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from transformers import AutoConfig
3232
from xtuner.v1.module import HashRouterConfig, NoAuxRouterConfig
3333
from xtuner.v1.module.attention.dsa import DSAConfig
34+
from xtuner.v1.module.attention.kv_compressor import KVCompressor
3435
from xtuner.v1.module.decoder_layer.deepseek_v4_decoder_layer import V4DecoderLayer
3536
from xtuner.v1.module.decoder_layer.hc_block import HCWrapperConfig
3637
from xtuner.v1.module.decoder_layer.moe_decoder_layer import (
@@ -506,6 +507,10 @@ def build_layers(self, config: MoEConfig) -> nn.ModuleDict:
506507
f"compress_ratios (len={len(compress_ratios) if compress_ratios else 0}) must cover "
507508
f"all {v4_cfg.num_hidden_layers} hidden layers"
508509
)
510+
# Distinct positive compress_ratios across the stack. The model forward builds one
511+
# ``cu_seq_lens_out`` per ratio and caches it on the SequenceContext, so every layer of
512+
# that ratio reuses the cumsum + H2D instead of recomputing it inside its KVCompressor.
513+
self._compressor_ratios = sorted({r for r in compress_ratios[: v4_cfg.num_hidden_layers] if r > 0})
509514

510515
layers = nn.ModuleDict()
511516
for layer_idx in range(v4_cfg.num_hidden_layers):
@@ -604,10 +609,23 @@ def _should_compute_aux_loss(self, layer_idx: int) -> bool:
604609
# automatically — V4's mtp_block is None (build_mtp_block returns None) — so the
605610
# parent's MTP branch is a no-op (PR9 follow-up wires the V4-specific MTP head).
606611

612+
def _assign_compressed_cu_seq_lens(self, seq_ctx) -> None:
613+
# Build ``cu_seq_lens_out`` once per distinct compress_ratio and cache it on the
614+
# SequenceContext, so the per-layer KVCompressor (DSA + Indexer) reuses it instead of
615+
# re-running the cumsum + H2D every call. Keyed by ratio because the chunk count is
616+
# ``ceil(L_i / ratio)`` — different for the ratio-4 and ratio-128 layers.
617+
if not self._compressor_ratios:
618+
return
619+
seq_ctx.compressed_cu_seq_lens = {
620+
ratio: KVCompressor.build_cu_seq_lens_out(seq_ctx.cu_seq_lens_q, seq_ctx.cu_seq_lens_q_cpu, ratio)[0]
621+
for ratio in self._compressor_ratios
622+
}
623+
607624
@override
608625
def _prepare_hidden_states(self, seq_ctx) -> tuple[torch.Tensor, dict]: # type: ignore[override]
609626
assert seq_ctx.position_ids is not None
610627
assert seq_ctx.input_ids is not None, "DeepSeekV4 requires input_ids (HashRouter consumes them)"
628+
self._assign_compressed_cu_seq_lens(seq_ctx)
611629
hidden_states = self.embed_tokens(seq_ctx.input_ids)
612630
# Dense rope (sliding-window heads) and compressed rope (Indexer) both come
613631
# from the same DualRotaryEmbedding; precompute both so each layer picks the
@@ -676,6 +694,7 @@ def _prepare_hidden_states_mb(self, seq_ctx_list) -> tuple[list[torch.Tensor], d
676694
for seq_ctx in seq_ctx_list:
677695
assert seq_ctx.position_ids is not None
678696
assert seq_ctx.input_ids is not None, "DeepSeekV4 requires input_ids (HashRouter consumes them)"
697+
self._assign_compressed_cu_seq_lens(seq_ctx)
679698
h = self.embed_tokens(seq_ctx.input_ids)
680699
pos_emb = self.rotary_emb(h, seq_ctx.position_ids, use_compressed=False)
681700
pos_emb_compressed = _build_compressed_position_embeddings(self.rotary_emb, h, seq_ctx.position_ids)

xtuner/v1/module/attention/dsa.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def __init__(
277277
compress_ratio=compress_ratio,
278278
overlap=(compress_ratio == 4),
279279
rotate=False,
280+
qk_rope_head_dim=dsa_cfg.qk_rope_head_dim,
280281
rms_norm_eps=dsa_cfg.rms_norm_eps,
281282
)
282283
else:
@@ -461,8 +462,13 @@ def forward(
461462
)
462463
if hidden_states.size(-1) != self.hidden_size:
463464
raise ValueError(f"hidden_states last dim {hidden_states.size(-1)} != hidden_size {self.hidden_size}")
464-
if self.compress_ratio == 4 and position_embeddings_compressed is None:
465-
raise ValueError("position_embeddings_compressed is required for compress_ratio == 4 (Indexer rope)")
465+
if self.compress_ratio > 0 and position_embeddings_compressed is None:
466+
raise ValueError(
467+
"position_embeddings_compressed is required for compress_ratio > 0 "
468+
"(the KVCompressor rotates compressed-kv with the compressed-rope basis to "
469+
"match V4 reference Compressor.forward, mirroring HF "
470+
"DeepseekV4{CSA,HCA}Compressor)"
471+
)
466472

467473
cos, sin = position_embeddings
468474
total_tokens = hidden_states.size(1)
@@ -534,10 +540,22 @@ def forward(
534540
# so, the single outer call eliminates N entry/exit pairs and
535541
# batches the GEMMs at the layer boundary.
536542
assert self.compressor is not None # compress_ratio > 0 always materialises it
543+
# Compressed boundaries are built once per compress_ratio at model forward and cached
544+
# on seq_ctx (DeepSeekV4._assign_compressed_cu_seq_lens); reuse this layer's instead of
545+
# recomputing the cumsum + H2D. ``None`` (e.g. standalone DSA in a unit test) falls back
546+
# to the compressor building it. The Indexer's internal compressor shares this ratio's
547+
# value. ``.get`` over the constant ``self.compress_ratio`` key stays compile-traceable.
548+
cu_seq_lens_out = (
549+
seq_ctx.compressed_cu_seq_lens.get(self.compress_ratio)
550+
if seq_ctx.compressed_cu_seq_lens is not None
551+
else None
552+
)
537553
kv_compressed, cu_c = self.compressor(
538554
hidden_states,
539555
cu_q,
540556
cu_seq_lens_cpu=seq_ctx.cu_seq_lens_q_cpu,
557+
position_embeddings_compressed=position_embeddings_compressed,
558+
cu_seq_lens_out=cu_seq_lens_out,
541559
) # [1, total_c, D], [B+1]
542560
if self.compress_ratio == 4:
543561
# ``DualRotaryEmbedding`` already emits half-dim cos/sin in the
@@ -564,6 +582,7 @@ def forward(
564582
(cos_c, sin_c),
565583
cu_q,
566584
cu_seq_lens_cpu=seq_ctx.cu_seq_lens_q_cpu,
585+
cu_seq_lens_out=cu_seq_lens_out,
567586
)
568587
else:
569588
# compress_ratio == 128: deterministic positional top-k.

xtuner/v1/module/attention/indexer.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,12 @@ def __init__(self, config: IndexerConfig) -> None:
181181
compress_ratio=config.compress_ratio,
182182
overlap=True,
183183
rotate=True,
184+
# Wire the same ``rope_head_dim`` the Indexer uses for its q rope
185+
# tail through to the internal KVCompressor: HF
186+
# ``DeepseekV4Indexer.forward`` rotates its ``compressed`` output at
187+
# window-center positions before scoring (modeling_deepseek_v4.py L541),
188+
# so the internal compressor must do the same.
189+
qk_rope_head_dim=config.rope_head_dim,
184190
rms_norm_eps=config.rms_norm_eps,
185191
)
186192

@@ -191,6 +197,7 @@ def forward(
191197
position_embeddings_compressed: tuple[torch.Tensor, torch.Tensor],
192198
cu_seq_lens: torch.Tensor,
193199
cu_seq_lens_cpu: torch.Tensor | None = None,
200+
cu_seq_lens_out: torch.Tensor | None = None,
194201
) -> torch.Tensor:
195202
"""Compute per-query top-k compressed-KV indices.
196203
@@ -206,6 +213,12 @@ def forward(
206213
DSA layer's dual-rope module.
207214
cu_seq_lens (torch.Tensor): 1D int32 cumulative per-sample token
208215
counts with length ``num_samples + 1``.
216+
cu_seq_lens_cpu (torch.Tensor | None): Optional CPU mirror of
217+
``cu_seq_lens`` forwarded to the internal compressor.
218+
cu_seq_lens_out (torch.Tensor | None): Optional precomputed compressed
219+
boundaries for this ``compress_ratio`` (built once at model
220+
forward); forwarded to the internal compressor so it skips the
221+
per-call cumsum + H2D. See :meth:`KVCompressor.build_cu_seq_lens_out`.
209222
210223
Returns:
211224
torch.Tensor: Top-k indices shaped ``[1, total_tokens, index_topk]``
@@ -259,6 +272,8 @@ def forward(
259272
hidden_states,
260273
cu_seq_lens,
261274
cu_seq_lens_cpu=cu_seq_lens_cpu,
275+
position_embeddings_compressed=position_embeddings_compressed,
276+
cu_seq_lens_out=cu_seq_lens_out,
262277
)
263278

264279
# Step 5: gate weights, scaled exactly as V4 reference L418.

0 commit comments

Comments
 (0)