From 8a9f31ab10c844700d8b97aed08828e64bf0e113 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Mon, 16 Mar 2026 21:43:26 -0400 Subject: [PATCH 01/10] Remove duplicate reconnection logic from Gradium STT The _receive_messages method had its own while-True reconnect loop, duplicating the reconnection handling already provided by WebsocketService._receive_task_handler (exponential backoff, max retries, error reporting). Flatten to just the inner message loop and let the base class handle reconnection. --- src/pipecat/services/gradium/stt.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/pipecat/services/gradium/stt.py b/src/pipecat/services/gradium/stt.py index c328dceabe..6b428bd695 100644 --- a/src/pipecat/services/gradium/stt.py +++ b/src/pipecat/services/gradium/stt.py @@ -412,7 +412,7 @@ def _get_websocket(self): return self._websocket raise Exception("Websocket not connected") - async def _process_messages(self): + async def _receive_messages(self): async for message in self._get_websocket(): try: data = json.loads(message) @@ -420,12 +420,6 @@ async def _process_messages(self): except json.JSONDecodeError: logger.warning(f"Received non-JSON message: {message}") - async def _receive_messages(self): - while True: - await self._process_messages() - logger.debug(f"{self} Gradium connection was disconnected (timeout?), reconnecting") - await self._connect_websocket() - async def _process_response(self, msg): type_ = msg.get("type", "") if type_ == "text": From e0b09069a1c9bf08ab0777707e3140f46d026532 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Mon, 16 Mar 2026 21:49:54 -0400 Subject: [PATCH 02/10] Align Gradium STT VAD handling with base class patterns Replace the process_frame override with a _handle_vad_user_stopped_speaking override, which is the proper hook provided by STTService. Move start_processing_metrics() into run_stt (matching Gladia's pattern). Remove unused FrameDirection and VADUserStartedSpeakingFrame imports. --- src/pipecat/services/gradium/stt.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/src/pipecat/services/gradium/stt.py b/src/pipecat/services/gradium/stt.py index 6b428bd695..8287468652 100644 --- a/src/pipecat/services/gradium/stt.py +++ b/src/pipecat/services/gradium/stt.py @@ -24,10 +24,8 @@ Frame, StartFrame, TranscriptionFrame, - VADUserStartedSpeakingFrame, VADUserStoppedSpeakingFrame, ) -from pipecat.processors.frame_processor import FrameDirection from pipecat.services.settings import NOT_GIVEN, STTSettings, _NotGiven from pipecat.services.stt_latency import GRADIUM_TTFS_P99 from pipecat.services.stt_service import WebsocketSTTService @@ -249,23 +247,17 @@ async def cancel(self, frame: CancelFrame): await super().cancel(frame) await self._disconnect() - async def process_frame(self, frame: Frame, direction: FrameDirection): - """Process frames with VAD-specific handling. + async def _handle_vad_user_stopped_speaking(self, frame: VADUserStoppedSpeakingFrame): + """Handle VAD user stopped speaking by flushing the transcription. - When VAD detects the user has stopped speaking, we flush the transcription - by sending silence frames. This makes the system more reactive by getting - the final transcription faster without closing the connection. + Calls the base class handler for TTFB tracking, then sends silence + frames to flush remaining audio from the model's buffer. Args: - frame: The frame to process. - direction: The direction of frame processing. + frame: The VAD user stopped speaking frame. """ - await super().process_frame(frame, direction) - - if isinstance(frame, VADUserStartedSpeakingFrame): - await self.start_processing_metrics() - elif isinstance(frame, VADUserStoppedSpeakingFrame): - await self._flush_transcription() + await super()._handle_vad_user_stopped_speaking(frame) + await self._flush_transcription() async def _flush_transcription(self): """Flush the transcription by sending silence frames. @@ -309,6 +301,7 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: Yields: None (processing handled via WebSocket messages). """ + await self.start_processing_metrics() self._audio_buffer.extend(audio) while len(self._audio_buffer) >= self._chunk_size_bytes: From d12f3bd78e7c2143f71017f528e898a86921b557 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Mon, 16 Mar 2026 21:51:22 -0400 Subject: [PATCH 03/10] Add keepalive support to Gradium STT service Enable the base class keepalive mechanism (10s timeout, 5s interval) and override _send_keepalive to wrap silence in Gradium's audio message format. Prevents idle connection timeouts, especially behind a ServiceSwitcher. --- src/pipecat/services/gradium/stt.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/pipecat/services/gradium/stt.py b/src/pipecat/services/gradium/stt.py index 8287468652..6a872c50e5 100644 --- a/src/pipecat/services/gradium/stt.py +++ b/src/pipecat/services/gradium/stt.py @@ -173,6 +173,8 @@ def __init__( super().__init__( sample_rate=SAMPLE_RATE, ttfs_p99_latency=ttfs_p99_latency, + keepalive_timeout=10, + keepalive_interval=5, settings=default_settings, **kwargs, ) @@ -400,6 +402,16 @@ async def _disconnect_websocket(self): self._websocket = None await self._call_event_handler("on_disconnected") + async def _send_keepalive(self, silence: bytes): + """Send silent audio to keep the Gradium connection alive. + + Args: + silence: Silent PCM audio bytes to send as a keepalive. + """ + chunk = base64.b64encode(silence).decode("utf-8") + msg = json.dumps({"type": "audio", "audio": chunk}) + await self._websocket.send(msg) + def _get_websocket(self): if self._websocket: return self._websocket From c8c2ed4d8faf4e6e336c5407a75fb3d597930121 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Mon, 16 Mar 2026 22:24:33 -0400 Subject: [PATCH 04/10] Clean up Gradium STT message handling and add model_name to setup Inline _process_response into _receive_messages, add required model_name field to the setup message per Gradium docs, and improve _handle_text docstring. --- src/pipecat/services/gradium/stt.py | 39 +++++++++++++---------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/src/pipecat/services/gradium/stt.py b/src/pipecat/services/gradium/stt.py index 6a872c50e5..392ec4ec53 100644 --- a/src/pipecat/services/gradium/stt.py +++ b/src/pipecat/services/gradium/stt.py @@ -113,6 +113,7 @@ def __init__( *, api_key: str, api_endpoint_base_url: str = "wss://eu.api.gradium.ai/api/speech/asr", + encoding: str = "pcm", params: Optional[InputParams] = None, json_config: Optional[str] = None, settings: Optional[Settings] = None, @@ -124,6 +125,7 @@ def __init__( Args: api_key: Gradium API key for authentication. api_endpoint_base_url: WebSocket endpoint URL. Defaults to Gradium's streaming endpoint. + encoding: Audio input format. One of "pcm", "wav", or "opus". Defaults to "pcm". params: Configuration parameters for language and delay settings. .. deprecated:: 0.0.105 @@ -151,7 +153,7 @@ def __init__( # 1. Initialize default_settings with hardcoded defaults default_settings = self.Settings( - model=None, + model="default", language=None, delay_in_frames=None, ) @@ -181,6 +183,7 @@ def __init__( self._api_key = api_key self._api_endpoint_base_url = api_endpoint_base_url + self._encoding = encoding self._websocket = None self._json_config = json_config @@ -191,7 +194,7 @@ def __init__( self._chunk_size_bytes = 0 # Set from the ready message when connecting to the service. - # These values are used for flushing transcription. + # These values are used for flushing transcription via silence. self._delay_in_frames = 0 self._frame_size = 0 @@ -267,10 +270,6 @@ async def _flush_transcription(self): When VAD detects the user stopped speaking, we send delay_in_frames chunks of silence (zeros) to flush the remaining audio from the model's buffer. This allows for faster turn-around without closing the connection. - - From Gradium docs: "feed in delay_in_frames chunks of silence (vectors - of zeros). If those are fed in faster than realtime, the API also has - a possibility to process them faster." """ if not self._websocket or self._websocket.state is not State.OPEN: return @@ -348,7 +347,8 @@ async def _connect_websocket(self): await self._call_event_handler("on_connected") setup_msg = { "type": "setup", - "input_format": "pcm", + "model_name": self._settings.model, + "input_format": self._encoding, } # Build json_config: start with deprecated json_config, then override with params json_config = {} @@ -420,23 +420,18 @@ def _get_websocket(self): async def _receive_messages(self): async for message in self._get_websocket(): try: - data = json.loads(message) - await self._process_response(data) + msg = json.loads(message) except json.JSONDecodeError: logger.warning(f"Received non-JSON message: {message}") - - async def _process_response(self, msg): - type_ = msg.get("type", "") - if type_ == "text": - await self._handle_text(msg["text"]) - elif type_ == "end_of_stream": - await self._handle_end_of_stream() - elif type_ == "error": - await self.push_error(error_msg=f"Error: {msg}") - - async def _handle_end_of_stream(self): - """Handle termination message.""" - logger.debug("Received end_of_stream message from server") + continue + + type_ = msg.get("type", "") + if type_ == "text": + await self._handle_text(msg["text"]) + elif type_ == "end_of_stream": + logger.debug("Received end_of_stream message from server") + elif type_ == "error": + await self.push_error(error_msg=f"Error: {msg}") async def _handle_text(self, text: str): """Handle transcription results.""" From 5ce36db8a03d10b20c7c3f7b8bf84c0d734705da Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Tue, 17 Mar 2026 22:32:28 -0400 Subject: [PATCH 05/10] Rework Gradium STT to use flush API and text accumulation Replace silence-based flushing with Gradium flush/flushed protocol. Accumulate word-level text fragments as InterimTranscriptionFrames and emit a single TranscriptionFrame on flush completion. Align VAD handling with CartesiaSTTService pattern using process_frame override. Remove keepalive (not supported by Gradium) and pass language to transcription frames. --- src/pipecat/services/gradium/stt.py | 125 +++++++++++++++------------- 1 file changed, 68 insertions(+), 57 deletions(-) diff --git a/src/pipecat/services/gradium/stt.py b/src/pipecat/services/gradium/stt.py index 392ec4ec53..fdddf68f21 100644 --- a/src/pipecat/services/gradium/stt.py +++ b/src/pipecat/services/gradium/stt.py @@ -22,10 +22,13 @@ CancelFrame, EndFrame, Frame, + InterimTranscriptionFrame, StartFrame, TranscriptionFrame, + VADUserStartedSpeakingFrame, VADUserStoppedSpeakingFrame, ) +from pipecat.processors.frame_processor import FrameDirection from pipecat.services.settings import NOT_GIVEN, STTSettings, _NotGiven from pipecat.services.stt_latency import GRADIUM_TTFS_P99 from pipecat.services.stt_service import WebsocketSTTService @@ -175,8 +178,6 @@ def __init__( super().__init__( sample_rate=SAMPLE_RATE, ttfs_p99_latency=ttfs_p99_latency, - keepalive_timeout=10, - keepalive_interval=5, settings=default_settings, **kwargs, ) @@ -193,10 +194,12 @@ def __init__( self._chunk_size_ms = 80 self._chunk_size_bytes = 0 - # Set from the ready message when connecting to the service. - # These values are used for flushing transcription via silence. - self._delay_in_frames = 0 - self._frame_size = 0 + # Accumulates text fragments within a turn. Each "text" message + # appends to this list. On "flushed" the full text is joined and + # pushed as a TranscriptionFrame. Any trailing fragments are + # flushed when the user starts speaking again. + self._accumulated_text: list[str] = [] + self._flush_counter = 0 def can_generate_metrics(self) -> bool: """Check if the service can generate metrics. @@ -252,46 +255,41 @@ async def cancel(self, frame: CancelFrame): await super().cancel(frame) await self._disconnect() - async def _handle_vad_user_stopped_speaking(self, frame: VADUserStoppedSpeakingFrame): - """Handle VAD user stopped speaking by flushing the transcription. + async def _start_metrics(self): + """Start performance metrics collection for transcription processing.""" + await self.start_processing_metrics() - Calls the base class handler for TTFB tracking, then sends silence - frames to flush remaining audio from the model's buffer. + async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process incoming frames and handle speech events. Args: - frame: The VAD user stopped speaking frame. + frame: The frame to process. + direction: Direction of frame flow in the pipeline. """ - await super()._handle_vad_user_stopped_speaking(frame) - await self._flush_transcription() + await super().process_frame(frame, direction) + + if isinstance(frame, VADUserStartedSpeakingFrame): + await self._start_metrics() + elif isinstance(frame, VADUserStoppedSpeakingFrame): + await self._send_flush() - async def _flush_transcription(self): - """Flush the transcription by sending silence frames. + async def _send_flush(self): + """Send a flush request to process any buffered audio immediately. - When VAD detects the user stopped speaking, we send delay_in_frames - chunks of silence (zeros) to flush the remaining audio from the model's - buffer. This allows for faster turn-around without closing the connection. + Sends a flush message to tell the server to process buffered audio. + The server responds with text fragments followed by a "flushed" + acknowledgment, which triggers finalization. """ if not self._websocket or self._websocket.state is not State.OPEN: return - if self._delay_in_frames <= 0: - logger.debug("No delay_in_frames set, skipping flush") - return - - # Create a silence chunk (zeros) of frame_size samples - # Each sample is 2 bytes (16-bit PCM) - silence_bytes = bytes(self._frame_size * 2) - silence_b64 = base64.b64encode(silence_bytes).decode("utf-8") - - logger.debug(f"Flushing Gradium STT with {self._delay_in_frames} silence frames") - - for _ in range(self._delay_in_frames): - msg = {"type": "audio", "audio": silence_b64} - try: - await self._websocket.send(json.dumps(msg)) - except Exception as e: - logger.warning(f"Failed to send silence frame: {e}") - break + self._flush_counter += 1 + flush_id = str(self._flush_counter) + msg = {"type": "flush", "flush_id": flush_id} + try: + await self._websocket.send(json.dumps(msg)) + except Exception as e: + logger.warning(f"Failed to send flush: {e}") async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: """Process audio data for speech-to-text conversion. @@ -302,7 +300,6 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: Yields: None (processing handled via WebSocket messages). """ - await self.start_processing_metrics() self._audio_buffer.extend(audio) while len(self._audio_buffer) >= self._chunk_size_bytes: @@ -370,13 +367,7 @@ async def _connect_websocket(self): if ready_msg["type"] != "ready": raise Exception(f"unexpected first message type {ready_msg['type']}") - # Store delay_in_frames and frame_size for silence flushing - self._delay_in_frames = ready_msg.get("delay_in_frames", 0) - self._frame_size = ready_msg.get("frame_size", 1920) - logger.debug( - f"Connected to Gradium STT (delay_in_frames={self._delay_in_frames}, " - f"frame_size={self._frame_size})" - ) + logger.debug("Connected to Gradium STT") except Exception as e: await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) @@ -402,16 +393,6 @@ async def _disconnect_websocket(self): self._websocket = None await self._call_event_handler("on_disconnected") - async def _send_keepalive(self, silence: bytes): - """Send silent audio to keep the Gradium connection alive. - - Args: - silence: Silent PCM audio bytes to send as a keepalive. - """ - chunk = base64.b64encode(silence).decode("utf-8") - msg = json.dumps({"type": "audio", "audio": chunk}) - await self._websocket.send(msg) - def _get_websocket(self): if self._websocket: return self._websocket @@ -428,19 +409,49 @@ async def _receive_messages(self): type_ = msg.get("type", "") if type_ == "text": await self._handle_text(msg["text"]) + elif type_ == "flushed": + await self._handle_flushed() elif type_ == "end_of_stream": logger.debug("Received end_of_stream message from server") elif type_ == "error": await self.push_error(error_msg=f"Error: {msg}") async def _handle_text(self, text: str): - """Handle transcription results.""" + """Handle streaming transcription fragment. + + Accumulates text and pushes an InterimTranscriptionFrame with the + full accumulated text so far. + """ + self._accumulated_text.append(text) + accumulated = " ".join(self._accumulated_text) + await self.push_frame( + InterimTranscriptionFrame( + text=accumulated, + user_id=self._user_id, + timestamp=time_now_iso8601(), + language=self._settings.language, + ) + ) + await self.stop_processing_metrics() + + async def _handle_flushed(self): + """Handle flush completion by pushing the finalized transcription. + + The "flushed" message confirms that buffered audio has been processed. + Any trailing text fragments that arrive after this will be caught by + the TTFB timeout handler. + """ + if not self._accumulated_text: + return + text = " ".join(self._accumulated_text) + self._accumulated_text.clear() + logger.debug(f"Final transcription: [{text}]") await self.push_frame( TranscriptionFrame( text, self._user_id, time_now_iso8601(), + self._settings.language, ) ) - await self._trace_transcription(text, is_final=True, language=None) - await self.stop_processing_metrics() + await self._trace_transcription(text, is_final=True, language=self._settings.language) From d58a47d12a09f0a69f432403d70cb785f3751e24 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Tue, 17 Mar 2026 23:07:28 -0400 Subject: [PATCH 06/10] Add transcript aggregation delay after flushed to capture trailing tokens Gradium flushed response can arrive before all text tokens have been delivered. Instead of finalizing immediately on flushed, start a short timer (100ms) that allows trailing tokens to accumulate before pushing the final TranscriptionFrame. --- src/pipecat/services/gradium/stt.py | 36 +++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/src/pipecat/services/gradium/stt.py b/src/pipecat/services/gradium/stt.py index fdddf68f21..297a226fea 100644 --- a/src/pipecat/services/gradium/stt.py +++ b/src/pipecat/services/gradium/stt.py @@ -10,6 +10,7 @@ WebSocket API for streaming audio transcription. """ +import asyncio import base64 import json from dataclasses import dataclass, field @@ -195,11 +196,13 @@ def __init__( self._chunk_size_bytes = 0 # Accumulates text fragments within a turn. Each "text" message - # appends to this list. On "flushed" the full text is joined and - # pushed as a TranscriptionFrame. Any trailing fragments are - # flushed when the user starts speaking again. + # appends to this list. On "flushed" a short aggregation delay + # allows trailing tokens to arrive before the full text is joined + # and pushed as a TranscriptionFrame. self._accumulated_text: list[str] = [] self._flush_counter = 0 + self._transcript_aggregation_delay = 0.1 # seconds to wait after flushed + self._transcript_aggregation_task: Optional[asyncio.Task] = None def can_generate_metrics(self) -> bool: """Check if the service can generate metrics. @@ -376,6 +379,10 @@ async def _connect_websocket(self): async def _disconnect(self): await super()._disconnect() + if self._transcript_aggregation_task: + await self.cancel_task(self._transcript_aggregation_task) + self._transcript_aggregation_task = None + if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None @@ -435,14 +442,29 @@ async def _handle_text(self, text: str): await self.stop_processing_metrics() async def _handle_flushed(self): - """Handle flush completion by pushing the finalized transcription. + """Handle flush completion by starting a transcript aggregation timer. - The "flushed" message confirms that buffered audio has been processed. - Any trailing text fragments that arrive after this will be caught by - the TTFB timeout handler. + The "flushed" message confirms that buffered audio has been processed, + but text tokens may still arrive after this point. A short timer allows + trailing tokens to accumulate before finalizing the transcription. """ + if self._transcript_aggregation_task: + await self.cancel_task(self._transcript_aggregation_task) + self._transcript_aggregation_task = self.create_task( + self._transcript_aggregation_handler(), "transcript_aggregation" + ) + + async def _transcript_aggregation_handler(self): + """Wait for trailing tokens then finalize the accumulated transcription.""" + await asyncio.sleep(self._transcript_aggregation_delay) + await self._finalize_accumulated_text() + + async def _finalize_accumulated_text(self): + """Join accumulated text, push TranscriptionFrame, and clear state.""" if not self._accumulated_text: return + self._transcript_aggregation_task = None + text = " ".join(self._accumulated_text) self._accumulated_text.clear() logger.debug(f"Final transcription: [{text}]") From c6945c55f319fd9afbedab1b5b75350be1017eaa Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Tue, 17 Mar 2026 23:10:29 -0400 Subject: [PATCH 07/10] Add changelog for PR #4066 --- changelog/4066.changed.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog/4066.changed.md diff --git a/changelog/4066.changed.md b/changelog/4066.changed.md new file mode 100644 index 0000000000..65e95ff2c0 --- /dev/null +++ b/changelog/4066.changed.md @@ -0,0 +1 @@ +- Improved `GradiumSTTService` transcription accuracy by reworking how text fragments are accumulated and finalized. Previously, trailing words could be dropped when the server's `flushed` response arrived before all text tokens were delivered. The service now uses a short aggregation delay after flush to capture trailing tokens, producing complete utterances. From b0f77bc7c4266e7a78a3ecbbc9ccd03c4f73fc2e Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Wed, 18 Mar 2026 09:00:32 -0400 Subject: [PATCH 08/10] Change default encoding to pcm_16000 --- src/pipecat/services/gradium/stt.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pipecat/services/gradium/stt.py b/src/pipecat/services/gradium/stt.py index 297a226fea..2572028c38 100644 --- a/src/pipecat/services/gradium/stt.py +++ b/src/pipecat/services/gradium/stt.py @@ -117,7 +117,7 @@ def __init__( *, api_key: str, api_endpoint_base_url: str = "wss://eu.api.gradium.ai/api/speech/asr", - encoding: str = "pcm", + encoding: str = "pcm_16000", params: Optional[InputParams] = None, json_config: Optional[str] = None, settings: Optional[Settings] = None, @@ -129,7 +129,8 @@ def __init__( Args: api_key: Gradium API key for authentication. api_endpoint_base_url: WebSocket endpoint URL. Defaults to Gradium's streaming endpoint. - encoding: Audio input format. One of "pcm", "wav", or "opus". Defaults to "pcm". + encoding: Audio input format. One of "pcm", "pcm_16000", "wav", or "opus". Defaults to + "pcm_16000". params: Configuration parameters for language and delay settings. .. deprecated:: 0.0.105 From 4d55a8e8f4dc0219367bdc2f2b831ebd86c69d01 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Wed, 18 Mar 2026 15:22:05 -0400 Subject: [PATCH 09/10] Code review feedback --- src/pipecat/services/gradium/stt.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/pipecat/services/gradium/stt.py b/src/pipecat/services/gradium/stt.py index 2572028c38..fe07786dc8 100644 --- a/src/pipecat/services/gradium/stt.py +++ b/src/pipecat/services/gradium/stt.py @@ -46,6 +46,9 @@ raise Exception(f"Missing module: {e}") SAMPLE_RATE = 24000 +# Seconds to wait after a "flushed" message for trailing text tokens to arrive +# before finalizing the transcription. +TRANSCRIPT_AGGREGATION_DELAY = 0.1 def language_to_gradium_language(language: Language) -> Optional[str]: @@ -202,7 +205,6 @@ def __init__( # and pushed as a TranscriptionFrame. self._accumulated_text: list[str] = [] self._flush_counter = 0 - self._transcript_aggregation_delay = 0.1 # seconds to wait after flushed self._transcript_aggregation_task: Optional[asyncio.Task] = None def can_generate_metrics(self) -> bool: @@ -384,6 +386,9 @@ async def _disconnect(self): await self.cancel_task(self._transcript_aggregation_task) self._transcript_aggregation_task = None + self._accumulated_text.clear() + self._flush_counter = 0 + if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None @@ -457,7 +462,7 @@ async def _handle_flushed(self): async def _transcript_aggregation_handler(self): """Wait for trailing tokens then finalize the accumulated transcription.""" - await asyncio.sleep(self._transcript_aggregation_delay) + await asyncio.sleep(TRANSCRIPT_AGGREGATION_DELAY) await self._finalize_accumulated_text() async def _finalize_accumulated_text(self): From 4d9d8afb3dbf262fa338dcab3a9a3ee8ffd81f2c Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Wed, 18 Mar 2026 15:47:49 -0400 Subject: [PATCH 10/10] Decouple encoding from sample_rate in Gradium STT The encoding parameter now takes just the base type (pcm, wav, opus) and the sample rate is derived from the pipeline audio_in_sample_rate, assembled dynamically via input_format_from_encoding(). This fixes the mismatch where SAMPLE_RATE=24000 was passed to the base class while encoding defaulted to pcm_16000. --- changelog/4066.changed.2.md | 1 + src/pipecat/services/gradium/stt.py | 47 +++++++++++++++++++++++++---- 2 files changed, 42 insertions(+), 6 deletions(-) create mode 100644 changelog/4066.changed.2.md diff --git a/changelog/4066.changed.2.md b/changelog/4066.changed.2.md new file mode 100644 index 0000000000..751961f106 --- /dev/null +++ b/changelog/4066.changed.2.md @@ -0,0 +1 @@ +- `GradiumSTTService` now takes both an `encoding` and `sample_rate` constructor argument which is assmebled in the class to form the `input_format`. PCM accepts `8000`, `16000`, and `24000` Hz sample rates. diff --git a/src/pipecat/services/gradium/stt.py b/src/pipecat/services/gradium/stt.py index fe07786dc8..5dea2c8241 100644 --- a/src/pipecat/services/gradium/stt.py +++ b/src/pipecat/services/gradium/stt.py @@ -45,12 +45,39 @@ logger.error('In order to use Gradium, you need to `pip install "pipecat-ai[gradium]"`.') raise Exception(f"Missing module: {e}") -SAMPLE_RATE = 24000 # Seconds to wait after a "flushed" message for trailing text tokens to arrive # before finalizing the transcription. TRANSCRIPT_AGGREGATION_DELAY = 0.1 +def _input_format_from_encoding(encoding: str, sample_rate: int) -> str: + """Build Gradium input_format from encoding type and sample rate. + + For PCM encoding, appends the sample rate (e.g., "pcm_16000"). + For other encodings (wav, opus), returns the encoding as-is. + + Args: + encoding: Base encoding type ("pcm", "wav", or "opus"). + sample_rate: Audio sample rate in Hz. + + Returns: + The full input_format string for the Gradium API. + """ + if encoding == "pcm": + match sample_rate: + case 8000: + return "pcm_8000" + case 16000: + return "pcm_16000" + case 24000: + return "pcm_24000" + logger.warning( + f"GradiumSTTService: unsupported sample rate {sample_rate} for PCM encoding, using pcm_16000" + ) + return "pcm_16000" + return encoding + + def language_to_gradium_language(language: Language) -> Optional[str]: """Convert a Language enum to Gradium's language code format. @@ -120,7 +147,8 @@ def __init__( *, api_key: str, api_endpoint_base_url: str = "wss://eu.api.gradium.ai/api/speech/asr", - encoding: str = "pcm_16000", + encoding: str = "pcm", + sample_rate: Optional[int] = None, params: Optional[InputParams] = None, json_config: Optional[str] = None, settings: Optional[Settings] = None, @@ -132,8 +160,12 @@ def __init__( Args: api_key: Gradium API key for authentication. api_endpoint_base_url: WebSocket endpoint URL. Defaults to Gradium's streaming endpoint. - encoding: Audio input format. One of "pcm", "pcm_16000", "wav", or "opus". Defaults to - "pcm_16000". + encoding: Base audio encoding type. One of "pcm", "wav", or "opus". + For PCM, the sample rate is appended automatically from the + pipeline's audio_in_sample_rate (e.g., "pcm" becomes "pcm_16000"). + Defaults to "pcm". + sample_rate: Audio sample rate in Hz. If None, uses the pipeline + sample rate. params: Configuration parameters for language and delay settings. .. deprecated:: 0.0.105 @@ -181,7 +213,7 @@ def __init__( default_settings.apply_update(settings) super().__init__( - sample_rate=SAMPLE_RATE, + sample_rate=sample_rate, ttfs_p99_latency=ttfs_p99_latency, settings=default_settings, **kwargs, @@ -195,6 +227,8 @@ def __init__( self._receive_task = None + self._input_format = "" + self._audio_buffer = bytearray() self._chunk_size_ms = 80 self._chunk_size_bytes = 0 @@ -240,6 +274,7 @@ async def start(self, frame: StartFrame): frame: Start frame to begin processing. """ await super().start(frame) + self._input_format = _input_format_from_encoding(self._encoding, self.sample_rate) self._chunk_size_bytes = int(self._chunk_size_ms * self.sample_rate * 2 / 1000) await self._connect() @@ -351,7 +386,7 @@ async def _connect_websocket(self): setup_msg = { "type": "setup", "model_name": self._settings.model, - "input_format": self._encoding, + "input_format": self._input_format, } # Build json_config: start with deprecated json_config, then override with params json_config = {}