99import asyncio
1010import base64
1111import json
12+ import uuid
1213from typing import AsyncGenerator , Optional
1314
1415import aiohttp
2728 TTSStoppedFrame ,
2829)
2930from pipecat .processors .frame_processor import FrameDirection
30- from pipecat .services .tts_service import InterruptibleTTSService , TTSService
31+ from pipecat .services .tts_service import AudioContextTTSService , TTSService
3132from pipecat .transcriptions .language import Language , resolve_language
3233from pipecat .utils .tracing .service_decorators import traced_tts
3334
@@ -72,7 +73,7 @@ def language_to_async_language(language: Language) -> Optional[str]:
7273 return resolve_language (language , LANGUAGE_MAP , use_base_code = True )
7374
7475
75- class AsyncAITTSService (InterruptibleTTSService ):
76+ class AsyncAITTSService (AudioContextTTSService ):
7677 """Async TTS service with WebSocket streaming.
7778
7879 Provides text-to-speech using Async's streaming WebSocket API.
@@ -148,6 +149,7 @@ def __init__(
148149 self ._receive_task = None
149150 self ._keepalive_task = None
150151 self ._started = False
152+ self ._context_id = None
151153
152154 def can_generate_metrics (self ) -> bool :
153155 """Check if this service can generate processing metrics.
@@ -168,8 +170,8 @@ def language_to_service_language(self, language: Language) -> Optional[str]:
168170 """
169171 return language_to_async_language (language )
170172
171- def _build_msg (self , text : str = "" , force : bool = False ) -> str :
172- msg = {"transcript" : text , "force" : force }
173+ def _build_msg (self , text : str = "" , context_id : str = "" , force : bool = False ) -> str :
174+ msg = {"transcript" : text , "context_id" : context_id , " force" : force }
173175 return json .dumps (msg )
174176
175177 async def start (self , frame : StartFrame ):
@@ -253,11 +255,16 @@ async def _disconnect_websocket(self):
253255
254256 if self ._websocket :
255257 logger .debug ("Disconnecting from Async" )
258+ # Close all contexts and the socket
259+ if self ._context_id :
260+ await self ._websocket .send (json .dumps ({"terminate" : True }))
256261 await self ._websocket .close ()
262+ logger .debug ("Disconnected from Async" )
257263 except Exception as e :
258264 await self .push_error (error_msg = f"Unknown error occurred: { e } " , exception = e )
259265 finally :
260266 self ._websocket = None
267+ self ._context_id = None
261268 self ._started = False
262269 await self ._call_event_handler ("on_disconnected" )
263270
@@ -268,10 +275,10 @@ def _get_websocket(self):
268275
269276 async def flush_audio (self ):
270277 """Flush any pending audio."""
271- if not self ._websocket :
278+ if not self ._context_id or not self . _websocket :
272279 return
273280 logger .trace (f"{ self } : flushing audio" )
274- msg = self ._build_msg (text = " " , force = True )
281+ msg = self ._build_msg (text = " " , context_id = self . _context_id , force = True )
275282 await self ._websocket .send (msg )
276283
277284 async def push_frame (self , frame : Frame , direction : FrameDirection = FrameDirection .DOWNSTREAM ):
@@ -291,35 +298,75 @@ async def _receive_messages(self):
291298 if not msg :
292299 continue
293300
294- elif msg .get ("audio" ):
301+ received_ctx_id = msg .get ("context_id" )
302+ # Handle final messages first, regardless of context availability
303+ # At the moment, this message is received AFTER the close_context message is
304+ # sent, so it doesn't serve any functional purpose. For now, we'll just log it.
305+ if msg .get ("final" ) is True :
306+ logger .trace (f"Received final message for context { received_ctx_id } " )
307+ continue
308+
309+ # Check if this message belongs to the current context.
310+ if not self .audio_context_available (received_ctx_id ):
311+ if self ._context_id == received_ctx_id :
312+ logger .debug (
313+ f"Received a delayed message, recreating the context: { self ._context_id } "
314+ )
315+ await self .create_audio_context (self ._context_id )
316+ else :
317+ # This can happen if a message is received _after_ we have closed a context
318+ # due to user interruption but _before_ the `isFinal` message for the context
319+ # is received.
320+ logger .debug (f"Ignoring message from unavailable context: { received_ctx_id } " )
321+ continue
322+
323+ if msg .get ("audio" ):
295324 await self .stop_ttfb_metrics ()
296- frame = TTSAudioRawFrame (
297- audio = base64 .b64decode (msg ["audio" ]),
298- sample_rate = self .sample_rate ,
299- num_channels = 1 ,
300- )
301- await self .push_frame (frame )
302- elif msg .get ("error_code" ):
303- await self .push_frame (TTSStoppedFrame ())
304- await self .stop_all_metrics ()
305- await self .push_error (error_msg = f"Error: { msg ['message' ]} " )
306- else :
307- await self .push_error (error_msg = f"Unknown message type: { msg } " )
325+ audio = base64 .b64decode (msg ["audio" ])
326+ frame = TTSAudioRawFrame (audio , self .sample_rate , 1 )
327+ await self .append_to_audio_context (received_ctx_id , frame )
308328
309329 async def _keepalive_task_handler (self ):
310330 """Send periodic keepalive messages to maintain WebSocket connection."""
311- KEEPALIVE_SLEEP = 3
331+ KEEPALIVE_SLEEP = 10
312332 while True :
313333 await asyncio .sleep (KEEPALIVE_SLEEP )
314334 try :
315335 if self ._websocket and self ._websocket .state is State .OPEN :
316- keepalive_message = {"transcript" : " " }
317- logger .trace ("Sending keepalive message" )
336+ if self ._context_id :
337+ keepalive_message = {
338+ "transcript" : " " ,
339+ "context_id" : self ._context_id ,
340+ }
341+ logger .trace ("Sending keepalive message" )
342+ else :
343+ # It's possible to have a user interruption which clears the context
344+ # without generating a new TTS response. In this case, we'll just send
345+ # an empty message to keep the connection alive.
346+ keepalive_message = {"transcript" : " " }
347+ logger .trace ("Sending keepalive without context" )
318348 await self ._websocket .send (json .dumps (keepalive_message ))
319349 except websockets .ConnectionClosed as e :
320350 logger .warning (f"{ self } keepalive error: { e } " )
321351 break
322352
353+ async def _handle_interruption (self , frame : InterruptionFrame , direction : FrameDirection ):
354+ """Handle interruption by closing the current context."""
355+ await super ()._handle_interruption (frame , direction )
356+
357+ # Close the current context when interrupted without closing the websocket
358+ if self ._context_id and self ._websocket :
359+ try :
360+ await self ._websocket .send (
361+ json .dumps (
362+ {"context_id" : self ._context_id , "close_context" : True , "transcript" : "" }
363+ )
364+ )
365+ except Exception as e :
366+ logger .error (f"Error closing context on interruption: { e } " )
367+ self ._context_id = None
368+ self ._started = False
369+
323370 @traced_tts
324371 async def run_tts (self , text : str ) -> AsyncGenerator [Frame , None ]:
325372 """Generate speech from text using Async API websocket endpoint.
@@ -336,21 +383,29 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
336383 if not self ._websocket or self ._websocket .state is State .CLOSED :
337384 await self ._connect ()
338385
339- if not self ._started :
340- await self .start_ttfb_metrics ()
341- yield TTSStartedFrame ()
342- self ._started = True
343-
344- msg = self ._build_msg (text = text , force = True )
345-
346386 try :
347- await self ._get_websocket ().send (msg )
348- await self .start_tts_usage_metrics (text )
387+ if not self ._started :
388+ await self .start_ttfb_metrics ()
389+ yield TTSStartedFrame ()
390+ self ._started = True
391+
392+ if not self ._context_id :
393+ self ._context_id = str (uuid .uuid4 ())
394+ if not self .audio_context_available (self ._context_id ):
395+ await self .create_audio_context (self ._context_id )
396+
397+ msg = self ._build_msg (text = text , force = True , context_id = self ._context_id )
398+ await self ._get_websocket ().send (msg )
399+ await self .start_tts_usage_metrics (text )
400+ else :
401+ if self ._websocket and self ._context_id :
402+ msg = self ._build_msg (text = text , force = True , context_id = self ._context_id )
403+ await self ._get_websocket ().send (msg )
404+
349405 except Exception as e :
350406 yield ErrorFrame (error = f"Unknown error occurred: { e } " )
351407 yield TTSStoppedFrame ()
352- await self ._disconnect ()
353- await self ._connect ()
408+ self ._started = False
354409 return
355410 yield None
356411 except Exception as e :
@@ -490,7 +545,14 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
490545 await self .push_error (error_msg = f"Async API error: { error_text } " )
491546 raise Exception (f"Async API returned status { response .status } : { error_text } " )
492547
493- audio_data = await response .read ()
548+ # Read streaming bytes; stop TTFB on the *first* received chunk
549+ buffer = bytearray ()
550+ async for chunk in response .content .iter_chunked (64 * 1024 ):
551+ if not chunk :
552+ continue
553+ await self .stop_ttfb_metrics ()
554+ buffer .extend (chunk )
555+ audio_data = bytes (buffer )
494556
495557 await self .start_tts_usage_metrics (text )
496558
0 commit comments