Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 34 additions & 15 deletions demo/vibevoice_asr_gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@
from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, COMMON_AUDIO_EXTS


def format_plain_transcription(segments: List[Dict]) -> str:
"""Format parsed ASR segments as copyable transcript text."""
lines = []
for segment in segments:
text = str(segment.get('text', '')).strip()
if text:
lines.append(text)
return "\n".join(lines)


class VibeVoiceASRInference:
"""Simple inference wrapper for VibeVoice ASR model."""

Expand Down Expand Up @@ -526,7 +536,7 @@ def transcribe_audio(
do_sample: bool,
repetition_penalty: float = 1.0,
context_info: str = ""
) -> Generator[Tuple[str, str], None, None]:
) -> Generator[Tuple[str, str, str], None, None]:
"""
Transcribe audio and return results with audio segments (streaming version).

Expand All @@ -539,14 +549,14 @@ def transcribe_audio(
context_info: Optional context information (e.g., hotwords, speaker names, topics)

Yields:
Tuple of (raw_text, audio_segments_html)
Tuple of (raw_text, plain_text, audio_segments_html)
"""
if asr_model is None:
yield "❌ Please load a model first!", ""
yield "❌ Please load a model first!", "", ""
return

if not audio_path_input and audio_input is None:
yield "❌ Please provide audio input!", ""
yield "❌ Please provide audio input!", "", ""
return

try:
Expand All @@ -555,7 +565,7 @@ def transcribe_audio(
end_sec = parse_time_to_seconds(end_time_input)
print(f"[INFO] Parsed time range: start={start_sec}, end={end_sec}")
if (start_time_input and start_sec is None) or (end_time_input and end_sec is None):
yield "❌ Invalid time format. Use seconds or hh:mm:ss.", ""
yield "❌ Invalid time format. Use seconds or hh:mm:ss.", "", ""
return

audio_path = None
Expand All @@ -566,10 +576,10 @@ def transcribe_audio(
candidate = Path(audio_path_input.strip())
# Security: validate file extension to prevent arbitrary file probing
if candidate.suffix.lower() not in {e.lower() for e in COMMON_AUDIO_EXTS}:
yield "❌ Unsupported audio format.", ""
yield "❌ Unsupported audio format.", "", ""
return
if not candidate.exists():
yield f"❌ Provided path does not exist: {candidate}", ""
yield f"❌ Provided path does not exist: {candidate}", "", ""
return
audio_path = str(candidate)
print(f"[INFO] Using provided audio path: {audio_path}")
Expand All @@ -582,7 +592,7 @@ def transcribe_audio(
sample_rate, audio_array = audio_input
print(f"[INFO] Received microphone audio with sample_rate={sample_rate}")
elif audio_path is None:
yield "❌ Invalid audio input format!", ""
yield "❌ Invalid audio input format!", "", ""
return

# If slicing is requested, load and slice audio
Expand All @@ -593,11 +603,11 @@ def transcribe_audio(
audio_array, sample_rate = load_audio_use_ffmpeg(audio_path, resample=False)
print("[INFO] Loaded audio for slicing via ffmpeg")
except Exception as exc:
yield f"❌ Failed to load audio for slicing: {exc}", ""
yield f"❌ Failed to load audio for slicing: {exc}", "", ""
return
sliced_path, err = slice_audio_to_temp(audio_array, sample_rate, start_sec, end_sec)
if err:
yield f"❌ {err}", ""
yield f"❌ {err}", "", ""
return
audio_path = sliced_path
print(f"[INFO] Sliced audio written to temp file: {audio_path}")
Expand Down Expand Up @@ -652,13 +662,13 @@ def run_transcription():
# Show streaming output with live stats, format for readability
formatted_text = generated_text.replace('},', '},\n')
streaming_output = f"--- 🔴 LIVE Streaming Output (tokens: {token_count}, time: {elapsed:.1f}s) ---\n{formatted_text}"
yield streaming_output, "<div style='padding: 20px; text-align: center; color: #6c757d;'>⏳ Generating transcription... Audio segments will appear after completion.</div>"
yield streaming_output, "", "<div style='padding: 20px; text-align: center; color: #6c757d;'>⏳ Generating transcription... Audio segments will appear after completion.</div>"

# Wait for thread to complete
transcription_thread.join()

if result_container["error"]:
yield f"❌ Error during transcription: {result_container['error']}", ""
yield f"❌ Error during transcription: {result_container['error']}", "", ""
return

result = result_container["result"]
Expand All @@ -684,6 +694,7 @@ def run_transcription():
print(f"[DEBUG] Raw model output:")
print(f"[DEBUG] {result['raw_text']}")
print(f"[DEBUG] Found {len(result['segments'])} segments")
plain_text = format_plain_transcription(result['segments'])

# Create audio segments with server-side encoding (low quality for minimal transfer)
# Using: 16kHz mono MP3 @ 32kbps = ~4KB per second of audio
Expand Down Expand Up @@ -906,12 +917,12 @@ def run_transcription():
"""

# Final yield with complete results
yield raw_output, audio_segments_html
yield raw_output, plain_text, audio_segments_html

except Exception as e:
print(f"Error during transcription: {e}")
print(traceback.format_exc())
yield f"❌ Error during transcription: {str(e)}", ""
yield f"❌ Error during transcription: {str(e)}", "", ""


def _detect_device_and_attn(
Expand Down Expand Up @@ -1108,6 +1119,14 @@ def create_gradio_interface(
gr.Markdown("## 📝 Results")

with gr.Tabs():
with gr.TabItem("Plain Text"):
plain_text_output = gr.Textbox(
label="Plain Transcription Text",
lines=8,
max_lines=20,
interactive=False
)

with gr.TabItem("Raw Output"):
raw_output = gr.Textbox(
label="Raw Transcription Output",
Expand Down Expand Up @@ -1158,7 +1177,7 @@ def set_stop_flag():
repetition_penalty_slider,
context_info_input
],
outputs=[raw_output, audio_segments_output]
outputs=[raw_output, plain_text_output, audio_segments_output]
)

stop_button.click(
Expand Down