Skip to content

Commit bf43032

Browse files
authored
Merge pull request pipecat-ai#3504 from pipecat-ai/aleix/nvidia-stt-tts-error-handling
NVIDIA STT/TTS error handling
2 parents 024809b + a010a02 commit bf43032

3 files changed

Lines changed: 135 additions & 109 deletions

File tree

changelog/3504.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
- Moved `NVIDIATTSService` and `NVIDIASTTService` client initialization from constructor to `start()` for better error handling.

src/pipecat/services/nvidia/stt.py

Lines changed: 84 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def __init__(
134134

135135
params = params or NvidiaSTTService.InputParams()
136136

137+
self._server = server
137138
self._api_key = api_key
138139
self._use_ssl = use_ssl
139140
self._profanity_filter = False
@@ -162,18 +163,54 @@ def __init__(
162163

163164
self.set_model_name(model_function_map.get("model_name"))
164165

166+
self._asr_service = None
167+
self._queue = None
168+
self._config = None
169+
self._thread_task = None
170+
self._response_task = None
171+
172+
def _initialize_client(self):
165173
metadata = [
166174
["function-id", self._function_id],
167-
["authorization", f"Bearer {api_key}"],
175+
["authorization", f"Bearer {self._api_key}"],
168176
]
169-
auth = riva.client.Auth(None, self._use_ssl, server, metadata)
177+
auth = riva.client.Auth(None, self._use_ssl, self._server, metadata)
170178

171179
self._asr_service = riva.client.ASRService(auth)
172180

173-
self._queue = None
174-
self._config = None
175-
self._thread_task = None
176-
self._response_task = None
181+
def _create_recognition_config(self):
182+
"""Create the NVIDIA Riva ASR recognition configuration."""
183+
config = riva.client.StreamingRecognitionConfig(
184+
config=riva.client.RecognitionConfig(
185+
encoding=riva.client.AudioEncoding.LINEAR_PCM,
186+
language_code=self._language_code,
187+
model="",
188+
max_alternatives=1,
189+
profanity_filter=self._profanity_filter,
190+
enable_automatic_punctuation=self._automatic_punctuation,
191+
verbatim_transcripts=not self._no_verbatim_transcripts,
192+
sample_rate_hertz=self.sample_rate,
193+
audio_channel_count=1,
194+
),
195+
interim_results=True,
196+
)
197+
198+
riva.client.add_word_boosting_to_config(
199+
config, self._boosted_lm_words, self._boosted_lm_score
200+
)
201+
202+
riva.client.add_endpoint_parameters_to_config(
203+
config,
204+
self._start_history,
205+
self._start_threshold,
206+
self._stop_history,
207+
self._stop_history_eou,
208+
self._stop_threshold,
209+
self._stop_threshold_eou,
210+
)
211+
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
212+
213+
return config
177214

178215
def can_generate_metrics(self) -> bool:
179216
"""Check if this service can generate processing metrics.
@@ -206,41 +243,9 @@ async def start(self, frame: StartFrame):
206243
frame: StartFrame indicating pipeline start.
207244
"""
208245
await super().start(frame)
246+
self._initialize_client()
247+
self._config = self._create_recognition_config()
209248

210-
if self._config:
211-
return
212-
213-
config = riva.client.StreamingRecognitionConfig(
214-
config=riva.client.RecognitionConfig(
215-
encoding=riva.client.AudioEncoding.LINEAR_PCM,
216-
language_code=self._language_code,
217-
model="",
218-
max_alternatives=1,
219-
profanity_filter=self._profanity_filter,
220-
enable_automatic_punctuation=self._automatic_punctuation,
221-
verbatim_transcripts=not self._no_verbatim_transcripts,
222-
sample_rate_hertz=self.sample_rate,
223-
audio_channel_count=1,
224-
),
225-
interim_results=True,
226-
)
227-
228-
riva.client.add_word_boosting_to_config(
229-
config, self._boosted_lm_words, self._boosted_lm_score
230-
)
231-
232-
riva.client.add_endpoint_parameters_to_config(
233-
config,
234-
self._start_history,
235-
self._start_threshold,
236-
self._stop_history,
237-
self._stop_history_eou,
238-
self._stop_threshold,
239-
self._stop_threshold_eou,
240-
)
241-
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
242-
243-
self._config = config
244249
self._queue = asyncio.Queue()
245250

246251
if not self._thread_task:
@@ -250,6 +255,8 @@ async def start(self, frame: StartFrame):
250255
self._response_queue = asyncio.Queue()
251256
self._response_task = self.create_task(self._response_task_handler())
252257

258+
logger.debug(f"Initialized NvidiaSTTService with model: {self.model_name}")
259+
253260
async def stop(self, frame: EndFrame):
254261
"""Stop the NVIDIA Riva STT service and clean up resources.
255262
@@ -503,8 +510,6 @@ def _initialize_client(self):
503510
auth = riva.client.Auth(None, self._use_ssl, self._server, metadata)
504511
self._asr_service = riva.client.ASRService(auth)
505512

506-
logger.info(f"Initialized NvidiaSegmentedSTTService with model: {self.model_name}")
507-
508513
def _create_recognition_config(self):
509514
"""Create the NVIDIA Riva ASR recognition configuration."""
510515
# Create base configuration
@@ -572,6 +577,7 @@ async def start(self, frame: StartFrame):
572577
await super().start(frame)
573578
self._initialize_client()
574579
self._config = self._create_recognition_config()
580+
logger.debug(f"Initialized NvidiaSegmentedSTTService with model: {self.model_name}")
575581

576582
async def set_language(self, language: Language):
577583
"""Set the language for the STT service.
@@ -605,65 +611,53 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
605611
Frame: TranscriptionFrame containing the transcribed text.
606612
"""
607613
try:
608-
await self.start_processing_metrics()
609-
await self.start_ttfb_metrics()
610-
611-
# Make sure the client is initialized
612-
if self._asr_service is None:
613-
self._initialize_client()
614-
615-
# Make sure the config is created
616-
if self._config is None:
617-
self._config = self._create_recognition_config()
618-
619-
# Type assertion to satisfy the IDE
620614
assert self._asr_service is not None, "ASR service not initialized"
621615
assert self._config is not None, "Recognition config not created"
622616

617+
await self.start_processing_metrics()
618+
await self.start_ttfb_metrics()
619+
623620
# Process audio with NVIDIA Riva ASR - explicitly request non-future response
624621
raw_response = self._asr_service.offline_recognize(audio, self._config, future=False)
625622

626623
await self.stop_ttfb_metrics()
627624
await self.stop_processing_metrics()
628625

629626
# Process the response - handle different possible return types
630-
try:
631-
# If it's a future-like object, get the result
632-
if hasattr(raw_response, "result"):
633-
response = raw_response.result()
634-
else:
635-
response = raw_response
636-
637-
# Process transcription results
638-
transcription_found = False
639-
640-
# Now we can safely check results
641-
# Type hint for the IDE
642-
results = getattr(response, "results", [])
643-
644-
for result in results:
645-
alternatives = getattr(result, "alternatives", [])
646-
if alternatives:
647-
text = alternatives[0].transcript.strip()
648-
if text:
649-
logger.debug(f"Transcription: [{text}]")
650-
yield TranscriptionFrame(
651-
text,
652-
self._user_id,
653-
time_now_iso8601(),
654-
self._language_enum,
655-
)
656-
transcription_found = True
657-
658-
await self._handle_transcription(text, True, self._language_enum)
659-
660-
if not transcription_found:
661-
logger.debug("No transcription results found in NVIDIA Riva response")
662-
663-
except AttributeError as ae:
664-
logger.error(f"Unexpected response structure from NVIDIA Riva: {ae}")
665-
yield ErrorFrame(f"Unexpected NVIDIA Riva response format: {str(ae)}")
627+
# If it's a future-like object, get the result
628+
if hasattr(raw_response, "result"):
629+
response = raw_response.result()
630+
else:
631+
response = raw_response
632+
633+
# Process transcription results
634+
transcription_found = False
635+
636+
# Now we can safely check results
637+
# Type hint for the IDE
638+
results = getattr(response, "results", [])
639+
640+
for result in results:
641+
alternatives = getattr(result, "alternatives", [])
642+
if alternatives:
643+
text = alternatives[0].transcript.strip()
644+
if text:
645+
logger.debug(f"Transcription: [{text}]")
646+
yield TranscriptionFrame(
647+
text,
648+
self._user_id,
649+
time_now_iso8601(),
650+
self._language_enum,
651+
)
652+
transcription_found = True
653+
654+
await self._handle_transcription(text, True, self._language_enum)
666655

656+
if not transcription_found:
657+
logger.debug(f"{self}: No transcription results found in NVIDIA Riva response")
658+
except AttributeError as ae:
659+
logger.error(f"{self}: Unexpected response structure from NVIDIA Riva: {ae}")
660+
yield ErrorFrame(f"{self}: Unexpected NVIDIA Riva response format: {str(ae)}")
667661
except Exception as e:
668662
logger.error(f"{self} exception: {e}")
669663
yield ErrorFrame(error=f"{self} error: {e}")

src/pipecat/services/nvidia/tts.py

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pipecat.frames.frames import (
2626
ErrorFrame,
2727
Frame,
28+
StartFrame,
2829
TTSAudioRawFrame,
2930
TTSStartedFrame,
3031
TTSStoppedFrame,
@@ -93,6 +94,7 @@ def __init__(
9394

9495
params = params or NvidiaTTSService.InputParams()
9596

97+
self._server = server
9698
self._api_key = api_key
9799
self._voice_id = voice_id
98100
self._language_code = params.language
@@ -102,18 +104,8 @@ def __init__(
102104
self.set_model_name(model_function_map.get("model_name"))
103105
self.set_voice(voice_id)
104106

105-
metadata = [
106-
["function-id", self._function_id],
107-
["authorization", f"Bearer {api_key}"],
108-
]
109-
auth = riva.client.Auth(None, self._use_ssl, server, metadata)
110-
111-
self._service = riva.client.SpeechSynthesisService(auth)
112-
113-
# warm up the service
114-
config_response = self._service.stub.GetRivaSynthesisConfig(
115-
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
116-
)
107+
self._service = None
108+
self._config = None
117109

118110
async def set_model(self, model: str):
119111
"""Attempt to set the TTS model.
@@ -129,6 +121,39 @@ async def set_model(self, model: str):
129121
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
130122
)
131123

124+
def _initialize_client(self):
125+
if self._service is not None:
126+
return
127+
128+
metadata = [
129+
["function-id", self._function_id],
130+
["authorization", f"Bearer {self._api_key}"],
131+
]
132+
auth = riva.client.Auth(None, self._use_ssl, self._server, metadata)
133+
134+
self._service = riva.client.SpeechSynthesisService(auth)
135+
136+
def _create_synthesis_config(self):
137+
if not self._service:
138+
return
139+
140+
# warm up the service
141+
config = self._service.stub.GetRivaSynthesisConfig(
142+
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
143+
)
144+
return config
145+
146+
async def start(self, frame: StartFrame):
147+
"""Start the Cartesia TTS service.
148+
149+
Args:
150+
frame: The start frame containing initialization parameters.
151+
"""
152+
await super().start(frame)
153+
self._initialize_client()
154+
self._config = self._create_synthesis_config()
155+
logger.debug(f"Initialized NvidiaTTSService with model: {self.model_name}")
156+
132157
@traced_tts
133158
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
134159
"""Generate speech from text using NVIDIA Riva TTS.
@@ -161,12 +186,15 @@ def add_response(r):
161186
logger.error(f"{self} exception: {e}")
162187
add_response(None)
163188

164-
await self.start_ttfb_metrics()
165-
yield TTSStartedFrame()
189+
try:
190+
assert self._service is not None, "TTS service not initialized"
191+
assert self._config is not None, "Synthesis configuration not created"
192+
193+
await self.start_ttfb_metrics()
194+
yield TTSStartedFrame()
166195

167-
logger.debug(f"{self}: Generating TTS [{text}]")
196+
logger.debug(f"{self}: Generating TTS [{text}]")
168197

169-
try:
170198
queue = asyncio.Queue()
171199
await asyncio.to_thread(read_audio_responses, queue)
172200

@@ -181,9 +209,12 @@ def add_response(r):
181209
)
182210
yield frame
183211
resp = await asyncio.wait_for(queue.get(), timeout=NVIDIA_TTS_TIMEOUT_SECS)
212+
213+
await self.start_tts_usage_metrics(text)
214+
yield TTSStoppedFrame()
184215
except asyncio.TimeoutError:
185216
logger.error(f"{self} timeout waiting for audio response")
186217
yield ErrorFrame(error=f"{self} error: {e}")
187-
188-
await self.start_tts_usage_metrics(text)
189-
yield TTSStoppedFrame()
218+
except Exception as e:
219+
logger.error(f"{self} exception: {e}")
220+
yield ErrorFrame(error=f"{self} error: {e}")

0 commit comments

Comments
 (0)