Skip to content

Commit ae59b3b

Browse files
authored
Merge pull request pipecat-ai#3404 from poseneror/feature/gladia-vad-events
feat(gladia): add VAD events support
2 parents 8b0f0b5 + 3304b18 commit ae59b3b

2 files changed

Lines changed: 44 additions & 0 deletions

File tree

src/pipecat/services/gladia/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@ class GladiaInputParams(BaseModel):
169169
pre_processing: Audio pre-processing options
170170
realtime_processing: Real-time processing features
171171
messages_config: WebSocket message filtering options
172+
enable_vad: Enable VAD to trigger end of utterance detection. This should be used
173+
without any other VAD enabled in the agent and will emit the speaker started
174+
and stopped frames. Defaults to False.
172175
"""
173176

174177
encoding: Optional[str] = "wav/pcm"
@@ -182,3 +185,4 @@ class GladiaInputParams(BaseModel):
182185
pre_processing: Optional[PreProcessingConfig] = None
183186
realtime_processing: Optional[RealtimeProcessingConfig] = None
184187
messages_config: Optional[MessagesConfig] = None
188+
enable_vad: bool = False

src/pipecat/services/gladia/stt.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
StartFrame,
2929
TranscriptionFrame,
3030
TranslationFrame,
31+
UserStartedSpeakingFrame,
32+
UserStoppedSpeakingFrame,
3133
)
3234
from pipecat.services.gladia.config import GladiaInputParams
3335
from pipecat.services.stt_service import WebsocketSTTService
@@ -202,6 +204,7 @@ def __init__(
202204
model: str = "solaria-1",
203205
params: Optional[GladiaInputParams] = None,
204206
max_buffer_size: int = 1024 * 1024 * 20, # 20MB default buffer
207+
should_interrupt: bool = True,
205208
**kwargs,
206209
):
207210
"""Initialize the Gladia STT service.
@@ -220,6 +223,8 @@ def __init__(
220223
model: Model to use for transcription. Defaults to "solaria-1".
221224
params: Additional configuration parameters for Gladia service.
222225
max_buffer_size: Maximum size of audio buffer in bytes. Defaults to 20MB.
226+
should_interrupt: Determine whether the bot should be interrupted when
227+
Gladia VAD detects user speech. Defaults to True.
223228
**kwargs: Additional arguments passed to the STTService parent class.
224229
"""
225230
super().__init__(sample_rate=sample_rate, **kwargs)
@@ -266,6 +271,10 @@ def __init__(
266271
self._max_buffer_size = max_buffer_size
267272
self._buffer_lock = asyncio.Lock()
268273

274+
# VAD state tracking
275+
self._is_speaking = False
276+
self._should_interrupt = should_interrupt
277+
269278
def __str__(self):
270279
return f"{self.name} [{self._session_id}]"
271280

@@ -507,6 +516,33 @@ async def _handle_transcription(
507516
await self.stop_ttfb_metrics()
508517
await self.stop_processing_metrics()
509518

519+
async def _on_speech_started(self):
520+
"""Handle speech start event from Gladia.
521+
522+
Broadcasts UserStartedSpeakingFrame and optionally triggers interruption
523+
when VAD is enabled.
524+
"""
525+
if not self._params.enable_vad or self._is_speaking:
526+
return
527+
528+
logger.debug(f"{self} User started speaking")
529+
self._is_speaking = True
530+
531+
await self.broadcast_frame(UserStartedSpeakingFrame)
532+
if self._should_interrupt:
533+
await self.push_interruption_task_frame_and_wait()
534+
535+
async def _on_speech_ended(self):
536+
"""Handle speech end event from Gladia.
537+
538+
Broadcasts UserStoppedSpeakingFrame when VAD is enabled.
539+
"""
540+
if not self._params.enable_vad or not self._is_speaking:
541+
return
542+
self._is_speaking = False
543+
await self.broadcast_frame(UserStoppedSpeakingFrame)
544+
logger.debug(f"{self} User stopped speaking")
545+
510546
async def _send_audio(self, audio: bytes):
511547
"""Send audio chunk with proper message format."""
512548
if self._websocket and self._websocket.state is State.OPEN:
@@ -599,6 +635,10 @@ async def _receive_messages(self):
599635
translation, "", time_now_iso8601(), translated_language
600636
)
601637
)
638+
elif content["type"] == "speech_start":
639+
await self._on_speech_started()
640+
elif content["type"] == "speech_end":
641+
await self._on_speech_ended()
602642
except json.JSONDecodeError:
603643
logger.warning(f"{self} Received non-JSON message: {message}")
604644

0 commit comments

Comments
 (0)