Skip to content

Commit a9bfb09

Browse files
authored
Merge pull request pipecat-ai#3287 from ashotbagh/feature/asyncai-multicontext-wss
Fix TTFB metric and add multi-context WebSocket support for Async TTS
2 parents 86ed485 + c4ae402 commit a9bfb09

3 files changed

Lines changed: 98 additions & 34 deletions

File tree

changelog/3287.changed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
- Enhanced interruption handling in `AsyncAITTSService` by supporting multi-context WebSocket sessions for more robust context management.

changelog/3287.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
- Corrected TTFB metric calculation in `AsyncAIHttpTTSService`.

src/pipecat/services/asyncai/tts.py

Lines changed: 96 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import asyncio
1010
import base64
1111
import json
12+
import uuid
1213
from typing import AsyncGenerator, Optional
1314

1415
import aiohttp
@@ -27,7 +28,7 @@
2728
TTSStoppedFrame,
2829
)
2930
from pipecat.processors.frame_processor import FrameDirection
30-
from pipecat.services.tts_service import InterruptibleTTSService, TTSService
31+
from pipecat.services.tts_service import AudioContextTTSService, TTSService
3132
from pipecat.transcriptions.language import Language, resolve_language
3233
from pipecat.utils.tracing.service_decorators import traced_tts
3334

@@ -72,7 +73,7 @@ def language_to_async_language(language: Language) -> Optional[str]:
7273
return resolve_language(language, LANGUAGE_MAP, use_base_code=True)
7374

7475

75-
class AsyncAITTSService(InterruptibleTTSService):
76+
class AsyncAITTSService(AudioContextTTSService):
7677
"""Async TTS service with WebSocket streaming.
7778
7879
Provides text-to-speech using Async's streaming WebSocket API.
@@ -148,6 +149,7 @@ def __init__(
148149
self._receive_task = None
149150
self._keepalive_task = None
150151
self._started = False
152+
self._context_id = None
151153

152154
def can_generate_metrics(self) -> bool:
153155
"""Check if this service can generate processing metrics.
@@ -168,8 +170,8 @@ def language_to_service_language(self, language: Language) -> Optional[str]:
168170
"""
169171
return language_to_async_language(language)
170172

171-
def _build_msg(self, text: str = "", force: bool = False) -> str:
172-
msg = {"transcript": text, "force": force}
173+
def _build_msg(self, text: str = "", context_id: str = "", force: bool = False) -> str:
174+
msg = {"transcript": text, "context_id": context_id, "force": force}
173175
return json.dumps(msg)
174176

175177
async def start(self, frame: StartFrame):
@@ -253,11 +255,16 @@ async def _disconnect_websocket(self):
253255

254256
if self._websocket:
255257
logger.debug("Disconnecting from Async")
258+
# Close all contexts and the socket
259+
if self._context_id:
260+
await self._websocket.send(json.dumps({"terminate": True}))
256261
await self._websocket.close()
262+
logger.debug("Disconnected from Async")
257263
except Exception as e:
258264
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
259265
finally:
260266
self._websocket = None
267+
self._context_id = None
261268
self._started = False
262269
await self._call_event_handler("on_disconnected")
263270

@@ -268,10 +275,10 @@ def _get_websocket(self):
268275

269276
async def flush_audio(self):
270277
"""Flush any pending audio."""
271-
if not self._websocket:
278+
if not self._context_id or not self._websocket:
272279
return
273280
logger.trace(f"{self}: flushing audio")
274-
msg = self._build_msg(text=" ", force=True)
281+
msg = self._build_msg(text=" ", context_id=self._context_id, force=True)
275282
await self._websocket.send(msg)
276283

277284
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
@@ -291,35 +298,75 @@ async def _receive_messages(self):
291298
if not msg:
292299
continue
293300

294-
elif msg.get("audio"):
301+
received_ctx_id = msg.get("context_id")
302+
# Handle final messages first, regardless of context availability
303+
# At the moment, this message is received AFTER the close_context message is
304+
# sent, so it doesn't serve any functional purpose. For now, we'll just log it.
305+
if msg.get("final") is True:
306+
logger.trace(f"Received final message for context {received_ctx_id}")
307+
continue
308+
309+
# Check if this message belongs to the current context.
310+
if not self.audio_context_available(received_ctx_id):
311+
if self._context_id == received_ctx_id:
312+
logger.debug(
313+
f"Received a delayed message, recreating the context: {self._context_id}"
314+
)
315+
await self.create_audio_context(self._context_id)
316+
else:
317+
# This can happen if a message is received _after_ we have closed a context
318+
# due to user interruption but _before_ the `isFinal` message for the context
319+
# is received.
320+
logger.debug(f"Ignoring message from unavailable context: {received_ctx_id}")
321+
continue
322+
323+
if msg.get("audio"):
295324
await self.stop_ttfb_metrics()
296-
frame = TTSAudioRawFrame(
297-
audio=base64.b64decode(msg["audio"]),
298-
sample_rate=self.sample_rate,
299-
num_channels=1,
300-
)
301-
await self.push_frame(frame)
302-
elif msg.get("error_code"):
303-
await self.push_frame(TTSStoppedFrame())
304-
await self.stop_all_metrics()
305-
await self.push_error(error_msg=f"Error: {msg['message']}")
306-
else:
307-
await self.push_error(error_msg=f"Unknown message type: {msg}")
325+
audio = base64.b64decode(msg["audio"])
326+
frame = TTSAudioRawFrame(audio, self.sample_rate, 1)
327+
await self.append_to_audio_context(received_ctx_id, frame)
308328

309329
async def _keepalive_task_handler(self):
310330
"""Send periodic keepalive messages to maintain WebSocket connection."""
311-
KEEPALIVE_SLEEP = 3
331+
KEEPALIVE_SLEEP = 10
312332
while True:
313333
await asyncio.sleep(KEEPALIVE_SLEEP)
314334
try:
315335
if self._websocket and self._websocket.state is State.OPEN:
316-
keepalive_message = {"transcript": " "}
317-
logger.trace("Sending keepalive message")
336+
if self._context_id:
337+
keepalive_message = {
338+
"transcript": " ",
339+
"context_id": self._context_id,
340+
}
341+
logger.trace("Sending keepalive message")
342+
else:
343+
# It's possible to have a user interruption which clears the context
344+
# without generating a new TTS response. In this case, we'll just send
345+
# an empty message to keep the connection alive.
346+
keepalive_message = {"transcript": " "}
347+
logger.trace("Sending keepalive without context")
318348
await self._websocket.send(json.dumps(keepalive_message))
319349
except websockets.ConnectionClosed as e:
320350
logger.warning(f"{self} keepalive error: {e}")
321351
break
322352

353+
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
354+
"""Handle interruption by closing the current context."""
355+
await super()._handle_interruption(frame, direction)
356+
357+
# Close the current context when interrupted without closing the websocket
358+
if self._context_id and self._websocket:
359+
try:
360+
await self._websocket.send(
361+
json.dumps(
362+
{"context_id": self._context_id, "close_context": True, "transcript": ""}
363+
)
364+
)
365+
except Exception as e:
366+
logger.error(f"Error closing context on interruption: {e}")
367+
self._context_id = None
368+
self._started = False
369+
323370
@traced_tts
324371
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
325372
"""Generate speech from text using Async API websocket endpoint.
@@ -336,21 +383,29 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
336383
if not self._websocket or self._websocket.state is State.CLOSED:
337384
await self._connect()
338385

339-
if not self._started:
340-
await self.start_ttfb_metrics()
341-
yield TTSStartedFrame()
342-
self._started = True
343-
344-
msg = self._build_msg(text=text, force=True)
345-
346386
try:
347-
await self._get_websocket().send(msg)
348-
await self.start_tts_usage_metrics(text)
387+
if not self._started:
388+
await self.start_ttfb_metrics()
389+
yield TTSStartedFrame()
390+
self._started = True
391+
392+
if not self._context_id:
393+
self._context_id = str(uuid.uuid4())
394+
if not self.audio_context_available(self._context_id):
395+
await self.create_audio_context(self._context_id)
396+
397+
msg = self._build_msg(text=text, force=True, context_id=self._context_id)
398+
await self._get_websocket().send(msg)
399+
await self.start_tts_usage_metrics(text)
400+
else:
401+
if self._websocket and self._context_id:
402+
msg = self._build_msg(text=text, force=True, context_id=self._context_id)
403+
await self._get_websocket().send(msg)
404+
349405
except Exception as e:
350406
yield ErrorFrame(error=f"Unknown error occurred: {e}")
351407
yield TTSStoppedFrame()
352-
await self._disconnect()
353-
await self._connect()
408+
self._started = False
354409
return
355410
yield None
356411
except Exception as e:
@@ -490,7 +545,14 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
490545
await self.push_error(error_msg=f"Async API error: {error_text}")
491546
raise Exception(f"Async API returned status {response.status}: {error_text}")
492547

493-
audio_data = await response.read()
548+
# Read streaming bytes; stop TTFB on the *first* received chunk
549+
buffer = bytearray()
550+
async for chunk in response.content.iter_chunked(64 * 1024):
551+
if not chunk:
552+
continue
553+
await self.stop_ttfb_metrics()
554+
buffer.extend(chunk)
555+
audio_data = bytes(buffer)
494556

495557
await self.start_tts_usage_metrics(text)
496558

0 commit comments

Comments
 (0)