Skip to content

Commit f3b72e9

Browse files
authored
Merge pull request #3585 from pipecat-ai/aleix/improve-piper-tts-support
improve Piper TTS support
2 parents b77a50d + bd00587 commit f3b72e9

12 files changed

Lines changed: 349 additions & 32 deletions

File tree

.github/workflows/coverage.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,14 @@ jobs:
3333
3434
- name: Install dependencies
3535
run: |
36-
uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain --extra livekit --extra websocket
36+
uv sync --group dev \
37+
--extra anthropic \
38+
--extra aws \
39+
--extra google \
40+
--extra langchain \
41+
--extra livekit \
42+
--extra piper \
43+
--extra websocket
3744
3845
- name: Run tests with coverage
3946
run: |

.github/workflows/tests.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,14 @@ jobs:
3737
3838
- name: Install dependencies
3939
run: |
40-
uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain --extra livekit --extra websocket
40+
uv sync --group dev \
41+
--extra anthropic \
42+
--extra aws \
43+
--extra google \
44+
--extra langchain \
45+
--extra livekit \
46+
--extra piper \
47+
--extra websocket
4148
4249
- name: Test with pytest
4350
run: |

changelog/3585.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
- Added local `PiperTTSService` for offline text-to-speech using Piper voice models. The existing HTTP-based service has been renamed to `PiperHttpTTSService`.

changelog/3585.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
- Fixed `PiperHttpTTSService` (olf `PiperTTSService`) to resample audio output based on the model's sample rate parsed from the WAV header.

examples/foundational/01-say-one-thing-piper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pipecat.pipeline.task import PipelineTask
1717
from pipecat.runner.types import RunnerArguments
1818
from pipecat.runner.utils import create_transport
19-
from pipecat.services.piper.tts import PiperTTSService
19+
from pipecat.services.piper.tts import PiperHttpTTSService
2020
from pipecat.transports.base_transport import BaseTransport, TransportParams
2121
from pipecat.transports.daily.transport import DailyParams
2222
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
@@ -39,7 +39,7 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
3939

4040
# Create an HTTP session
4141
async with aiohttp.ClientSession() as session:
42-
tts = PiperTTSService(
42+
tts = PiperHttpTTSService(
4343
base_url=os.getenv("PIPER_BASE_URL"), aiohttp_session=session, sample_rate=24000
4444
)
4545

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
#
2+
# Copyright (c) 2024-2026, Daily
3+
#
4+
# SPDX-License-Identifier: BSD 2-Clause License
5+
#
6+
7+
import os
8+
9+
from dotenv import load_dotenv
10+
from loguru import logger
11+
12+
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
13+
from pipecat.audio.vad.silero import SileroVADAnalyzer
14+
from pipecat.audio.vad.vad_analyzer import VADParams
15+
from pipecat.frames.frames import LLMRunFrame
16+
from pipecat.pipeline.pipeline import Pipeline
17+
from pipecat.pipeline.runner import PipelineRunner
18+
from pipecat.pipeline.task import PipelineParams, PipelineTask
19+
from pipecat.processors.aggregators.llm_context import LLMContext
20+
from pipecat.processors.aggregators.llm_response_universal import (
21+
LLMContextAggregatorPair,
22+
LLMUserAggregatorParams,
23+
)
24+
from pipecat.runner.types import RunnerArguments
25+
from pipecat.runner.utils import create_transport
26+
from pipecat.services.deepgram.stt import DeepgramSTTService
27+
from pipecat.services.openai.llm import OpenAILLMService
28+
from pipecat.services.piper.tts import PiperTTSService
29+
from pipecat.transports.base_transport import BaseTransport, TransportParams
30+
from pipecat.transports.daily.transport import DailyParams
31+
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
32+
from pipecat.turns.user_stop import TurnAnalyzerUserTurnStopStrategy
33+
from pipecat.turns.user_turn_strategies import UserTurnStrategies
34+
35+
load_dotenv(override=True)
36+
37+
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
38+
# instantiated. The function will be called when the desired transport gets
39+
# selected.
40+
transport_params = {
41+
"daily": lambda: DailyParams(
42+
audio_in_enabled=True,
43+
audio_out_enabled=True,
44+
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
45+
),
46+
"twilio": lambda: FastAPIWebsocketParams(
47+
audio_in_enabled=True,
48+
audio_out_enabled=True,
49+
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
50+
),
51+
"webrtc": lambda: TransportParams(
52+
audio_in_enabled=True,
53+
audio_out_enabled=True,
54+
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
55+
),
56+
}
57+
58+
59+
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
60+
logger.info(f"Starting bot")
61+
62+
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
63+
64+
tts = PiperTTSService(voice_id="en_US-ryan-high")
65+
66+
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
67+
68+
messages = [
69+
{
70+
"role": "system",
71+
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.",
72+
},
73+
]
74+
75+
context = LLMContext(messages)
76+
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
77+
context,
78+
user_params=LLMUserAggregatorParams(
79+
user_turn_strategies=UserTurnStrategies(
80+
stop=[TurnAnalyzerUserTurnStopStrategy(turn_analyzer=LocalSmartTurnAnalyzerV3())]
81+
),
82+
),
83+
)
84+
85+
pipeline = Pipeline(
86+
[
87+
transport.input(), # Transport user input
88+
stt,
89+
user_aggregator, # User responses
90+
llm, # LLM
91+
tts, # TTS
92+
transport.output(), # Transport bot output
93+
assistant_aggregator, # Assistant spoken responses
94+
]
95+
)
96+
97+
task = PipelineTask(
98+
pipeline,
99+
params=PipelineParams(
100+
enable_metrics=True,
101+
enable_usage_metrics=True,
102+
),
103+
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
104+
)
105+
106+
@transport.event_handler("on_client_connected")
107+
async def on_client_connected(transport, client):
108+
logger.info(f"Client connected")
109+
# Kick off the conversation.
110+
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
111+
await task.queue_frames([LLMRunFrame()])
112+
113+
@transport.event_handler("on_client_disconnected")
114+
async def on_client_disconnected(transport, client):
115+
logger.info(f"Client disconnected")
116+
await task.cancel()
117+
118+
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
119+
120+
await runner.run(task)
121+
122+
123+
async def bot(runner_args: RunnerArguments):
124+
"""Main bot entry point compatible with Pipecat Cloud."""
125+
transport = await create_transport(runner_args, transport_params)
126+
await run_bot(transport, runner_args)
127+
128+
129+
if __name__ == "__main__":
130+
from pipecat.runner.run import main
131+
132+
main()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ rnnoise = [ "pyrnnoise~=0.4.1" ]
9595
openpipe = [ "openpipe>=4.50.0,<6" ]
9696
openrouter = []
9797
perplexity = []
98+
piper = [ "piper-tts>=1.3.0,<2" ]
9899
playht = [ "pipecat-ai[websockets-base]" ]
99100
qwen = []
100101
remote-smart-turn = []

scripts/evals/run-release-evals.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def EVAL_VISION_IMAGE(*, eval_speaks_first: bool = False):
138138
("07zf-interruptible-gradium.py", EVAL_SIMPLE_MATH),
139139
("07zg-interruptible-camb.py", EVAL_SIMPLE_MATH),
140140
("07zh-interruptible-hathora.py", EVAL_SIMPLE_MATH),
141+
("07zi-interruptible-piper.py", EVAL_SIMPLE_MATH),
141142
# Needs a local XTTS docker instance running.
142143
# ("07i-interruptible-xtts.py", EVAL_SIMPLE_MATH),
143144
# Needs a Krisp license.

src/pipecat/services/piper/tts.py

Lines changed: 131 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
"""Piper TTS service implementation."""
88

9-
from typing import AsyncGenerator, Optional
9+
import asyncio
10+
from pathlib import Path
11+
from typing import AsyncGenerator, AsyncIterator, Optional
1012

1113
import aiohttp
1214
from loguru import logger
@@ -20,11 +22,128 @@
2022
from pipecat.services.tts_service import TTSService
2123
from pipecat.utils.tracing.service_decorators import traced_tts
2224

25+
try:
26+
from piper import PiperVoice
27+
from piper.download_voices import download_voice
28+
except ModuleNotFoundError as e:
29+
logger.error(f"Exception: {e}")
30+
logger.error("In order to use Piper, you need to `pip install pipecat-ai[piper]`.")
31+
raise Exception(f"Missing module: {e}")
32+
2333

24-
# This assumes a running TTS service running: https://github.com/OHF-Voice/piper1-gpl/blob/main/docs/API_HTTP.md
2534
class PiperTTSService(TTSService):
2635
"""Piper TTS service implementation.
2736
37+
Provides local text-to-speech synthesis using Piper voice models. Automatically
38+
downloads voice models if not already present and resamples audio output to
39+
match the configured sample rate.
40+
"""
41+
42+
def __init__(
43+
self,
44+
*,
45+
voice_id: str,
46+
download_dir: Optional[Path] = None,
47+
force_redownload: bool = False,
48+
use_cuda: bool = False,
49+
**kwargs,
50+
):
51+
"""Initialize the Piper TTS service.
52+
53+
Args:
54+
voice_id: Piper voice model identifier (e.g. `en_US-ryan-high`).
55+
download_dir: Directory for storing voice model files. Defaults to
56+
the current working directory.
57+
force_redownload: Re-download the voice model even if it already exists.
58+
use_cuda: Use CUDA for GPU-accelerated inference.
59+
**kwargs: Additional arguments passed to the parent `TTSService`.
60+
"""
61+
super().__init__(**kwargs)
62+
63+
self._voice_id = voice_id
64+
65+
download_dir = download_dir or Path.cwd()
66+
67+
model_file = f"{voice_id}.onnx"
68+
model_path = Path(download_dir) / model_file
69+
70+
if not model_path.exists():
71+
logger.debug(f"Downloading Piper '{voice_id}' model")
72+
download_voice(voice_id, download_dir, force_redownload=force_redownload)
73+
74+
logger.debug(f"Loading Piper '{voice_id}' model from {model_path}")
75+
76+
self._voice = PiperVoice.load(model_path, use_cuda=use_cuda)
77+
78+
logger.debug(f"Loaded Piper '{voice_id}' model")
79+
80+
def can_generate_metrics(self) -> bool:
81+
"""Check if this service can generate processing metrics.
82+
83+
Returns:
84+
True, as Piper service supports metrics generation.
85+
"""
86+
return True
87+
88+
@traced_tts
89+
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
90+
"""Generate speech from text using Piper.
91+
92+
Args:
93+
text: The text to convert to speech.
94+
95+
Yields:
96+
Frame: Audio frames containing the synthesized speech and status frames.
97+
"""
98+
99+
def async_next(it):
100+
try:
101+
return next(it)
102+
except StopIteration:
103+
return None
104+
105+
async def async_iterator(iterator) -> AsyncIterator[bytes]:
106+
while True:
107+
item = await asyncio.to_thread(async_next, iterator)
108+
if item is None:
109+
return
110+
yield item.audio_int16_bytes
111+
112+
logger.debug(f"{self}: Generating TTS [{text}]")
113+
114+
try:
115+
await self.start_ttfb_metrics()
116+
117+
await self.start_tts_usage_metrics(text)
118+
119+
yield TTSStartedFrame()
120+
121+
async for frame in self._stream_audio_frames_from_iterator(
122+
async_iterator(self._voice.synthesize(text)),
123+
in_sample_rate=self._voice.config.sample_rate,
124+
):
125+
await self.stop_ttfb_metrics()
126+
yield frame
127+
except Exception as e:
128+
logger.error(f"{self} exception: {e}")
129+
yield ErrorFrame(error=f"Unknown error occurred: {e}")
130+
finally:
131+
logger.debug(f"{self}: Finished TTS [{text}]")
132+
await self.stop_ttfb_metrics()
133+
yield TTSStoppedFrame()
134+
135+
136+
# This assumes a running TTS service running:
137+
# https://github.com/OHF-Voice/piper1-gpl/blob/main/docs/API_HTTP.md
138+
#
139+
# Usage:
140+
#
141+
# $ uv pip install "piper-tts[http]"
142+
# $ uv run python -m piper.http_server -m en_US-ryan-high
143+
#
144+
class PiperHttpTTSService(TTSService):
145+
"""Piper HTTP TTS service implementation.
146+
28147
Provides integration with Piper's HTTP TTS server for text-to-speech
29148
synthesis. Supports streaming audio generation with configurable sample
30149
rates and automatic WAV header removal.
@@ -35,28 +154,26 @@ def __init__(
35154
*,
36155
base_url: str,
37156
aiohttp_session: aiohttp.ClientSession,
38-
# When using Piper, the sample rate of the generated audio depends on the
39-
# voice model being used.
40-
sample_rate: Optional[int] = None,
157+
voice_id: Optional[str] = None,
41158
**kwargs,
42159
):
43160
"""Initialize the Piper TTS service.
44161
45162
Args:
46163
base_url: Base URL for the Piper TTS HTTP server.
47164
aiohttp_session: aiohttp ClientSession for making HTTP requests.
48-
sample_rate: Output sample rate. If None, uses the voice model's native rate.
165+
voice_id: Piper voice model identifier (e.g. `en_US-ryan-high`).
49166
**kwargs: Additional arguments passed to the parent TTSService.
50167
"""
51-
super().__init__(sample_rate=sample_rate, **kwargs)
168+
super().__init__(**kwargs)
52169

53170
if base_url.endswith("/"):
54171
logger.warning("Base URL ends with a slash, this is not allowed.")
55172
base_url = base_url[:-1]
56173

57174
self._base_url = base_url
58175
self._session = aiohttp_session
59-
self._settings = {"base_url": base_url}
176+
self._model_id = voice_id
60177

61178
def can_generate_metrics(self) -> bool:
62179
"""Check if this service can generate processing metrics.
@@ -83,9 +200,12 @@ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
83200
try:
84201
await self.start_ttfb_metrics()
85202

86-
async with self._session.post(
87-
self._base_url, json={"text": text}, headers=headers
88-
) as response:
203+
data = {
204+
"text": text,
205+
"voice": self._model_id,
206+
}
207+
208+
async with self._session.post(self._base_url, json=data, headers=headers) as response:
89209
if response.status != 200:
90210
error = await response.text()
91211
yield ErrorFrame(

0 commit comments

Comments
 (0)