From 9cdbc56be3b671bda192a4f5529e95eeafb484c4 Mon Sep 17 00:00:00 2001 From: Ashot Date: Tue, 23 Dec 2025 16:35:45 +0400 Subject: [PATCH 1/4] Fix TTFB metric and add multi-context WebSocket support for Async TTS --- CHANGELOG.md | 9 + src/pipecat/services/asyncai/tts.py | 299 ++++++++++++++++++++++------ 2 files changed, 247 insertions(+), 61 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d583b4e14..63ef32ed1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Changed + +- Enhanced interruption handling in `AsyncAITTSService` by supporting multi-context WebSocket sessions for more robust context management. + +### Fixed + +- Corrected TTFB metric calculation in `AsyncAIHttpTTSService`. ## [0.0.99] - 2026-01-13 diff --git a/src/pipecat/services/asyncai/tts.py b/src/pipecat/services/asyncai/tts.py index 3033692059..c49b95153d 100644 --- a/src/pipecat/services/asyncai/tts.py +++ b/src/pipecat/services/asyncai/tts.py @@ -9,8 +9,9 @@ import asyncio import base64 import json -from typing import AsyncGenerator, Optional +from typing import AsyncGenerator, Optional, Dict +import uuid import aiohttp from loguru import logger from pydantic import BaseModel @@ -27,7 +28,7 @@ TTSStoppedFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.tts_service import InterruptibleTTSService, TTSService +from pipecat.services.tts_service import WebsocketTTSService, TTSService from pipecat.transcriptions.language import Language, resolve_language from pipecat.utils.tracing.service_decorators import traced_tts @@ -72,7 +73,7 @@ def language_to_async_language(language: Language) -> Optional[str]: return resolve_language(language, LANGUAGE_MAP, use_base_code=True) -class AsyncAITTSService(InterruptibleTTSService): +class AsyncAITTSService(WebsocketTTSService): """Async TTS service with WebSocket streaming. Provides text-to-speech using Async's streaming WebSocket API. @@ -126,6 +127,10 @@ def __init__( **kwargs, ) + self._contexts: Dict[str, asyncio.Queue] = {} + self._audio_context_task = None + self._context_id = None + params = params or AsyncAITTSService.InputParams() self._api_key = api_key @@ -148,30 +153,56 @@ def __init__( self._receive_task = None self._keepalive_task = None self._started = False + + async def create_audio_context(self, context_id: str): + """Create a new audio context for grouping related audio. - def can_generate_metrics(self) -> bool: - """Check if this service can generate processing metrics. - - Returns: - True, as Async service supports metrics generation. + Args: + context_id: Unique identifier for the audio context. """ - return True + await self._contexts_queue.put(context_id) + self._contexts[context_id] = asyncio.Queue() + logger.trace(f"{self} created audio context {context_id}") - def language_to_service_language(self, language: Language) -> Optional[str]: - """Convert a Language enum to Async language format. + async def append_to_audio_context(self, context_id: str, frame: TTSAudioRawFrame): + """Append audio to an existing context. Args: - language: The language to convert. + context_id: The context to append audio to. + frame: The audio frame to append. + """ + if self.audio_context_available(context_id): + logger.trace(f"{self} appending audio {frame} to audio context {context_id}") + await self._contexts[context_id].put(frame) + else: + logger.warning(f"{self} unable to append audio to context {context_id}") - Returns: - The Async-specific language code, or None if not supported. + async def remove_audio_context(self, context_id: str): + """Remove an existing audio context. + + Args: + context_id: The context to remove. """ - return language_to_async_language(language) + if self.audio_context_available(context_id): + # We just mark the audio context for deletion by appending + # None. Once we reach None while handling audio we know we can + # safely remove the context. + logger.trace(f"{self} marking audio context {context_id} for deletion") + await self._contexts[context_id].put(None) + else: + logger.warning(f"{self} unable to remove context {context_id}") + + def audio_context_available(self, context_id: str) -> bool: + """Check whether the given audio context is registered. - def _build_msg(self, text: str = "", force: bool = False) -> str: - msg = {"transcript": text, "force": force} - return json.dumps(msg) + Args: + context_id: The context ID to check. + Returns: + True if the context exists and is available. + """ + return context_id in self._contexts + async def start(self, frame: StartFrame): """Start the Async TTS service. @@ -179,6 +210,7 @@ async def start(self, frame: StartFrame): frame: The start frame containing initialization parameters. """ await super().start(frame) + self._create_audio_context_task() self._settings["output_format"]["sample_rate"] = self.sample_rate await self._connect() @@ -189,6 +221,12 @@ async def stop(self, frame: EndFrame): frame: The end frame. """ await super().stop(frame) + if self._audio_context_task: + # Indicate no more audio contexts are available. this will end the + # task cleanly after all contexts have been processed. + await self._contexts_queue.put(None) + await self._audio_context_task + self._audio_context_task = None await self._disconnect() async def cancel(self, frame: CancelFrame): @@ -198,7 +236,88 @@ async def cancel(self, frame: CancelFrame): frame: The cancel frame. """ await super().cancel(frame) - await self._disconnect() + await self._stop_audio_context_task() + await self._disconnect() + + async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection): + await super()._handle_interruption(frame, direction) + await self._stop_audio_context_task() + self._create_audio_context_task() + + def _create_audio_context_task(self): + if not self._audio_context_task: + self._contexts_queue = asyncio.Queue() + self._contexts: Dict[str, asyncio.Queue] = {} + self._audio_context_task = self.create_task(self._audio_context_task_handler()) + + async def _stop_audio_context_task(self): + if self._audio_context_task: + await self.cancel_task(self._audio_context_task) + self._audio_context_task = None + + async def _audio_context_task_handler(self): + """In this task we process audio contexts in order.""" + running = True + while running: + context_id = await self._contexts_queue.get() + + if context_id: + # Process the audio context until the context doesn't have more + # audio available (i.e. we find None). + await self._handle_audio_context(context_id) + + # We just finished processing the context, so we can safely remove it. + del self._contexts[context_id] + + # Append some silence between sentences. + silence = b"\x00" * self.sample_rate + frame = TTSAudioRawFrame( + audio=silence, sample_rate=self.sample_rate, num_channels=1 + ) + await self.push_frame(frame) + else: + running = False + + self._contexts_queue.task_done() + + async def _handle_audio_context(self, context_id: str): + # If we don't receive any audio during this time, we consider the context finished. + AUDIO_CONTEXT_TIMEOUT = 3.0 + queue = self._contexts[context_id] + running = True + while running: + try: + frame = await asyncio.wait_for(queue.get(), timeout=AUDIO_CONTEXT_TIMEOUT) + if frame: + await self.push_frame(frame) + running = frame is not None + except asyncio.TimeoutError: + # We didn't get audio, so let's consider this context finished. + logger.trace(f"{self} time out on audio context {context_id}") + break + + def can_generate_metrics(self) -> bool: + """Check if this service can generate processing metrics. + + Returns: + True, as Async service supports metrics generation. + """ + return True + + def language_to_service_language(self, language: Language) -> Optional[str]: + """Convert a Language enum to Async language format. + + Args: + language: The language to convert. + + Returns: + The Async-specific language code, or None if not supported. + """ + return language_to_async_language(language) + + def _build_msg(self, text: str = "", context_id: str = "", force: bool = False) -> str: + msg = {"transcript": text, "context_id": context_id, "force": force} + return json.dumps(msg) async def _connect(self): await super()._connect() @@ -253,11 +372,16 @@ async def _disconnect_websocket(self): if self._websocket: logger.debug("Disconnecting from Async") + # Close all contexts and the socket + if self._context_id: + await self._websocket.send(json.dumps({"terminate": True})) await self._websocket.close() + logger.debug("Disconnected from Async") except Exception as e: - await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) + logger.error(f"{self} error closing websocket: {e}") finally: self._websocket = None + self._context_id = None self._started = False await self._call_event_handler("on_disconnected") @@ -268,10 +392,10 @@ def _get_websocket(self): async def flush_audio(self): """Flush any pending audio.""" - if not self._websocket: + if not self._context_id or not self._websocket: return logger.trace(f"{self}: flushing audio") - msg = self._build_msg(text=" ", force=True) + msg = self._build_msg(text=" ", context_id=self._context_id, force=True) await self._websocket.send(msg) async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM): @@ -291,35 +415,70 @@ async def _receive_messages(self): if not msg: continue - elif msg.get("audio"): + received_ctx_id = msg.get("context_id") + # Handle final messages first, regardless of context availability + # At the moment, this message is received AFTER the close_context message is + # sent, so it doesn't serve any functional purpose. For now, we'll just log it. + if msg.get("final") is True: + logger.trace(f"Received final message for context {received_ctx_id}") + continue + + # Check if this message belongs to the current context. + if not self.audio_context_available(received_ctx_id): + if self._context_id == received_ctx_id: + logger.debug( + f"Received a delayed message, recreating the context: {self._context_id}" + ) + await self.create_audio_context(self._context_id) + else: + # This can happen if a message is received _after_ we have closed a context + # due to user interruption but _before_ the `isFinal` message for the context + # is received. + logger.debug(f"Ignoring message from unavailable context: {received_ctx_id}") + continue + + if msg.get("audio"): await self.stop_ttfb_metrics() - frame = TTSAudioRawFrame( - audio=base64.b64decode(msg["audio"]), - sample_rate=self.sample_rate, - num_channels=1, - ) - await self.push_frame(frame) - elif msg.get("error_code"): - await self.push_frame(TTSStoppedFrame()) - await self.stop_all_metrics() - await self.push_error(error_msg=f"Error: {msg['message']}") - else: - await self.push_error(error_msg=f"Unknown message type: {msg}") + audio = base64.b64decode(msg["audio"]) + frame = TTSAudioRawFrame(audio, self.sample_rate, 1) + await self.append_to_audio_context(received_ctx_id, frame) async def _keepalive_task_handler(self): """Send periodic keepalive messages to maintain WebSocket connection.""" - KEEPALIVE_SLEEP = 3 + KEEPALIVE_SLEEP = 10 while True: await asyncio.sleep(KEEPALIVE_SLEEP) try: if self._websocket and self._websocket.state is State.OPEN: - keepalive_message = {"transcript": " "} - logger.trace("Sending keepalive message") + if self._context_id: + keepalive_message = {"transcript": " ", "context_id": self._context_id,} + logger.trace("Sending keepalive message") + else: + # It's possible to have a user interruption which clears the context + # without generating a new TTS response. In this case, we'll just send + # an empty message to keep the connection alive. + keepalive_message = {"transcript": " "} + logger.trace("Sending keepalive without context") await self._websocket.send(json.dumps(keepalive_message)) except websockets.ConnectionClosed as e: logger.warning(f"{self} keepalive error: {e}") break + async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection): + """Handle interruption by closing the current context.""" + await super()._handle_interruption(frame, direction) + + # Close the current context when interrupted without closing the websocket + if self._context_id and self._websocket: + try: + await self._websocket.send( + json.dumps({"context_id": self._context_id, "close_context": True, "transcript": ""}) + ) + except Exception as e: + logger.error(f"Error closing context on interruption: {e}") + self._context_id = None + self._started = False + @traced_tts async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: """Generate speech from text using Async API websocket endpoint. @@ -336,26 +495,35 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: if not self._websocket or self._websocket.state is State.CLOSED: await self._connect() - if not self._started: - await self.start_ttfb_metrics() - yield TTSStartedFrame() - self._started = True - - msg = self._build_msg(text=text, force=True) - - try: - await self._get_websocket().send(msg) - await self.start_tts_usage_metrics(text) + try: + if not self._started: + await self.start_ttfb_metrics() + yield TTSStartedFrame() + self._started = True + + if not self._context_id: + self._context_id = str(uuid.uuid4()) + if not self.audio_context_available(self._context_id): + await self.create_audio_context(self._context_id) + + msg = self._build_msg(text=" ", context_id=self._context_id) + await self._get_websocket().send(msg) + msg = self._build_msg(text=text, force=True, context_id=self._context_id) + await self._get_websocket().send(msg) + await self.start_tts_usage_metrics(text) + else: + if self._websocket and self._context_id: + msg = self._build_msg(text=text, force=True, context_id=self._context_id) + await self._get_websocket().send(msg) + except Exception as e: - yield ErrorFrame(error=f"Unknown error occurred: {e}") + logger.error(f"{self} error sending message: {e}") yield TTSStoppedFrame() - await self._disconnect() - await self._connect() + self._started = False return yield None except Exception as e: - yield ErrorFrame(error=f"Unknown error occurred: {e}") - + logger.error(f"{self} exception: {e}") class AsyncAIHttpTTSService(TTSService): """HTTP-based Async TTS service. @@ -466,9 +634,9 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: """ logger.debug(f"{self}: Generating TTS [{text}]") + first_byte_seen = False try: voice_config = {"mode": "id", "id": self._voice_id} - await self.start_ttfb_metrics() payload = { "model_id": self._model_name, "transcript": text, @@ -476,7 +644,6 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: "output_format": self._settings["output_format"], "language": self._settings["language"], } - yield TTSStartedFrame() headers = { "version": self._api_version, "x-api-key": self._api_key, @@ -484,26 +651,36 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: } url = f"{self._base_url}/text_to_speech/streaming" + yield TTSStartedFrame() + await self.start_ttfb_metrics() async with self._session.post(url, json=payload, headers=headers) as response: if response.status != 200: error_text = await response.text() await self.push_error(error_msg=f"Async API error: {error_text}") raise Exception(f"Async API returned status {response.status}: {error_text}") - audio_data = await response.read() + # Read streaming bytes; stop TTFB on the *first* received chunk + buffer = bytearray() + async for chunk in response.content.iter_chunked(64 * 1024): + if not chunk: + continue + if not first_byte_seen: + first_byte_seen = True + await self.stop_ttfb_metrics() + await self.start_tts_usage_metrics(text) - await self.start_tts_usage_metrics(text) + buffer.extend(chunk) + audio_data = bytes(buffer) - frame = TTSAudioRawFrame( + yield TTSAudioRawFrame( audio=audio_data, sample_rate=self.sample_rate, num_channels=1, ) - yield frame - except Exception as e: await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) finally: - await self.stop_ttfb_metrics() + if not first_byte_seen: + await self.stop_ttfb_metrics() yield TTSStoppedFrame() From 5ae592f38e37e275dfaef46d5701432aff1df1fb Mon Sep 17 00:00:00 2001 From: Ashot Date: Wed, 7 Jan 2026 15:55:35 +0400 Subject: [PATCH 2/4] Improve Async TTS interruption handling by using AudioContextTTSService class and add changelog fragments --- CHANGELOG.md | 9 --- changelog/3287.changed.md | 1 + changelog/3287.fixed.md | 1 + src/pipecat/services/asyncai/tts.py | 120 +--------------------------- 4 files changed, 5 insertions(+), 126 deletions(-) create mode 100644 changelog/3287.changed.md create mode 100644 changelog/3287.fixed.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 63ef32ed1a..3d583b4e14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,15 +6,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [Unreleased] - -### Changed - -- Enhanced interruption handling in `AsyncAITTSService` by supporting multi-context WebSocket sessions for more robust context management. - -### Fixed - -- Corrected TTFB metric calculation in `AsyncAIHttpTTSService`. ## [0.0.99] - 2026-01-13 diff --git a/changelog/3287.changed.md b/changelog/3287.changed.md new file mode 100644 index 0000000000..f0df829661 --- /dev/null +++ b/changelog/3287.changed.md @@ -0,0 +1 @@ +- Enhanced interruption handling in `AsyncAITTSService` by supporting multi-context WebSocket sessions for more robust context management. \ No newline at end of file diff --git a/changelog/3287.fixed.md b/changelog/3287.fixed.md new file mode 100644 index 0000000000..30ce0b13bc --- /dev/null +++ b/changelog/3287.fixed.md @@ -0,0 +1 @@ +- Corrected TTFB metric calculation in `AsyncAIHttpTTSService`. \ No newline at end of file diff --git a/src/pipecat/services/asyncai/tts.py b/src/pipecat/services/asyncai/tts.py index c49b95153d..05ba186545 100644 --- a/src/pipecat/services/asyncai/tts.py +++ b/src/pipecat/services/asyncai/tts.py @@ -28,7 +28,7 @@ TTSStoppedFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.tts_service import WebsocketTTSService, TTSService +from pipecat.services.tts_service import AudioContextTTSService, WebsocketTTSService, TTSService from pipecat.transcriptions.language import Language, resolve_language from pipecat.utils.tracing.service_decorators import traced_tts @@ -73,7 +73,7 @@ def language_to_async_language(language: Language) -> Optional[str]: return resolve_language(language, LANGUAGE_MAP, use_base_code=True) -class AsyncAITTSService(WebsocketTTSService): +class AsyncAITTSService(AudioContextTTSService, WebsocketTTSService): """Async TTS service with WebSocket streaming. Provides text-to-speech using Async's streaming WebSocket API. @@ -154,55 +154,6 @@ def __init__( self._keepalive_task = None self._started = False - async def create_audio_context(self, context_id: str): - """Create a new audio context for grouping related audio. - - Args: - context_id: Unique identifier for the audio context. - """ - await self._contexts_queue.put(context_id) - self._contexts[context_id] = asyncio.Queue() - logger.trace(f"{self} created audio context {context_id}") - - async def append_to_audio_context(self, context_id: str, frame: TTSAudioRawFrame): - """Append audio to an existing context. - - Args: - context_id: The context to append audio to. - frame: The audio frame to append. - """ - if self.audio_context_available(context_id): - logger.trace(f"{self} appending audio {frame} to audio context {context_id}") - await self._contexts[context_id].put(frame) - else: - logger.warning(f"{self} unable to append audio to context {context_id}") - - async def remove_audio_context(self, context_id: str): - """Remove an existing audio context. - - Args: - context_id: The context to remove. - """ - if self.audio_context_available(context_id): - # We just mark the audio context for deletion by appending - # None. Once we reach None while handling audio we know we can - # safely remove the context. - logger.trace(f"{self} marking audio context {context_id} for deletion") - await self._contexts[context_id].put(None) - else: - logger.warning(f"{self} unable to remove context {context_id}") - - def audio_context_available(self, context_id: str) -> bool: - """Check whether the given audio context is registered. - - Args: - context_id: The context ID to check. - - Returns: - True if the context exists and is available. - """ - return context_id in self._contexts - async def start(self, frame: StartFrame): """Start the Async TTS service. @@ -210,7 +161,6 @@ async def start(self, frame: StartFrame): frame: The start frame containing initialization parameters. """ await super().start(frame) - self._create_audio_context_task() self._settings["output_format"]["sample_rate"] = self.sample_rate await self._connect() @@ -221,12 +171,6 @@ async def stop(self, frame: EndFrame): frame: The end frame. """ await super().stop(frame) - if self._audio_context_task: - # Indicate no more audio contexts are available. this will end the - # task cleanly after all contexts have been processed. - await self._contexts_queue.put(None) - await self._audio_context_task - self._audio_context_task = None await self._disconnect() async def cancel(self, frame: CancelFrame): @@ -236,65 +180,7 @@ async def cancel(self, frame: CancelFrame): frame: The cancel frame. """ await super().cancel(frame) - await self._stop_audio_context_task() - await self._disconnect() - - async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection): - await super()._handle_interruption(frame, direction) - await self._stop_audio_context_task() - self._create_audio_context_task() - - def _create_audio_context_task(self): - if not self._audio_context_task: - self._contexts_queue = asyncio.Queue() - self._contexts: Dict[str, asyncio.Queue] = {} - self._audio_context_task = self.create_task(self._audio_context_task_handler()) - - async def _stop_audio_context_task(self): - if self._audio_context_task: - await self.cancel_task(self._audio_context_task) - self._audio_context_task = None - - async def _audio_context_task_handler(self): - """In this task we process audio contexts in order.""" - running = True - while running: - context_id = await self._contexts_queue.get() - - if context_id: - # Process the audio context until the context doesn't have more - # audio available (i.e. we find None). - await self._handle_audio_context(context_id) - - # We just finished processing the context, so we can safely remove it. - del self._contexts[context_id] - - # Append some silence between sentences. - silence = b"\x00" * self.sample_rate - frame = TTSAudioRawFrame( - audio=silence, sample_rate=self.sample_rate, num_channels=1 - ) - await self.push_frame(frame) - else: - running = False - - self._contexts_queue.task_done() - - async def _handle_audio_context(self, context_id: str): - # If we don't receive any audio during this time, we consider the context finished. - AUDIO_CONTEXT_TIMEOUT = 3.0 - queue = self._contexts[context_id] - running = True - while running: - try: - frame = await asyncio.wait_for(queue.get(), timeout=AUDIO_CONTEXT_TIMEOUT) - if frame: - await self.push_frame(frame) - running = frame is not None - except asyncio.TimeoutError: - # We didn't get audio, so let's consider this context finished. - logger.trace(f"{self} time out on audio context {context_id}") - break + await self._disconnect() def can_generate_metrics(self) -> bool: """Check if this service can generate processing metrics. From 15067c678d00213aa2a2ee7dbb862dbd0e3c54a2 Mon Sep 17 00:00:00 2001 From: Ashot Date: Wed, 7 Jan 2026 21:42:30 +0400 Subject: [PATCH 3/4] adapt Async TTS to updated AudioContextTTSService --- src/pipecat/services/asyncai/tts.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/pipecat/services/asyncai/tts.py b/src/pipecat/services/asyncai/tts.py index 05ba186545..04a847955f 100644 --- a/src/pipecat/services/asyncai/tts.py +++ b/src/pipecat/services/asyncai/tts.py @@ -28,7 +28,7 @@ TTSStoppedFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.tts_service import AudioContextTTSService, WebsocketTTSService, TTSService +from pipecat.services.tts_service import AudioContextTTSService, TTSService from pipecat.transcriptions.language import Language, resolve_language from pipecat.utils.tracing.service_decorators import traced_tts @@ -73,7 +73,7 @@ def language_to_async_language(language: Language) -> Optional[str]: return resolve_language(language, LANGUAGE_MAP, use_base_code=True) -class AsyncAITTSService(AudioContextTTSService, WebsocketTTSService): +class AsyncAITTSService(AudioContextTTSService): """Async TTS service with WebSocket streaming. Provides text-to-speech using Async's streaming WebSocket API. @@ -153,7 +153,7 @@ def __init__( self._receive_task = None self._keepalive_task = None self._started = False - + async def start(self, frame: StartFrame): """Start the Async TTS service. @@ -337,7 +337,10 @@ async def _keepalive_task_handler(self): try: if self._websocket and self._websocket.state is State.OPEN: if self._context_id: - keepalive_message = {"transcript": " ", "context_id": self._context_id,} + keepalive_message = { + "transcript": " ", + "context_id": self._context_id, + } logger.trace("Sending keepalive message") else: # It's possible to have a user interruption which clears the context @@ -358,7 +361,9 @@ async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameD if self._context_id and self._websocket: try: await self._websocket.send( - json.dumps({"context_id": self._context_id, "close_context": True, "transcript": ""}) + json.dumps( + {"context_id": self._context_id, "close_context": True, "transcript": ""} + ) ) except Exception as e: logger.error(f"Error closing context on interruption: {e}") @@ -381,7 +386,7 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: if not self._websocket or self._websocket.state is State.CLOSED: await self._connect() - try: + try: if not self._started: await self.start_ttfb_metrics() yield TTSStartedFrame() @@ -401,7 +406,7 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: if self._websocket and self._context_id: msg = self._build_msg(text=text, force=True, context_id=self._context_id) await self._get_websocket().send(msg) - + except Exception as e: logger.error(f"{self} error sending message: {e}") yield TTSStoppedFrame() @@ -409,7 +414,8 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: return yield None except Exception as e: - logger.error(f"{self} exception: {e}") + logger.error(f"{self} exception: {e}") + class AsyncAIHttpTTSService(TTSService): """HTTP-based Async TTS service. From c4ae4025f3bc28d8a65e8ffc118fda7104500f77 Mon Sep 17 00:00:00 2001 From: Ashot Date: Wed, 14 Jan 2026 16:33:30 +0400 Subject: [PATCH 4/4] Adjustments of Async TTS for multicontext websocket support --- src/pipecat/services/asyncai/tts.py | 87 +++++++++++++---------------- 1 file changed, 40 insertions(+), 47 deletions(-) diff --git a/src/pipecat/services/asyncai/tts.py b/src/pipecat/services/asyncai/tts.py index 04a847955f..fbc760562f 100644 --- a/src/pipecat/services/asyncai/tts.py +++ b/src/pipecat/services/asyncai/tts.py @@ -9,9 +9,9 @@ import asyncio import base64 import json -from typing import AsyncGenerator, Optional, Dict - import uuid +from typing import AsyncGenerator, Optional + import aiohttp from loguru import logger from pydantic import BaseModel @@ -127,10 +127,6 @@ def __init__( **kwargs, ) - self._contexts: Dict[str, asyncio.Queue] = {} - self._audio_context_task = None - self._context_id = None - params = params or AsyncAITTSService.InputParams() self._api_key = api_key @@ -153,6 +149,30 @@ def __init__( self._receive_task = None self._keepalive_task = None self._started = False + self._context_id = None + + def can_generate_metrics(self) -> bool: + """Check if this service can generate processing metrics. + + Returns: + True, as Async service supports metrics generation. + """ + return True + + def language_to_service_language(self, language: Language) -> Optional[str]: + """Convert a Language enum to Async language format. + + Args: + language: The language to convert. + + Returns: + The Async-specific language code, or None if not supported. + """ + return language_to_async_language(language) + + def _build_msg(self, text: str = "", context_id: str = "", force: bool = False) -> str: + msg = {"transcript": text, "context_id": context_id, "force": force} + return json.dumps(msg) async def start(self, frame: StartFrame): """Start the Async TTS service. @@ -182,29 +202,6 @@ async def cancel(self, frame: CancelFrame): await super().cancel(frame) await self._disconnect() - def can_generate_metrics(self) -> bool: - """Check if this service can generate processing metrics. - - Returns: - True, as Async service supports metrics generation. - """ - return True - - def language_to_service_language(self, language: Language) -> Optional[str]: - """Convert a Language enum to Async language format. - - Args: - language: The language to convert. - - Returns: - The Async-specific language code, or None if not supported. - """ - return language_to_async_language(language) - - def _build_msg(self, text: str = "", context_id: str = "", force: bool = False) -> str: - msg = {"transcript": text, "context_id": context_id, "force": force} - return json.dumps(msg) - async def _connect(self): await super()._connect() @@ -264,7 +261,7 @@ async def _disconnect_websocket(self): await self._websocket.close() logger.debug("Disconnected from Async") except Exception as e: - logger.error(f"{self} error closing websocket: {e}") + await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) finally: self._websocket = None self._context_id = None @@ -338,7 +335,7 @@ async def _keepalive_task_handler(self): if self._websocket and self._websocket.state is State.OPEN: if self._context_id: keepalive_message = { - "transcript": " ", + "transcript": " ", "context_id": self._context_id, } logger.trace("Sending keepalive message") @@ -397,24 +394,22 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: if not self.audio_context_available(self._context_id): await self.create_audio_context(self._context_id) - msg = self._build_msg(text=" ", context_id=self._context_id) - await self._get_websocket().send(msg) msg = self._build_msg(text=text, force=True, context_id=self._context_id) await self._get_websocket().send(msg) await self.start_tts_usage_metrics(text) else: if self._websocket and self._context_id: msg = self._build_msg(text=text, force=True, context_id=self._context_id) - await self._get_websocket().send(msg) + await self._get_websocket().send(msg) except Exception as e: - logger.error(f"{self} error sending message: {e}") + yield ErrorFrame(error=f"Unknown error occurred: {e}") yield TTSStoppedFrame() self._started = False return yield None except Exception as e: - logger.error(f"{self} exception: {e}") + yield ErrorFrame(error=f"Unknown error occurred: {e}") class AsyncAIHttpTTSService(TTSService): @@ -526,9 +521,9 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: """ logger.debug(f"{self}: Generating TTS [{text}]") - first_byte_seen = False try: voice_config = {"mode": "id", "id": self._voice_id} + await self.start_ttfb_metrics() payload = { "model_id": self._model_name, "transcript": text, @@ -536,6 +531,7 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: "output_format": self._settings["output_format"], "language": self._settings["language"], } + yield TTSStartedFrame() headers = { "version": self._api_version, "x-api-key": self._api_key, @@ -543,8 +539,6 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: } url = f"{self._base_url}/text_to_speech/streaming" - yield TTSStartedFrame() - await self.start_ttfb_metrics() async with self._session.post(url, json=payload, headers=headers) as response: if response.status != 200: error_text = await response.text() @@ -556,23 +550,22 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: async for chunk in response.content.iter_chunked(64 * 1024): if not chunk: continue - if not first_byte_seen: - first_byte_seen = True - await self.stop_ttfb_metrics() - await self.start_tts_usage_metrics(text) - + await self.stop_ttfb_metrics() buffer.extend(chunk) audio_data = bytes(buffer) - yield TTSAudioRawFrame( + await self.start_tts_usage_metrics(text) + + frame = TTSAudioRawFrame( audio=audio_data, sample_rate=self.sample_rate, num_channels=1, ) + yield frame + except Exception as e: await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e) finally: - if not first_byte_seen: - await self.stop_ttfb_metrics() + await self.stop_ttfb_metrics() yield TTSStoppedFrame()