Skip to content

Commit eb86187

Browse files
authored
feat(asr): add MPS/Apple Silicon support to Gradio ASR demo (#344)
* feat(asr): add MPS/Apple Silicon support to Gradio ASR demo The Gradio ASR demo only detected CUDA and fell back to CPU, ignoring Apple Metal Performance Shaders entirely. On Apple Silicon Macs the demo was essentially unusable—model inference ran on CPU at extreme slowness. Changes: - Add _detect_device_and_attn() helper that auto-detects the best available device (CUDA > MPS > CPU) and picks a compatible attention implementation (flash_attention_2 for CUDA when available, sdpa otherwise). This mirrors the pattern already used in the TTS demo (realtime_model_inference_from_file.py). - Update VibeVoiceASRInference.__init__ to handle MPS device loading: load model on CPU first then move to MPS, since device_map='mps' is not supported by Accelerate. - Update initialize_model() to use float32 for MPS (and CPU), matching the file inference script's dtype selection. - Add --device CLI argument (auto|cuda|mps|xpu|cpu) with sensible auto default, and change --attn_implementation default from flash_attention_2 to auto. - Add 14 regression tests covering auto-detection, explicit device selection, MPS-unavailable fallback, attention implementation resolution, dtype selection, and CLI argument parsing. Refs #339 * chore: remove test file per reviewer request Drop tests/test_gradio_asr_device_detection.py to keep PR scope to the Gradio demo changes only. --------- Co-authored-by: voidborne-d <voidborne-d@users.noreply.github.com> Co-authored-by: d 🔹 <258577966+voidborne-d@users.noreply.github.com>
1 parent 4a78d3e commit eb86187

1 file changed

Lines changed: 99 additions & 18 deletions

File tree

demo/vibevoice_asr_gradio_demo.py

Lines changed: 99 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(self, model_path: str, device: str = "cuda", dtype: torch.dtype = t
6565
6666
Args:
6767
model_path: Path to the pretrained model (HuggingFace format directory or model name)
68-
device: Device to run inference on
68+
device: Device to run inference on (cuda, mps, xpu, cpu, auto)
6969
dtype: Data type for model weights
7070
attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager')
7171
"""
@@ -74,17 +74,34 @@ def __init__(self, model_path: str, device: str = "cuda", dtype: torch.dtype = t
7474
# Load processor
7575
self.processor = VibeVoiceASRProcessor.from_pretrained(model_path)
7676

77-
# Load model
77+
# Load model with device-specific handling
7878
print(f"Using attention implementation: {attn_implementation}")
79-
self.model = VibeVoiceASRForConditionalGeneration.from_pretrained(
80-
model_path,
81-
dtype=dtype,
82-
device_map=device if device == "auto" else None,
83-
attn_implementation=attn_implementation,
84-
trust_remote_code=True
85-
)
86-
87-
if device != "auto":
79+
if device == "mps":
80+
# MPS: load onto CPU first, then move (device_map="mps" is not supported)
81+
self.model = VibeVoiceASRForConditionalGeneration.from_pretrained(
82+
model_path,
83+
dtype=dtype,
84+
device_map=None,
85+
attn_implementation=attn_implementation,
86+
trust_remote_code=True
87+
)
88+
self.model = self.model.to("mps")
89+
elif device == "auto":
90+
self.model = VibeVoiceASRForConditionalGeneration.from_pretrained(
91+
model_path,
92+
dtype=dtype,
93+
device_map="auto",
94+
attn_implementation=attn_implementation,
95+
trust_remote_code=True
96+
)
97+
else:
98+
self.model = VibeVoiceASRForConditionalGeneration.from_pretrained(
99+
model_path,
100+
dtype=dtype,
101+
device_map=device if device != "auto" else None,
102+
attn_implementation=attn_implementation,
103+
trust_remote_code=True
104+
)
88105
self.model = self.model.to(device)
89106

90107
self.device = device if device != "auto" else next(self.model.parameters()).device
@@ -480,7 +497,11 @@ def initialize_model(model_path: str, device: str = "cuda", attn_implementation:
480497
"""Initialize the ASR model."""
481498
global asr_model
482499
try:
483-
dtype = torch.bfloat16 if device != "cpu" else torch.float32
500+
# MPS and CPU require float32; CUDA/XPU can use bfloat16
501+
if device in ("mps", "cpu"):
502+
dtype = torch.float32
503+
else:
504+
dtype = torch.bfloat16
484505
asr_model = VibeVoiceASRInference(
485506
model_path=model_path,
486507
device=device,
@@ -893,17 +914,68 @@ def run_transcription():
893914
yield f"❌ Error during transcription: {str(e)}", ""
894915

895916

896-
def create_gradio_interface(model_path: str, default_max_tokens: int = 8192, attn_implementation: str = "flash_attention_2"):
917+
def _detect_device_and_attn(
918+
device: str = "auto",
919+
attn_implementation: str = "auto",
920+
):
921+
"""
922+
Auto-detect the best device and attention implementation.
923+
924+
Args:
925+
device: Explicit device or ``"auto"`` for best available.
926+
attn_implementation: Explicit implementation or ``"auto"``.
927+
928+
Returns:
929+
(device, attn_implementation) tuple.
930+
"""
931+
# --- resolve device ---
932+
if device == "auto":
933+
if torch.cuda.is_available():
934+
device = "cuda"
935+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
936+
device = "mps"
937+
else:
938+
device = "cpu"
939+
elif device == "mps" and not (
940+
hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
941+
):
942+
print("Warning: MPS requested but not available. Falling back to CPU.")
943+
device = "cpu"
944+
945+
# --- resolve attention ---
946+
if attn_implementation == "auto":
947+
if device == "cuda":
948+
try:
949+
import flash_attn # noqa: F401
950+
attn_implementation = "flash_attention_2"
951+
except ImportError:
952+
print("flash_attn not installed, falling back to sdpa")
953+
attn_implementation = "sdpa"
954+
else:
955+
# MPS / XPU / CPU don't support flash_attention_2
956+
attn_implementation = "sdpa"
957+
958+
print(f"Using device: {device}, attn_implementation: {attn_implementation}")
959+
return device, attn_implementation
960+
961+
962+
def create_gradio_interface(
963+
model_path: str,
964+
default_max_tokens: int = 8192,
965+
device: str = "auto",
966+
attn_implementation: str = "auto",
967+
):
897968
"""Create and launch Gradio interface.
898969
899970
Args:
900971
model_path: Path to the model (HuggingFace format directory or model name)
901972
default_max_tokens: Default value for max_new_tokens slider
902-
attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager')
973+
device: Device to run inference on ('auto', 'cuda', 'mps', 'xpu', 'cpu')
974+
attn_implementation: Attention implementation to use ('auto', 'flash_attention_2', 'sdpa', 'eager')
903975
"""
904976

905977
# Initialize model at startup
906-
device = "cuda" if torch.cuda.is_available() else "cpu"
978+
device, attn_implementation = _detect_device_and_attn(device, attn_implementation)
907979
model_status = initialize_model(model_path, device, attn_implementation)
908980
print(model_status)
909981

@@ -1126,11 +1198,19 @@ def main():
11261198
default="microsoft/VibeVoice-ASR",
11271199
help="Path to the model (HuggingFace format directory or model name)"
11281200
)
1201+
parser.add_argument(
1202+
"--device",
1203+
type=str,
1204+
default="auto",
1205+
choices=["auto", "cuda", "mps", "xpu", "cpu"],
1206+
help="Device to run inference on. 'auto' detects the best available (default: auto)"
1207+
)
11291208
parser.add_argument(
11301209
"--attn_implementation",
11311210
type=str,
1132-
default="flash_attention_2",
1133-
help="Attention implementation to use (default: flash_attention_2)"
1211+
default="auto",
1212+
choices=["auto", "flash_attention_2", "sdpa", "eager"],
1213+
help="Attention implementation to use. 'auto' selects the best for your device (default: auto)"
11341214
)
11351215
parser.add_argument(
11361216
"--max_new_tokens",
@@ -1162,7 +1242,8 @@ def main():
11621242
demo, custom_css = create_gradio_interface(
11631243
model_path=args.model_path,
11641244
default_max_tokens=args.max_new_tokens,
1165-
attn_implementation=args.attn_implementation
1245+
device=args.device,
1246+
attn_implementation=args.attn_implementation,
11661247
)
11671248

11681249
print(f"🚀 Starting VibeVoice ASR Demo...")

0 commit comments

Comments
 (0)