Skip to content

Commit c4ae402

Browse files
author
Ashot
committed
Adjustments of Async TTS for multicontext websocket support
1 parent 15067c6 commit c4ae402

1 file changed

Lines changed: 40 additions & 47 deletions

File tree

  • src/pipecat/services/asyncai

src/pipecat/services/asyncai/tts.py

Lines changed: 40 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
import asyncio
1010
import base64
1111
import json
12-
from typing import AsyncGenerator, Optional, Dict
13-
1412
import uuid
13+
from typing import AsyncGenerator, Optional
14+
1515
import aiohttp
1616
from loguru import logger
1717
from pydantic import BaseModel
@@ -127,10 +127,6 @@ def __init__(
127127
**kwargs,
128128
)
129129

130-
self._contexts: Dict[str, asyncio.Queue] = {}
131-
self._audio_context_task = None
132-
self._context_id = None
133-
134130
params = params or AsyncAITTSService.InputParams()
135131

136132
self._api_key = api_key
@@ -153,6 +149,30 @@ def __init__(
153149
self._receive_task = None
154150
self._keepalive_task = None
155151
self._started = False
152+
self._context_id = None
153+
154+
def can_generate_metrics(self) -> bool:
155+
"""Check if this service can generate processing metrics.
156+
157+
Returns:
158+
True, as Async service supports metrics generation.
159+
"""
160+
return True
161+
162+
def language_to_service_language(self, language: Language) -> Optional[str]:
163+
"""Convert a Language enum to Async language format.
164+
165+
Args:
166+
language: The language to convert.
167+
168+
Returns:
169+
The Async-specific language code, or None if not supported.
170+
"""
171+
return language_to_async_language(language)
172+
173+
def _build_msg(self, text: str = "", context_id: str = "", force: bool = False) -> str:
174+
msg = {"transcript": text, "context_id": context_id, "force": force}
175+
return json.dumps(msg)
156176

157177
async def start(self, frame: StartFrame):
158178
"""Start the Async TTS service.
@@ -182,29 +202,6 @@ async def cancel(self, frame: CancelFrame):
182202
await super().cancel(frame)
183203
await self._disconnect()
184204

185-
def can_generate_metrics(self) -> bool:
186-
"""Check if this service can generate processing metrics.
187-
188-
Returns:
189-
True, as Async service supports metrics generation.
190-
"""
191-
return True
192-
193-
def language_to_service_language(self, language: Language) -> Optional[str]:
194-
"""Convert a Language enum to Async language format.
195-
196-
Args:
197-
language: The language to convert.
198-
199-
Returns:
200-
The Async-specific language code, or None if not supported.
201-
"""
202-
return language_to_async_language(language)
203-
204-
def _build_msg(self, text: str = "", context_id: str = "", force: bool = False) -> str:
205-
msg = {"transcript": text, "context_id": context_id, "force": force}
206-
return json.dumps(msg)
207-
208205
async def _connect(self):
209206
await super()._connect()
210207

@@ -264,7 +261,7 @@ async def _disconnect_websocket(self):
264261
await self._websocket.close()
265262
logger.debug("Disconnected from Async")
266263
except Exception as e:
267-
logger.error(f"{self} error closing websocket: {e}")
264+
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
268265
finally:
269266
self._websocket = None
270267
self._context_id = None
@@ -338,7 +335,7 @@ async def _keepalive_task_handler(self):
338335
if self._websocket and self._websocket.state is State.OPEN:
339336
if self._context_id:
340337
keepalive_message = {
341-
"transcript": " ",
338+
"transcript": " ",
342339
"context_id": self._context_id,
343340
}
344341
logger.trace("Sending keepalive message")
@@ -397,24 +394,22 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
397394
if not self.audio_context_available(self._context_id):
398395
await self.create_audio_context(self._context_id)
399396

400-
msg = self._build_msg(text=" ", context_id=self._context_id)
401-
await self._get_websocket().send(msg)
402397
msg = self._build_msg(text=text, force=True, context_id=self._context_id)
403398
await self._get_websocket().send(msg)
404399
await self.start_tts_usage_metrics(text)
405400
else:
406401
if self._websocket and self._context_id:
407402
msg = self._build_msg(text=text, force=True, context_id=self._context_id)
408-
await self._get_websocket().send(msg)
403+
await self._get_websocket().send(msg)
409404

410405
except Exception as e:
411-
logger.error(f"{self} error sending message: {e}")
406+
yield ErrorFrame(error=f"Unknown error occurred: {e}")
412407
yield TTSStoppedFrame()
413408
self._started = False
414409
return
415410
yield None
416411
except Exception as e:
417-
logger.error(f"{self} exception: {e}")
412+
yield ErrorFrame(error=f"Unknown error occurred: {e}")
418413

419414

420415
class AsyncAIHttpTTSService(TTSService):
@@ -526,25 +521,24 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
526521
"""
527522
logger.debug(f"{self}: Generating TTS [{text}]")
528523

529-
first_byte_seen = False
530524
try:
531525
voice_config = {"mode": "id", "id": self._voice_id}
526+
await self.start_ttfb_metrics()
532527
payload = {
533528
"model_id": self._model_name,
534529
"transcript": text,
535530
"voice": voice_config,
536531
"output_format": self._settings["output_format"],
537532
"language": self._settings["language"],
538533
}
534+
yield TTSStartedFrame()
539535
headers = {
540536
"version": self._api_version,
541537
"x-api-key": self._api_key,
542538
"Content-Type": "application/json",
543539
}
544540
url = f"{self._base_url}/text_to_speech/streaming"
545541

546-
yield TTSStartedFrame()
547-
await self.start_ttfb_metrics()
548542
async with self._session.post(url, json=payload, headers=headers) as response:
549543
if response.status != 200:
550544
error_text = await response.text()
@@ -556,23 +550,22 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
556550
async for chunk in response.content.iter_chunked(64 * 1024):
557551
if not chunk:
558552
continue
559-
if not first_byte_seen:
560-
first_byte_seen = True
561-
await self.stop_ttfb_metrics()
562-
await self.start_tts_usage_metrics(text)
563-
553+
await self.stop_ttfb_metrics()
564554
buffer.extend(chunk)
565555
audio_data = bytes(buffer)
566556

567-
yield TTSAudioRawFrame(
557+
await self.start_tts_usage_metrics(text)
558+
559+
frame = TTSAudioRawFrame(
568560
audio=audio_data,
569561
sample_rate=self.sample_rate,
570562
num_channels=1,
571563
)
572564

565+
yield frame
566+
573567
except Exception as e:
574568
await self.push_error(error_msg=f"Unknown error occurred: {e}", exception=e)
575569
finally:
576-
if not first_byte_seen:
577-
await self.stop_ttfb_metrics()
570+
await self.stop_ttfb_metrics()
578571
yield TTSStoppedFrame()

0 commit comments

Comments
 (0)