Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/3287.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- Enhanced interruption handling in `AsyncAITTSService` by supporting multi-context WebSocket sessions for more robust context management.
1 change: 1 addition & 0 deletions changelog/3287.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- Corrected TTFB metric calculation in `AsyncAIHttpTTSService`.
130 changes: 96 additions & 34 deletions src/pipecat/services/asyncai/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import asyncio
import base64
import json
import uuid
from typing import AsyncGenerator, Optional

import aiohttp
Expand All @@ -27,7 +28,7 @@
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.tts_service import InterruptibleTTSService, TTSService
from pipecat.services.tts_service import AudioContextTTSService, TTSService
from pipecat.transcriptions.language import Language, resolve_language
from pipecat.utils.tracing.service_decorators import traced_tts

Expand Down Expand Up @@ -72,7 +73,7 @@ def language_to_async_language(language: Language) -> Optional[str]:
return resolve_language(language, LANGUAGE_MAP, use_base_code=True)


class AsyncAITTSService(InterruptibleTTSService):
class AsyncAITTSService(AudioContextTTSService):
"""Async TTS service with WebSocket streaming.

Provides text-to-speech using Async's streaming WebSocket API.
Expand Down Expand Up @@ -148,6 +149,7 @@ def __init__(
self._receive_task = None
self._keepalive_task = None
self._started = False
self._context_id = None

def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Expand All @@ -168,8 +170,8 @@ def language_to_service_language(self, language: Language) -> Optional[str]:
"""
return language_to_async_language(language)

def _build_msg(self, text: str = "", force: bool = False) -> str:
msg = {"transcript": text, "force": force}
def _build_msg(self, text: str = "", context_id: str = "", force: bool = False) -> str:
msg = {"transcript": text, "context_id": context_id, "force": force}
return json.dumps(msg)

async def start(self, frame: StartFrame):
Expand Down Expand Up @@ -253,11 +255,16 @@ async def _disconnect_websocket(self):

if self._websocket:
logger.debug("Disconnecting from Async")
# Close all contexts and the socket
if self._context_id:
await self._websocket.send(json.dumps({"terminate": True}))
await self._websocket.close()
logger.debug("Disconnected from Async")
except Exception as e:
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
finally:
self._websocket = None
self._context_id = None
self._started = False
await self._call_event_handler("on_disconnected")

Expand All @@ -268,10 +275,10 @@ def _get_websocket(self):

async def flush_audio(self):
"""Flush any pending audio."""
if not self._websocket:
if not self._context_id or not self._websocket:
return
logger.trace(f"{self}: flushing audio")
msg = self._build_msg(text=" ", force=True)
msg = self._build_msg(text=" ", context_id=self._context_id, force=True)
await self._websocket.send(msg)

async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
Expand All @@ -291,35 +298,75 @@ async def _receive_messages(self):
if not msg:
continue

elif msg.get("audio"):
received_ctx_id = msg.get("context_id")
# Handle final messages first, regardless of context availability
# At the moment, this message is received AFTER the close_context message is
# sent, so it doesn't serve any functional purpose. For now, we'll just log it.
if msg.get("final") is True:
logger.trace(f"Received final message for context {received_ctx_id}")
continue

# Check if this message belongs to the current context.
if not self.audio_context_available(received_ctx_id):
if self._context_id == received_ctx_id:
logger.debug(
f"Received a delayed message, recreating the context: {self._context_id}"
)
await self.create_audio_context(self._context_id)
else:
# This can happen if a message is received _after_ we have closed a context
# due to user interruption but _before_ the `isFinal` message for the context
# is received.
logger.debug(f"Ignoring message from unavailable context: {received_ctx_id}")
continue

if msg.get("audio"):
await self.stop_ttfb_metrics()
frame = TTSAudioRawFrame(
audio=base64.b64decode(msg["audio"]),
sample_rate=self.sample_rate,
num_channels=1,
)
await self.push_frame(frame)
elif msg.get("error_code"):
await self.push_frame(TTSStoppedFrame())
await self.stop_all_metrics()
await self.push_error(error_msg=f"Error: {msg['message']}")
else:
await self.push_error(error_msg=f"Unknown message type: {msg}")
audio = base64.b64decode(msg["audio"])
frame = TTSAudioRawFrame(audio, self.sample_rate, 1)
await self.append_to_audio_context(received_ctx_id, frame)

async def _keepalive_task_handler(self):
"""Send periodic keepalive messages to maintain WebSocket connection."""
KEEPALIVE_SLEEP = 3
KEEPALIVE_SLEEP = 10
while True:
await asyncio.sleep(KEEPALIVE_SLEEP)
try:
if self._websocket and self._websocket.state is State.OPEN:
keepalive_message = {"transcript": " "}
logger.trace("Sending keepalive message")
if self._context_id:
keepalive_message = {
"transcript": " ",
"context_id": self._context_id,
}
logger.trace("Sending keepalive message")
else:
# It's possible to have a user interruption which clears the context
# without generating a new TTS response. In this case, we'll just send
# an empty message to keep the connection alive.
keepalive_message = {"transcript": " "}
logger.trace("Sending keepalive without context")
await self._websocket.send(json.dumps(keepalive_message))
except websockets.ConnectionClosed as e:
logger.warning(f"{self} keepalive error: {e}")
break

async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
"""Handle interruption by closing the current context."""
await super()._handle_interruption(frame, direction)

# Close the current context when interrupted without closing the websocket
if self._context_id and self._websocket:
try:
await self._websocket.send(
json.dumps(
{"context_id": self._context_id, "close_context": True, "transcript": ""}
)
)
except Exception as e:
logger.error(f"Error closing context on interruption: {e}")
self._context_id = None
self._started = False

@traced_tts
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using Async API websocket endpoint.
Expand All @@ -336,21 +383,29 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
if not self._websocket or self._websocket.state is State.CLOSED:
await self._connect()

if not self._started:
await self.start_ttfb_metrics()
yield TTSStartedFrame()
self._started = True

msg = self._build_msg(text=text, force=True)

try:
await self._get_websocket().send(msg)
await self.start_tts_usage_metrics(text)
if not self._started:
await self.start_ttfb_metrics()
yield TTSStartedFrame()
self._started = True

if not self._context_id:
self._context_id = str(uuid.uuid4())
if not self.audio_context_available(self._context_id):
await self.create_audio_context(self._context_id)

msg = self._build_msg(text=text, force=True, context_id=self._context_id)
await self._get_websocket().send(msg)
await self.start_tts_usage_metrics(text)
else:
if self._websocket and self._context_id:
msg = self._build_msg(text=text, force=True, context_id=self._context_id)
await self._get_websocket().send(msg)

except Exception as e:
yield ErrorFrame(error=f"Unknown error occurred: {e}")
yield TTSStoppedFrame()
await self._disconnect()
await self._connect()
self._started = False
return
yield None
except Exception as e:
Expand Down Expand Up @@ -490,7 +545,14 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
await self.push_error(error_msg=f"Async API error: {error_text}")
raise Exception(f"Async API returned status {response.status}: {error_text}")

audio_data = await response.read()
# Read streaming bytes; stop TTFB on the *first* received chunk
buffer = bytearray()
async for chunk in response.content.iter_chunked(64 * 1024):
if not chunk:
continue
await self.stop_ttfb_metrics()
buffer.extend(chunk)
audio_data = bytes(buffer)

await self.start_tts_usage_metrics(text)

Expand Down