Skip to content

Commit a787fd9

Browse files
committed
NVIDIATTSService: process incoming audio frame right away
Process audio as soon as we receive it from the generator. Previously, we were reading from the generator and adding elements into a queue until there was no more data, then we would process the queue.
1 parent 14495c4 commit a787fd9

2 files changed

Lines changed: 28 additions & 29 deletions

File tree

changelog/3509.fixed.2.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
- Optimized `NVIDIATTSService` to process incoming audio frames immediately.

src/pipecat/services/nvidia/tts.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import asyncio
1414
import os
15-
from typing import AsyncGenerator, Mapping, Optional
15+
from typing import AsyncGenerator, AsyncIterable, Generator, Mapping, Optional
1616

1717
from pipecat.utils.tracing.service_decorators import traced_tts
1818

@@ -35,14 +35,12 @@
3535

3636
try:
3737
import riva.client
38-
38+
import riva.client.proto.riva_tts_pb2 as rtts
3939
except ModuleNotFoundError as e:
4040
logger.error(f"Exception: {e}")
4141
logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[nvidia]`.")
4242
raise Exception(f"Missing module: {e}")
4343

44-
NVIDIA_TTS_TIMEOUT_SECS = 5
45-
4644

4745
class NvidiaTTSService(TTSService):
4846
"""NVIDIA Riva text-to-speech service.
@@ -165,26 +163,30 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
165163
Frame: Audio frames containing the synthesized speech data.
166164
"""
167165

168-
def read_audio_responses(queue: asyncio.Queue):
169-
def add_response(r):
170-
asyncio.run_coroutine_threadsafe(queue.put(r), self.get_event_loop())
171-
166+
def read_audio_responses() -> Generator[rtts.SynthesizeSpeechResponse, None, None]:
167+
responses = self._service.synthesize_online(
168+
text,
169+
self._voice_id,
170+
self._language_code,
171+
sample_rate_hz=self.sample_rate,
172+
zero_shot_audio_prompt_file=None,
173+
zero_shot_quality=self._quality,
174+
custom_dictionary={},
175+
)
176+
return responses
177+
178+
def async_next(it):
172179
try:
173-
responses = self._service.synthesize_online(
174-
text,
175-
self._voice_id,
176-
self._language_code,
177-
sample_rate_hz=self.sample_rate,
178-
zero_shot_audio_prompt_file=None,
179-
zero_shot_quality=self._quality,
180-
custom_dictionary={},
181-
)
182-
for r in responses:
183-
add_response(r)
184-
add_response(None)
185-
except Exception as e:
186-
logger.error(f"{self} exception: {e}")
187-
add_response(None)
180+
return next(it)
181+
except StopIteration:
182+
return None
183+
184+
async def async_iterator(iterator) -> AsyncIterable[rtts.SynthesizeSpeechResponse]:
185+
while True:
186+
item = await asyncio.to_thread(async_next, iterator)
187+
if item is None:
188+
return
189+
yield item
188190

189191
try:
190192
assert self._service is not None, "TTS service not initialized"
@@ -195,20 +197,16 @@ def add_response(r):
195197

196198
logger.debug(f"{self}: Generating TTS [{text}]")
197199

198-
queue = asyncio.Queue()
199-
await asyncio.to_thread(read_audio_responses, queue)
200+
responses = await asyncio.to_thread(read_audio_responses)
200201

201-
# Wait for the thread to start.
202-
resp = await asyncio.wait_for(queue.get(), timeout=NVIDIA_TTS_TIMEOUT_SECS)
203-
while resp:
202+
async for resp in async_iterator(responses):
204203
await self.stop_ttfb_metrics()
205204
frame = TTSAudioRawFrame(
206205
audio=resp.audio,
207206
sample_rate=self.sample_rate,
208207
num_channels=1,
209208
)
210209
yield frame
211-
resp = await asyncio.wait_for(queue.get(), timeout=NVIDIA_TTS_TIMEOUT_SECS)
212210

213211
await self.start_tts_usage_metrics(text)
214212
yield TTSStoppedFrame()

0 commit comments

Comments
 (0)