Skip to content

Commit 88dc743

Browse files
authored
Voxtral Realtime: unify SDPA classes and dtype-aware attention masks (#17997)
Unify CudaSDPA and StandardEncoderSDPA into a single StandardSDPA class with transpose_kv parameter, mirroring the MetalSDPA unification. This gives a symmetric design: MetalSDPA and StandardSDPA share the same interface (n_heads, n_kv_heads, head_dim, transpose_kv). Make _build_attn_mask and create_causal_mask dtype-aware — masks are now created in the model dtype instead of always float32. This is required because the Metal SDPA kernel reads the mask buffer as device T* (same type as Q/K/V). A float32 mask with bf16 Q/K/V would be misinterpreted.
1 parent 8c32dc8 commit 88dc743

File tree

3 files changed

+116
-123
lines changed

3 files changed

+116
-123
lines changed

examples/models/voxtral_realtime/export_voxtral_rt.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818
1919
Backend support:
2020
- XNNPACK (default): Uses custom SDPA op (torch.ops.llama.custom_sdpa) for optimal performance
21-
- Metal/AOTI: Uses MetalSDPA (_scaled_dot_product_attention_math_for_mps) for text_decoder
22-
and StandardEncoderSDPA (F.scaled_dot_product_attention) for streaming encoder,
23-
avoiding custom_sdpa which is incompatible with AOTI. Uses Dim.AUTO for audio
24-
encoder dynamic shapes (explicit bounds cause issues with AOTI).
25-
- CUDA/AOTI: Uses CudaSDPA (F.scaled_dot_product_attention with GQA expansion) for text_decoder
26-
and StandardEncoderSDPA for streaming encoder. Compiles to CUDA kernels via
27-
AOTInductor. Supports int4 quantization via _weight_int4pack_mm fallback kernel.
21+
- Metal/AOTI: Uses MetalSDPA (_scaled_dot_product_attention_math_for_mps) for both text_decoder
22+
and streaming encoder (transpose_kv=True), avoiding custom_sdpa which is
23+
incompatible with AOTI. Uses Dim.AUTO for audio encoder dynamic shapes
24+
(explicit bounds cause issues with AOTI).
25+
- CUDA/AOTI: Uses StandardSDPA (F.scaled_dot_product_attention with GQA expansion) for
26+
text_decoder and streaming encoder (transpose_kv=True). Compiles to CUDA kernels
27+
via AOTInductor. Supports int4 quantization via _weight_int4pack_mm fallback kernel.
2828
- Portable: Uses custom SDPA like XNNPACK
2929
3030
Usage:

examples/models/voxtral_realtime/model.md

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ VoxtralRealtimeModel
102102
attention: LMAttention
103103
wq/wk/wv/wo: Linear (no bias)
104104
kv_cache: KVCache (XNNPACK) or StaticKVCache (Metal/CUDA)
105-
sdpa: SDPA (XNNPACK) or MetalSDPA (Metal) or CudaSDPA (CUDA)
105+
sdpa: SDPA (XNNPACK) or MetalSDPA (Metal) or StandardSDPA (CUDA)
106106
ffn_norm: RMSNorm
107107
ada_rms_norm_t_cond: Sequential(Linear, GELU, Linear)
108108
feed_forward: LMMLP (w1/w2/w3)
@@ -116,7 +116,7 @@ StreamingAudioEncoderExport
116116
enc_norm: RMSNorm (shared from encoder.norm)
117117
adapter: AudioLanguageAdapter (shared from model.adapter)
118118
kv_caches: 32x EncoderRingKVCache (XNNPACK) or StandardEncoderRingKVCache (Metal/CUDA)
119-
sdpa: SDPA (XNNPACK) or StandardEncoderSDPA (Metal/CUDA)
119+
sdpa: SDPA (XNNPACK) or MetalSDPA (Metal, transpose_kv=True) or StandardSDPA (CUDA, transpose_kv=True)
120120
inv_freq: RoPE inverse frequencies (owned, on-the-fly computation)
121121
```
122122

@@ -164,10 +164,12 @@ Handles GQA expansion internally and upcasts to float32.
164164

165165
**Metal:** `MetalSDPA` uses `torch.ops.aten._scaled_dot_product_attention_math_for_mps`
166166
which handles GQA natively via `gqa_factor`, avoiding the memory bandwidth
167-
overhead of `repeat_interleave`. Uses explicit additive attention masks.
168-
AOTInductor has compatibility issues with the `custom_sdpa` custom op.
167+
overhead of `repeat_interleave`. Uses explicit additive attention masks
168+
that must match the Q/K/V dtype (the kernel reads masks as `device T*`).
169+
Used for both decoder (GQA, `transpose_kv=False`) and streaming encoder
170+
(no GQA, `transpose_kv=True`).
169171

170-
**CUDA:** `CudaSDPA` uses `F.scaled_dot_product_attention` with
172+
**CUDA:** `StandardSDPA` uses `F.scaled_dot_product_attention` with
171173
`repeat_interleave` for GQA expansion (32 query heads / 8 KV heads = 4x).
172174
Uses boolean attention masks (`True`=attend, `False`=masked) as required
173175
by the Triton SDPA kernel. The CUDA backend's Triton SDPA replacement
@@ -183,7 +185,7 @@ pass optimizes the attention kernel at compile time.
183185
require when using `[B, H, S, D]` attention with `[B, S, H, D]` cache.
184186

185187
**Metal/CUDA:** Q/K/V projections still produce `[B, T, H, D]`, but
186-
`StaticKVCache` stores `[B, H, S, D]` and `MetalSDPA`/`CudaSDPA` transpose q to
188+
`StaticKVCache` stores `[B, H, S, D]` and `MetalSDPA`/`StandardSDPA` transpose q to
187189
`[B, H, T, D]` for the SDPA kernel, then transpose back.
188190

189191
### Adaptive RMSNorm
@@ -225,9 +227,13 @@ mel_chunk (1, 128, 8) + enc_input_pos (4,)
225227
**XNNPACK/Portable:** Uses `EncoderRingKVCache` (`update_cache_with_indices`
226228
custom op) and `SDPA` (`custom_sdpa`).
227229

228-
**Metal/CUDA:** Uses `StandardEncoderRingKVCache` (`index_copy_`-based ring
229-
buffer) and `StandardEncoderSDPA` (`F.scaled_dot_product_attention` with
230-
explicit sliding window masks) — AOTI-compatible patterns avoiding custom ops.
230+
**Metal:** Uses `StandardEncoderRingKVCache` (`index_copy_`-based ring
231+
buffer) and `MetalSDPA` (native MPS SDPA kernel with `transpose_kv=True`).
232+
Masks are created in the model dtype to match the kernel's `device T*` expectation.
233+
234+
**CUDA:** Uses `StandardEncoderRingKVCache` and `StandardSDPA`
235+
(`F.scaled_dot_product_attention` with `transpose_kv=True` and explicit
236+
sliding window masks).
231237

232238
### Streaming decode loop
233239

0 commit comments

Comments
 (0)