Skip to content

Commit ef794ff

Browse files
committed
Decouple encoding from sample_rate in Gradium STT
The encoding parameter now takes just the base type (pcm, wav, opus) and the sample rate is derived from the pipeline audio_in_sample_rate, assembled dynamically via input_format_from_encoding(). This fixes the mismatch where SAMPLE_RATE=24000 was passed to the base class while encoding defaulted to pcm_16000.
1 parent 4d55a8e commit ef794ff

2 files changed

Lines changed: 42 additions & 6 deletions

File tree

changelog/4066.changed.2.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
- `GradiumSTTService` now takes both an `encoding` and `sample_rate` constructor argument which is assmebled in the class to form the `input_format`. PCM accepts `8000`, `16000`, and `24000` Hz sample rates.

src/pipecat/services/gradium/stt.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,39 @@
4545
logger.error('In order to use Gradium, you need to `pip install "pipecat-ai[gradium]"`.')
4646
raise Exception(f"Missing module: {e}")
4747

48-
SAMPLE_RATE = 24000
4948
# Seconds to wait after a "flushed" message for trailing text tokens to arrive
5049
# before finalizing the transcription.
5150
TRANSCRIPT_AGGREGATION_DELAY = 0.1
5251

5352

53+
def input_format_from_encoding(encoding: str, sample_rate: int) -> str:
54+
"""Build Gradium input_format from encoding type and sample rate.
55+
56+
For PCM encoding, appends the sample rate (e.g., "pcm_16000").
57+
For other encodings (wav, opus), returns the encoding as-is.
58+
59+
Args:
60+
encoding: Base encoding type ("pcm", "wav", or "opus").
61+
sample_rate: Audio sample rate in Hz.
62+
63+
Returns:
64+
The full input_format string for the Gradium API.
65+
"""
66+
if encoding == "pcm":
67+
match sample_rate:
68+
case 8000:
69+
return "pcm_8000"
70+
case 16000:
71+
return "pcm_16000"
72+
case 24000:
73+
return "pcm_24000"
74+
logger.warning(
75+
f"GradiumSTTService: unsupported sample rate {sample_rate} for PCM encoding, using pcm_16000"
76+
)
77+
return "pcm_16000"
78+
return encoding
79+
80+
5481
def language_to_gradium_language(language: Language) -> Optional[str]:
5582
"""Convert a Language enum to Gradium's language code format.
5683
@@ -120,7 +147,8 @@ def __init__(
120147
*,
121148
api_key: str,
122149
api_endpoint_base_url: str = "wss://eu.api.gradium.ai/api/speech/asr",
123-
encoding: str = "pcm_16000",
150+
encoding: str = "pcm",
151+
sample_rate: Optional[int] = None,
124152
params: Optional[InputParams] = None,
125153
json_config: Optional[str] = None,
126154
settings: Optional[Settings] = None,
@@ -132,8 +160,12 @@ def __init__(
132160
Args:
133161
api_key: Gradium API key for authentication.
134162
api_endpoint_base_url: WebSocket endpoint URL. Defaults to Gradium's streaming endpoint.
135-
encoding: Audio input format. One of "pcm", "pcm_16000", "wav", or "opus". Defaults to
136-
"pcm_16000".
163+
encoding: Base audio encoding type. One of "pcm", "wav", or "opus".
164+
For PCM, the sample rate is appended automatically from the
165+
pipeline's audio_in_sample_rate (e.g., "pcm" becomes "pcm_16000").
166+
Defaults to "pcm".
167+
sample_rate: Audio sample rate in Hz. If None, uses the pipeline
168+
sample rate.
137169
params: Configuration parameters for language and delay settings.
138170
139171
.. deprecated:: 0.0.105
@@ -181,7 +213,7 @@ def __init__(
181213
default_settings.apply_update(settings)
182214

183215
super().__init__(
184-
sample_rate=SAMPLE_RATE,
216+
sample_rate=sample_rate,
185217
ttfs_p99_latency=ttfs_p99_latency,
186218
settings=default_settings,
187219
**kwargs,
@@ -195,6 +227,8 @@ def __init__(
195227

196228
self._receive_task = None
197229

230+
self._input_format = ""
231+
198232
self._audio_buffer = bytearray()
199233
self._chunk_size_ms = 80
200234
self._chunk_size_bytes = 0
@@ -240,6 +274,7 @@ async def start(self, frame: StartFrame):
240274
frame: Start frame to begin processing.
241275
"""
242276
await super().start(frame)
277+
self._input_format = input_format_from_encoding(self._encoding, self.sample_rate)
243278
self._chunk_size_bytes = int(self._chunk_size_ms * self.sample_rate * 2 / 1000)
244279
await self._connect()
245280

@@ -351,7 +386,7 @@ async def _connect_websocket(self):
351386
setup_msg = {
352387
"type": "setup",
353388
"model_name": self._settings.model,
354-
"input_format": self._encoding,
389+
"input_format": self._input_format,
355390
}
356391
# Build json_config: start with deprecated json_config, then override with params
357392
json_config = {}

0 commit comments

Comments
 (0)