You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Voxtral Realtime: unify SDPA classes and dtype-aware attention masks (#17997)
Unify CudaSDPA and StandardEncoderSDPA into a single StandardSDPA class
with transpose_kv parameter, mirroring the MetalSDPA unification. This
gives a symmetric design: MetalSDPA and StandardSDPA share the same
interface (n_heads, n_kv_heads, head_dim, transpose_kv).
Make _build_attn_mask and create_causal_mask dtype-aware — masks are now
created in the model dtype instead of always float32. This is required
because the Metal SDPA kernel reads the mask buffer as device T* (same
type as Q/K/V). A float32 mask with bf16 Q/K/V would be misinterpreted.
0 commit comments