@@ -134,6 +134,7 @@ def __init__(
134134
135135 params = params or NvidiaSTTService .InputParams ()
136136
137+ self ._server = server
137138 self ._api_key = api_key
138139 self ._use_ssl = use_ssl
139140 self ._profanity_filter = False
@@ -162,18 +163,54 @@ def __init__(
162163
163164 self .set_model_name (model_function_map .get ("model_name" ))
164165
166+ self ._asr_service = None
167+ self ._queue = None
168+ self ._config = None
169+ self ._thread_task = None
170+ self ._response_task = None
171+
172+ def _initialize_client (self ):
165173 metadata = [
166174 ["function-id" , self ._function_id ],
167- ["authorization" , f"Bearer { api_key } " ],
175+ ["authorization" , f"Bearer { self . _api_key } " ],
168176 ]
169- auth = riva .client .Auth (None , self ._use_ssl , server , metadata )
177+ auth = riva .client .Auth (None , self ._use_ssl , self . _server , metadata )
170178
171179 self ._asr_service = riva .client .ASRService (auth )
172180
173- self ._queue = None
174- self ._config = None
175- self ._thread_task = None
176- self ._response_task = None
181+ def _create_recognition_config (self ):
182+ """Create the NVIDIA Riva ASR recognition configuration."""
183+ config = riva .client .StreamingRecognitionConfig (
184+ config = riva .client .RecognitionConfig (
185+ encoding = riva .client .AudioEncoding .LINEAR_PCM ,
186+ language_code = self ._language_code ,
187+ model = "" ,
188+ max_alternatives = 1 ,
189+ profanity_filter = self ._profanity_filter ,
190+ enable_automatic_punctuation = self ._automatic_punctuation ,
191+ verbatim_transcripts = not self ._no_verbatim_transcripts ,
192+ sample_rate_hertz = self .sample_rate ,
193+ audio_channel_count = 1 ,
194+ ),
195+ interim_results = True ,
196+ )
197+
198+ riva .client .add_word_boosting_to_config (
199+ config , self ._boosted_lm_words , self ._boosted_lm_score
200+ )
201+
202+ riva .client .add_endpoint_parameters_to_config (
203+ config ,
204+ self ._start_history ,
205+ self ._start_threshold ,
206+ self ._stop_history ,
207+ self ._stop_history_eou ,
208+ self ._stop_threshold ,
209+ self ._stop_threshold_eou ,
210+ )
211+ riva .client .add_custom_configuration_to_config (config , self ._custom_configuration )
212+
213+ return config
177214
178215 def can_generate_metrics (self ) -> bool :
179216 """Check if this service can generate processing metrics.
@@ -206,41 +243,9 @@ async def start(self, frame: StartFrame):
206243 frame: StartFrame indicating pipeline start.
207244 """
208245 await super ().start (frame )
246+ self ._initialize_client ()
247+ self ._config = self ._create_recognition_config ()
209248
210- if self ._config :
211- return
212-
213- config = riva .client .StreamingRecognitionConfig (
214- config = riva .client .RecognitionConfig (
215- encoding = riva .client .AudioEncoding .LINEAR_PCM ,
216- language_code = self ._language_code ,
217- model = "" ,
218- max_alternatives = 1 ,
219- profanity_filter = self ._profanity_filter ,
220- enable_automatic_punctuation = self ._automatic_punctuation ,
221- verbatim_transcripts = not self ._no_verbatim_transcripts ,
222- sample_rate_hertz = self .sample_rate ,
223- audio_channel_count = 1 ,
224- ),
225- interim_results = True ,
226- )
227-
228- riva .client .add_word_boosting_to_config (
229- config , self ._boosted_lm_words , self ._boosted_lm_score
230- )
231-
232- riva .client .add_endpoint_parameters_to_config (
233- config ,
234- self ._start_history ,
235- self ._start_threshold ,
236- self ._stop_history ,
237- self ._stop_history_eou ,
238- self ._stop_threshold ,
239- self ._stop_threshold_eou ,
240- )
241- riva .client .add_custom_configuration_to_config (config , self ._custom_configuration )
242-
243- self ._config = config
244249 self ._queue = asyncio .Queue ()
245250
246251 if not self ._thread_task :
@@ -250,6 +255,8 @@ async def start(self, frame: StartFrame):
250255 self ._response_queue = asyncio .Queue ()
251256 self ._response_task = self .create_task (self ._response_task_handler ())
252257
258+ logger .debug (f"Initialized NvidiaSTTService with model: { self .model_name } " )
259+
253260 async def stop (self , frame : EndFrame ):
254261 """Stop the NVIDIA Riva STT service and clean up resources.
255262
@@ -503,8 +510,6 @@ def _initialize_client(self):
503510 auth = riva .client .Auth (None , self ._use_ssl , self ._server , metadata )
504511 self ._asr_service = riva .client .ASRService (auth )
505512
506- logger .info (f"Initialized NvidiaSegmentedSTTService with model: { self .model_name } " )
507-
508513 def _create_recognition_config (self ):
509514 """Create the NVIDIA Riva ASR recognition configuration."""
510515 # Create base configuration
@@ -572,6 +577,7 @@ async def start(self, frame: StartFrame):
572577 await super ().start (frame )
573578 self ._initialize_client ()
574579 self ._config = self ._create_recognition_config ()
580+ logger .debug (f"Initialized NvidiaSegmentedSTTService with model: { self .model_name } " )
575581
576582 async def set_language (self , language : Language ):
577583 """Set the language for the STT service.
@@ -605,65 +611,53 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
605611 Frame: TranscriptionFrame containing the transcribed text.
606612 """
607613 try :
608- await self .start_processing_metrics ()
609- await self .start_ttfb_metrics ()
610-
611- # Make sure the client is initialized
612- if self ._asr_service is None :
613- self ._initialize_client ()
614-
615- # Make sure the config is created
616- if self ._config is None :
617- self ._config = self ._create_recognition_config ()
618-
619- # Type assertion to satisfy the IDE
620614 assert self ._asr_service is not None , "ASR service not initialized"
621615 assert self ._config is not None , "Recognition config not created"
622616
617+ await self .start_processing_metrics ()
618+ await self .start_ttfb_metrics ()
619+
623620 # Process audio with NVIDIA Riva ASR - explicitly request non-future response
624621 raw_response = self ._asr_service .offline_recognize (audio , self ._config , future = False )
625622
626623 await self .stop_ttfb_metrics ()
627624 await self .stop_processing_metrics ()
628625
629626 # Process the response - handle different possible return types
630- try :
631- # If it's a future-like object, get the result
632- if hasattr (raw_response , "result" ):
633- response = raw_response .result ()
634- else :
635- response = raw_response
636-
637- # Process transcription results
638- transcription_found = False
639-
640- # Now we can safely check results
641- # Type hint for the IDE
642- results = getattr (response , "results" , [])
643-
644- for result in results :
645- alternatives = getattr (result , "alternatives" , [])
646- if alternatives :
647- text = alternatives [0 ].transcript .strip ()
648- if text :
649- logger .debug (f"Transcription: [{ text } ]" )
650- yield TranscriptionFrame (
651- text ,
652- self ._user_id ,
653- time_now_iso8601 (),
654- self ._language_enum ,
655- )
656- transcription_found = True
657-
658- await self ._handle_transcription (text , True , self ._language_enum )
659-
660- if not transcription_found :
661- logger .debug ("No transcription results found in NVIDIA Riva response" )
662-
663- except AttributeError as ae :
664- logger .error (f"Unexpected response structure from NVIDIA Riva: { ae } " )
665- yield ErrorFrame (f"Unexpected NVIDIA Riva response format: { str (ae )} " )
627+ # If it's a future-like object, get the result
628+ if hasattr (raw_response , "result" ):
629+ response = raw_response .result ()
630+ else :
631+ response = raw_response
632+
633+ # Process transcription results
634+ transcription_found = False
635+
636+ # Now we can safely check results
637+ # Type hint for the IDE
638+ results = getattr (response , "results" , [])
639+
640+ for result in results :
641+ alternatives = getattr (result , "alternatives" , [])
642+ if alternatives :
643+ text = alternatives [0 ].transcript .strip ()
644+ if text :
645+ logger .debug (f"Transcription: [{ text } ]" )
646+ yield TranscriptionFrame (
647+ text ,
648+ self ._user_id ,
649+ time_now_iso8601 (),
650+ self ._language_enum ,
651+ )
652+ transcription_found = True
653+
654+ await self ._handle_transcription (text , True , self ._language_enum )
666655
656+ if not transcription_found :
657+ logger .debug (f"{ self } : No transcription results found in NVIDIA Riva response" )
658+ except AttributeError as ae :
659+ logger .error (f"{ self } : Unexpected response structure from NVIDIA Riva: { ae } " )
660+ yield ErrorFrame (f"{ self } : Unexpected NVIDIA Riva response format: { str (ae )} " )
667661 except Exception as e :
668662 logger .error (f"{ self } exception: { e } " )
669663 yield ErrorFrame (error = f"{ self } error: { e } " )
0 commit comments