You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
feat(asr): add MPS/Apple Silicon support to Gradio ASR demo (#344)
* feat(asr): add MPS/Apple Silicon support to Gradio ASR demo
The Gradio ASR demo only detected CUDA and fell back to CPU, ignoring
Apple Metal Performance Shaders entirely. On Apple Silicon Macs the demo
was essentially unusable—model inference ran on CPU at extreme slowness.
Changes:
- Add _detect_device_and_attn() helper that auto-detects the best
available device (CUDA > MPS > CPU) and picks a compatible attention
implementation (flash_attention_2 for CUDA when available, sdpa
otherwise). This mirrors the pattern already used in the TTS demo
(realtime_model_inference_from_file.py).
- Update VibeVoiceASRInference.__init__ to handle MPS device loading:
load model on CPU first then move to MPS, since device_map='mps' is
not supported by Accelerate.
- Update initialize_model() to use float32 for MPS (and CPU), matching
the file inference script's dtype selection.
- Add --device CLI argument (auto|cuda|mps|xpu|cpu) with sensible auto
default, and change --attn_implementation default from flash_attention_2
to auto.
- Add 14 regression tests covering auto-detection, explicit device
selection, MPS-unavailable fallback, attention implementation
resolution, dtype selection, and CLI argument parsing.
Refs #339
* chore: remove test file per reviewer request
Drop tests/test_gradio_asr_device_detection.py to keep PR scope to
the Gradio demo changes only.
---------
Co-authored-by: voidborne-d <voidborne-d@users.noreply.github.com>
Co-authored-by: d 🔹 <258577966+voidborne-d@users.noreply.github.com>
0 commit comments