66
77"""Piper TTS service implementation."""
88
9- from typing import AsyncGenerator , Optional
9+ import asyncio
10+ from pathlib import Path
11+ from typing import AsyncGenerator , AsyncIterator , Optional
1012
1113import aiohttp
1214from loguru import logger
2022from pipecat .services .tts_service import TTSService
2123from pipecat .utils .tracing .service_decorators import traced_tts
2224
25+ try :
26+ from piper import PiperVoice
27+ from piper .download_voices import download_voice
28+ except ModuleNotFoundError as e :
29+ logger .error (f"Exception: { e } " )
30+ logger .error ("In order to use Piper, you need to `pip install pipecat-ai[piper]`." )
31+ raise Exception (f"Missing module: { e } " )
32+
2333
24- # This assumes a running TTS service running: https://github.com/OHF-Voice/piper1-gpl/blob/main/docs/API_HTTP.md
2534class PiperTTSService (TTSService ):
2635 """Piper TTS service implementation.
2736
37+ Provides local text-to-speech synthesis using Piper voice models. Automatically
38+ downloads voice models if not already present and resamples audio output to
39+ match the configured sample rate.
40+ """
41+
42+ def __init__ (
43+ self ,
44+ * ,
45+ voice_id : str ,
46+ download_dir : Optional [Path ] = None ,
47+ force_redownload : bool = False ,
48+ use_cuda : bool = False ,
49+ ** kwargs ,
50+ ):
51+ """Initialize the Piper TTS service.
52+
53+ Args:
54+ voice_id: Piper voice model identifier (e.g. `en_US-ryan-high`).
55+ download_dir: Directory for storing voice model files. Defaults to
56+ the current working directory.
57+ force_redownload: Re-download the voice model even if it already exists.
58+ use_cuda: Use CUDA for GPU-accelerated inference.
59+ **kwargs: Additional arguments passed to the parent `TTSService`.
60+ """
61+ super ().__init__ (** kwargs )
62+
63+ self ._voice_id = voice_id
64+
65+ download_dir = download_dir or Path .cwd ()
66+
67+ model_file = f"{ voice_id } .onnx"
68+ model_path = Path (download_dir ) / model_file
69+
70+ if not model_path .exists ():
71+ logger .debug (f"Downloading Piper '{ voice_id } ' model" )
72+ download_voice (voice_id , download_dir , force_redownload = force_redownload )
73+
74+ logger .debug (f"Loading Piper '{ voice_id } ' model from { model_path } " )
75+
76+ self ._voice = PiperVoice .load (model_path , use_cuda = use_cuda )
77+
78+ logger .debug (f"Loaded Piper '{ voice_id } ' model" )
79+
80+ def can_generate_metrics (self ) -> bool :
81+ """Check if this service can generate processing metrics.
82+
83+ Returns:
84+ True, as Piper service supports metrics generation.
85+ """
86+ return True
87+
88+ @traced_tts
89+ async def run_tts (self , text : str ) -> AsyncGenerator [Frame , None ]:
90+ """Generate speech from text using Piper.
91+
92+ Args:
93+ text: The text to convert to speech.
94+
95+ Yields:
96+ Frame: Audio frames containing the synthesized speech and status frames.
97+ """
98+
99+ def async_next (it ):
100+ try :
101+ return next (it )
102+ except StopIteration :
103+ return None
104+
105+ async def async_iterator (iterator ) -> AsyncIterator [bytes ]:
106+ while True :
107+ item = await asyncio .to_thread (async_next , iterator )
108+ if item is None :
109+ return
110+ yield item .audio_int16_bytes
111+
112+ logger .debug (f"{ self } : Generating TTS [{ text } ]" )
113+
114+ try :
115+ await self .start_ttfb_metrics ()
116+
117+ await self .start_tts_usage_metrics (text )
118+
119+ yield TTSStartedFrame ()
120+
121+ async for frame in self ._stream_audio_frames_from_iterator (
122+ async_iterator (self ._voice .synthesize (text )),
123+ in_sample_rate = self ._voice .config .sample_rate ,
124+ ):
125+ await self .stop_ttfb_metrics ()
126+ yield frame
127+ except Exception as e :
128+ logger .error (f"{ self } exception: { e } " )
129+ yield ErrorFrame (error = f"Unknown error occurred: { e } " )
130+ finally :
131+ logger .debug (f"{ self } : Finished TTS [{ text } ]" )
132+ await self .stop_ttfb_metrics ()
133+ yield TTSStoppedFrame ()
134+
135+
136+ # This assumes a running TTS service running:
137+ # https://github.com/OHF-Voice/piper1-gpl/blob/main/docs/API_HTTP.md
138+ #
139+ # Usage:
140+ #
141+ # $ uv pip install "piper-tts[http]"
142+ # $ uv run python -m piper.http_server -m en_US-ryan-high
143+ #
144+ class PiperHttpTTSService (TTSService ):
145+ """Piper HTTP TTS service implementation.
146+
28147 Provides integration with Piper's HTTP TTS server for text-to-speech
29148 synthesis. Supports streaming audio generation with configurable sample
30149 rates and automatic WAV header removal.
@@ -35,28 +154,26 @@ def __init__(
35154 * ,
36155 base_url : str ,
37156 aiohttp_session : aiohttp .ClientSession ,
38- # When using Piper, the sample rate of the generated audio depends on the
39- # voice model being used.
40- sample_rate : Optional [int ] = None ,
157+ voice_id : Optional [str ] = None ,
41158 ** kwargs ,
42159 ):
43160 """Initialize the Piper TTS service.
44161
45162 Args:
46163 base_url: Base URL for the Piper TTS HTTP server.
47164 aiohttp_session: aiohttp ClientSession for making HTTP requests.
48- sample_rate: Output sample rate. If None, uses the voice model's native rate .
165+ voice_id: Piper voice model identifier (e.g. `en_US-ryan-high`) .
49166 **kwargs: Additional arguments passed to the parent TTSService.
50167 """
51- super ().__init__ (sample_rate = sample_rate , ** kwargs )
168+ super ().__init__ (** kwargs )
52169
53170 if base_url .endswith ("/" ):
54171 logger .warning ("Base URL ends with a slash, this is not allowed." )
55172 base_url = base_url [:- 1 ]
56173
57174 self ._base_url = base_url
58175 self ._session = aiohttp_session
59- self ._settings = { "base_url" : base_url }
176+ self ._model_id = voice_id
60177
61178 def can_generate_metrics (self ) -> bool :
62179 """Check if this service can generate processing metrics.
@@ -83,9 +200,12 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
83200 try :
84201 await self .start_ttfb_metrics ()
85202
86- async with self ._session .post (
87- self ._base_url , json = {"text" : text }, headers = headers
88- ) as response :
203+ data = {
204+ "text" : text ,
205+ "voice" : self ._model_id ,
206+ }
207+
208+ async with self ._session .post (self ._base_url , json = data , headers = headers ) as response :
89209 if response .status != 200 :
90210 error = await response .text ()
91211 yield ErrorFrame (
0 commit comments