1212
1313import asyncio
1414import os
15- from typing import AsyncGenerator , Mapping , Optional
15+ from typing import AsyncGenerator , AsyncIterable , Generator , Mapping , Optional
1616
1717from pipecat .utils .tracing .service_decorators import traced_tts
1818
3535
3636try :
3737 import riva .client
38-
38+ import riva . client . proto . riva_tts_pb2 as rtts
3939except ModuleNotFoundError as e :
4040 logger .error (f"Exception: { e } " )
4141 logger .error ("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[nvidia]`." )
4242 raise Exception (f"Missing module: { e } " )
4343
44- NVIDIA_TTS_TIMEOUT_SECS = 5
45-
4644
4745class NvidiaTTSService (TTSService ):
4846 """NVIDIA Riva text-to-speech service.
@@ -165,26 +163,30 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
165163 Frame: Audio frames containing the synthesized speech data.
166164 """
167165
168- def read_audio_responses (queue : asyncio .Queue ):
169- def add_response (r ):
170- asyncio .run_coroutine_threadsafe (queue .put (r ), self .get_event_loop ())
171-
166+ def read_audio_responses () -> Generator [rtts .SynthesizeSpeechResponse , None , None ]:
167+ responses = self ._service .synthesize_online (
168+ text ,
169+ self ._voice_id ,
170+ self ._language_code ,
171+ sample_rate_hz = self .sample_rate ,
172+ zero_shot_audio_prompt_file = None ,
173+ zero_shot_quality = self ._quality ,
174+ custom_dictionary = {},
175+ )
176+ return responses
177+
178+ def async_next (it ):
172179 try :
173- responses = self ._service .synthesize_online (
174- text ,
175- self ._voice_id ,
176- self ._language_code ,
177- sample_rate_hz = self .sample_rate ,
178- zero_shot_audio_prompt_file = None ,
179- zero_shot_quality = self ._quality ,
180- custom_dictionary = {},
181- )
182- for r in responses :
183- add_response (r )
184- add_response (None )
185- except Exception as e :
186- logger .error (f"{ self } exception: { e } " )
187- add_response (None )
180+ return next (it )
181+ except StopIteration :
182+ return None
183+
184+ async def async_iterator (iterator ) -> AsyncIterable [rtts .SynthesizeSpeechResponse ]:
185+ while True :
186+ item = await asyncio .to_thread (async_next , iterator )
187+ if item is None :
188+ return
189+ yield item
188190
189191 try :
190192 assert self ._service is not None , "TTS service not initialized"
@@ -195,20 +197,16 @@ def add_response(r):
195197
196198 logger .debug (f"{ self } : Generating TTS [{ text } ]" )
197199
198- queue = asyncio .Queue ()
199- await asyncio .to_thread (read_audio_responses , queue )
200+ responses = await asyncio .to_thread (read_audio_responses )
200201
201- # Wait for the thread to start.
202- resp = await asyncio .wait_for (queue .get (), timeout = NVIDIA_TTS_TIMEOUT_SECS )
203- while resp :
202+ async for resp in async_iterator (responses ):
204203 await self .stop_ttfb_metrics ()
205204 frame = TTSAudioRawFrame (
206205 audio = resp .audio ,
207206 sample_rate = self .sample_rate ,
208207 num_channels = 1 ,
209208 )
210209 yield frame
211- resp = await asyncio .wait_for (queue .get (), timeout = NVIDIA_TTS_TIMEOUT_SECS )
212210
213211 await self .start_tts_usage_metrics (text )
214212 yield TTSStoppedFrame ()
0 commit comments