Skip to content

Commit 671dc8c

Browse files
committed
NvidiaSTTService: initialize client on StartFrame
Initialize client on StartFrame so errrors are reported within the pipeline.
1 parent 9a718de commit 671dc8c

1 file changed

Lines changed: 47 additions & 40 deletions

File tree

  • src/pipecat/services/nvidia

src/pipecat/services/nvidia/stt.py

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def __init__(
134134

135135
params = params or NvidiaSTTService.InputParams()
136136

137+
self._server = server
137138
self._api_key = api_key
138139
self._use_ssl = use_ssl
139140
self._profanity_filter = False
@@ -162,18 +163,54 @@ def __init__(
162163

163164
self.set_model_name(model_function_map.get("model_name"))
164165

166+
self._asr_service = None
167+
self._queue = None
168+
self._config = None
169+
self._thread_task = None
170+
self._response_task = None
171+
172+
def _initialize_client(self):
165173
metadata = [
166174
["function-id", self._function_id],
167-
["authorization", f"Bearer {api_key}"],
175+
["authorization", f"Bearer {self._api_key}"],
168176
]
169-
auth = riva.client.Auth(None, self._use_ssl, server, metadata)
177+
auth = riva.client.Auth(None, self._use_ssl, self._server, metadata)
170178

171179
self._asr_service = riva.client.ASRService(auth)
172180

173-
self._queue = None
174-
self._config = None
175-
self._thread_task = None
176-
self._response_task = None
181+
def _create_recognition_config(self):
182+
"""Create the NVIDIA Riva ASR recognition configuration."""
183+
config = riva.client.StreamingRecognitionConfig(
184+
config=riva.client.RecognitionConfig(
185+
encoding=riva.client.AudioEncoding.LINEAR_PCM,
186+
language_code=self._language_code,
187+
model="",
188+
max_alternatives=1,
189+
profanity_filter=self._profanity_filter,
190+
enable_automatic_punctuation=self._automatic_punctuation,
191+
verbatim_transcripts=not self._no_verbatim_transcripts,
192+
sample_rate_hertz=self.sample_rate,
193+
audio_channel_count=1,
194+
),
195+
interim_results=True,
196+
)
197+
198+
riva.client.add_word_boosting_to_config(
199+
config, self._boosted_lm_words, self._boosted_lm_score
200+
)
201+
202+
riva.client.add_endpoint_parameters_to_config(
203+
config,
204+
self._start_history,
205+
self._start_threshold,
206+
self._stop_history,
207+
self._stop_history_eou,
208+
self._stop_threshold,
209+
self._stop_threshold_eou,
210+
)
211+
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
212+
213+
return config
177214

178215
def can_generate_metrics(self) -> bool:
179216
"""Check if this service can generate processing metrics.
@@ -206,41 +243,9 @@ async def start(self, frame: StartFrame):
206243
frame: StartFrame indicating pipeline start.
207244
"""
208245
await super().start(frame)
246+
self._initialize_client()
247+
self._config = self._create_recognition_config()
209248

210-
if self._config:
211-
return
212-
213-
config = riva.client.StreamingRecognitionConfig(
214-
config=riva.client.RecognitionConfig(
215-
encoding=riva.client.AudioEncoding.LINEAR_PCM,
216-
language_code=self._language_code,
217-
model="",
218-
max_alternatives=1,
219-
profanity_filter=self._profanity_filter,
220-
enable_automatic_punctuation=self._automatic_punctuation,
221-
verbatim_transcripts=not self._no_verbatim_transcripts,
222-
sample_rate_hertz=self.sample_rate,
223-
audio_channel_count=1,
224-
),
225-
interim_results=True,
226-
)
227-
228-
riva.client.add_word_boosting_to_config(
229-
config, self._boosted_lm_words, self._boosted_lm_score
230-
)
231-
232-
riva.client.add_endpoint_parameters_to_config(
233-
config,
234-
self._start_history,
235-
self._start_threshold,
236-
self._stop_history,
237-
self._stop_history_eou,
238-
self._stop_threshold,
239-
self._stop_threshold_eou,
240-
)
241-
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
242-
243-
self._config = config
244249
self._queue = asyncio.Queue()
245250

246251
if not self._thread_task:
@@ -250,6 +255,8 @@ async def start(self, frame: StartFrame):
250255
self._response_queue = asyncio.Queue()
251256
self._response_task = self.create_task(self._response_task_handler())
252257

258+
logger.debug(f"Initialized NvidiaSTTService with model: {self.model_name}")
259+
253260
async def stop(self, frame: EndFrame):
254261
"""Stop the NVIDIA Riva STT service and clean up resources.
255262

0 commit comments

Comments
 (0)