Voxtral Realtime: enable bf16 for Metal backend with quantization#17845
Voxtral Realtime: enable bf16 for Metal backend with quantization#17845mergennachin merged 2 commits intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17845
Note: Links to docs will display an error until the docs builds have been completed.
|
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
Enables and recommends bf16 for Voxtral Realtime exports on Metal when using quantization, updating CI export arguments and user-facing docs to reflect the preferred configuration for memory/throughput.
Changes:
- Update Voxtral Realtime docs to include bf16 memory footprint numbers and recommend
--dtype bf16for Metal quantized exports. - Adjust example Metal export command(s) to include
--dtype bf16alongsidefpa4w. - Update Metal CI export script to pass
--dtype bf16for thequantized-int4-metalconfiguration.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| examples/models/voxtral_realtime/model.md | Updates memory calculations and guidance around bf16 + quantization for Metal/CUDA. |
| examples/models/voxtral_realtime/export_voxtral_rt.py | Updates usage example to show Metal export with bf16 + fpa4w. |
| examples/models/voxtral_realtime/README.md | Updates Metal backend table and export examples to recommend bf16 with fpa4w. |
| .ci/scripts/export_model_artifact.sh | Ensures Metal int4 quantized CI export passes --dtype bf16. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
40b6144 to
52027ff
Compare
The Metal AOTI backend already handles bf16 correctly (fp32 attention masks, fp32 RoPE upcast, dtype-agnostic KV caches and SDPA). Enable --dtype bf16 as the default recipe for Metal CI and update all documentation to recommend bf16 with fpa4w quantization. Fix a Metal shader compilation bug in the streaming encoder where bool.to(bf16) generates `bfloat tmp = 0.0;` — Metal Shading Language doesn't support implicit float-to-bfloat literal conversion. Use .float() instead and let mul_ handle type promotion.
52027ff to
77b74fd
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
| fp32: ≈ 832 MB, bf16: ≈ 416 MB. Encoder KV caches (streaming): | ||
| 32 layers × 2 × 1500 × 32 × 64 × bytes_per_elem. fp32: ≈ 786 MB, | ||
| bf16: ≈ 393 MB. | ||
|
|
| **Metal:** `MetalSDPA` uses `torch.ops.aten._scaled_dot_product_attention_math_for_mps` | ||
| which handles GQA natively via `gqa_factor`, avoiding the memory bandwidth | ||
| overhead of `repeat_interleave`. Uses explicit additive attention masks | ||
| which handles GQA natively (the kernel infers the group ratio from differing | ||
| Q vs K/V head counts), avoiding the memory bandwidth overhead of | ||
| `repeat_interleave`. Uses explicit additive attention masks | ||
| that must match the Q/K/V dtype (the kernel reads masks as `device T*`). |
The Metal AOTI backend already handles bf16 correctly (fp32 attention
masks, fp32 RoPE upcast, dtype-agnostic KV caches and SDPA). Enable
--dtype bf16 as the default recipe for Metal CI and update all
documentation to recommend bf16 with fpa4w quantization.