Skip to content

Commit 4d974e9

Browse files
committed
added voice settings
1 parent 750983e commit 4d974e9

File tree

12 files changed

+313
-69
lines changed

12 files changed

+313
-69
lines changed

api/chat/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import List
2+
from typing import List, Union
33
import prompty
44
import prompty.azure
55
from prompty.tracer import trace
@@ -10,7 +10,10 @@
1010

1111
@trace
1212
async def create_response(
13-
customer: str, question: str, context: List[str] = [], image: str = None
13+
customer: str,
14+
question: str,
15+
context: List[str] = [],
16+
image: Union[str, None] = None,
1417
):
1518
inputs = {"customer": customer, "question": question, "context": context}
1619
if image:
@@ -26,7 +29,7 @@ async def create_response(
2629

2730
customer = "Seth Juarez"
2831
question = "My friend just sent me this and I'm worried I don't have the right gear for my camping trip. Can you help me? CALL ME"
29-
context = []
32+
context: List[str] = []
3033
image = "winter.jpg"
3134

3235
asyncio.run(create_response(customer, question, context, image))

api/main.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,15 @@
55
from typing import List
66
from fastapi.responses import StreamingResponse
77
from jinja2 import Environment, FileSystemLoader
8-
from rtclient import RTLowLevelClient
9-
from api import repeat
8+
9+
from rtclient import RTLowLevelClient # type: ignore
1010
from api.realtime import RealtimeVoiceClient
1111
from api.session import Message, RealtimeSession, SessionManager
1212
from contextlib import asynccontextmanager
1313
from fastapi.middleware.cors import CORSMiddleware
1414
from azure.core.credentials import AzureKeyCredential
1515
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
16-
from api.session import SessionManager
1716
from api.suggestions import SimpleMessage, create_suggestion, suggestion_requested
18-
from prompty.tracer import Tracer, trace
1917
from dotenv import load_dotenv
2018
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
2119

@@ -24,7 +22,7 @@
2422
load_dotenv()
2523

2624
AZURE_VOICE_ENDPOINT = os.getenv("AZURE_VOICE_ENDPOINT")
27-
AZURE_VOICE_KEY = os.getenv("AZURE_VOICE_KEY")
25+
AZURE_VOICE_KEY = os.getenv("AZURE_VOICE_KEY", "fake_key")
2826

2927
LOCAL_TRACING_ENABLED = os.getenv("LOCAL_TRACING_ENABLED", "true") == "true"
3028
init_tracing(local_tracing=LOCAL_TRACING_ENABLED)
@@ -120,20 +118,34 @@ async def voice_endpoint(websocket: WebSocket):
120118
message = Message(**chat_items)
121119

122120
# get current username
121+
# and receive any parameters
123122
user_message = await websocket.receive_json()
124123
user = Message(**user_message)
125124

125+
settings = json.loads(user.payload)
126+
print(
127+
"Starting voice session with settings:\n",
128+
json.dumps(settings, indent=2),
129+
)
130+
126131
# create voice system message
127132
# TODO: retrieve context from chat messages via thread id
128133
system_message = env.get_template("script.jinja2").render(
129-
customer=user.payload,
134+
customer=settings["user"] if "user" in settings else "Seth",
130135
purchases=purchases,
131136
context=json.loads(message.payload),
132137
products=products,
133138
)
134139

135140
session = RealtimeSession(RealtimeVoiceClient(rt, verbose=False), websocket)
136-
await session.send_realtime_instructions(system_message)
141+
await session.send_realtime_instructions(
142+
system_message,
143+
threshold=settings["threshold"] if "threshold" in settings else 0.8,
144+
silence_duration_ms=(
145+
settings["silence"] if "silence" in settings else 500
146+
),
147+
prefix_padding_ms=(settings["prefix"] if "prefix" in settings else 300),
148+
)
137149
tasks = [
138150
asyncio.create_task(session.receive_realtime()),
139151
asyncio.create_task(session.receive_client()),
@@ -144,10 +156,4 @@ async def voice_endpoint(websocket: WebSocket):
144156
print("Voice Socket Disconnected", e)
145157

146158

147-
@repeat(seconds=60)
148-
@trace
149-
async def cleanup_sessions():
150-
await SessionManager.clear_closed_sessions()
151-
152-
153159
FastAPIInstrumentor.instrument_app(app, exclude_spans=["send", "receive"])

api/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pydantic import BaseModel
2-
from typing import List, Literal, Optional
2+
from typing import Literal, Optional
33

44

55
class Action(BaseModel):

api/realtime.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from dataclasses import dataclass
2-
import json
3-
from typing import Any, AsyncGenerator
4-
from rtclient import (
2+
from typing import Any, AsyncGenerator, Union
3+
from prompty.tracer import trace
4+
5+
from rtclient import ( # type: ignore
56
InputAudioBufferAppendMessage,
67
InputAudioBufferClearMessage,
78
InputAudioTranscription,
@@ -14,8 +15,7 @@
1415
SessionUpdateParams,
1516
SystemMessageItem,
1617
UserMessageItem,
17-
)
18-
from prompty.tracer import trace, Tracer
18+
)
1919

2020
# things to have in here
2121
# - send user message
@@ -125,16 +125,22 @@ async def send_system_message(self, message: str):
125125
)
126126

127127
@trace
128-
async def send_session_update(self, instructions: str = None):
128+
async def send_session_update(
129+
self,
130+
instructions: Union[str | None] = None,
131+
threshold: float = 0.8,
132+
silence_duration_ms: int = 500,
133+
prefix_padding_ms: int = 300,
134+
):
129135
if self.client is None:
130136
raise Exception("Client not set")
131137

132138
session = SessionUpdateParams(
133139
turn_detection=ServerVAD(
134140
type="server_vad",
135-
threshold=0.8,
136-
silence_duration_ms=500,
137-
prefix_padding_ms=300,
141+
threshold=threshold,
142+
silence_duration_ms=silence_duration_ms,
143+
prefix_padding_ms=prefix_padding_ms,
138144
),
139145
input_audio_transcription=InputAudioTranscription(model="whisper-1"),
140146
voice="shimmer",

api/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
fastapi[standard]
22
websockets
33
python-dotenv
4-
prompty[azure,serverless]==0.1.39
4+
prompty[azure,serverless]==0.1.47
55
opentelemetry-instrumentation
66
azure-monitor-opentelemetry-exporter
77
opentelemetry-instrumentation-fastapi

api/session.py

Lines changed: 68 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import asyncio
22
import json
3+
from typing import Dict, List, Literal, Union
4+
from fastapi import WebSocket
35
from prompty.tracer import trace
4-
from typing import Dict, List, Literal
5-
from fastapi import WebSocket, WebSocketDisconnect
6+
from fastapi import WebSocketDisconnect
67
from pydantic import BaseModel
78
from prompty.tracer import Tracer
89
from fastapi.websockets import WebSocketState
@@ -26,26 +27,38 @@ class Message(BaseModel):
2627
class RealtimeSession:
2728

2829
def __init__(self, realtime: RealtimeVoiceClient, client: WebSocket):
29-
self.realtime: RealtimeVoiceClient = realtime
30-
self.client: WebSocket = client
30+
self.realtime: Union[RealtimeVoiceClient, None] = realtime
31+
self.client: Union[WebSocket, None] = client
3132

3233
async def send_message(self, message: Message):
33-
await self.client.send_json(message.model_dump())
34+
if self.client is not None:
35+
await self.client.send_json(message.model_dump())
3436

3537
async def send_audio(self, audio: Message):
3638
# send audio to client, format into bytes
37-
await self.client.send_json(audio.model_dump())
39+
if self.client is not None:
40+
await self.client.send_json(audio.model_dump())
3841

3942
async def send_console(self, message: Message):
40-
await self.client.send_json(message.model_dump())
41-
42-
async def send_realtime_instructions(self, instructions: str):
43-
await self.realtime.send_session_update(instructions)
43+
if self.client is not None:
44+
await self.client.send_json(message.model_dump())
45+
46+
async def send_realtime_instructions(
47+
self,
48+
instructions: Union[str | None] = None,
49+
threshold: float = 0.8,
50+
silence_duration_ms: int = 500,
51+
prefix_padding_ms: int = 300,
52+
):
53+
if self.realtime is not None:
54+
await self.realtime.send_session_update(
55+
instructions, threshold, silence_duration_ms, prefix_padding_ms
56+
)
4457

4558
@trace
4659
async def receive_realtime(self):
4760
signature = "api.session.RealtimeSession.receive_realtime"
48-
while self.realtime != None and not self.realtime.closed:
61+
while self.realtime is not None and not self.realtime.closed:
4962
async for message in self.realtime.receive_message():
5063
# print("received message", message.type)
5164
if message is None:
@@ -57,7 +70,9 @@ async def receive_realtime(self):
5770
t(Tracer.SIGNATURE, signature)
5871
t(Tracer.INPUTS, message.content)
5972
await self.send_console(
60-
Message(type="console", payload=json.dumps(message.content))
73+
Message(
74+
type="console", payload=json.dumps(message.content)
75+
)
6176
)
6277
case "conversation.item.input_audio_transcription.completed":
6378
with Tracer.start("receive_user_transcript") as t:
@@ -70,9 +85,14 @@ async def receive_realtime(self):
7085
"content": message.content,
7186
},
7287
)
73-
await self.send_message(
74-
Message(type="user", payload=message.content)
75-
)
88+
if (
89+
message.content is not None
90+
and isinstance(message.content, str)
91+
and message.content != ""
92+
):
93+
await self.send_message(
94+
Message(type="user", payload=message.content)
95+
)
7696

7797
case "response.audio_transcript.done":
7898
with Tracer.start("receive_assistant_transcript") as t:
@@ -85,15 +105,25 @@ async def receive_realtime(self):
85105
"content": message.content,
86106
},
87107
)
88-
# audio stream
89-
await self.send_message(
90-
Message(type="assistant", payload=message.content)
91-
)
108+
if (
109+
message.content is not None
110+
and isinstance(message.content, str)
111+
and message.content != ""
112+
):
113+
# audio stream
114+
await self.send_message(
115+
Message(type="assistant", payload=message.content)
116+
)
92117

93118
case "response.audio.delta":
94-
await self.send_audio(
95-
Message(type="audio", payload=message.content)
96-
)
119+
if (
120+
message.content is not None
121+
and isinstance(message.content, str)
122+
and message.content != ""
123+
):
124+
await self.send_audio(
125+
Message(type="audio", payload=message.content)
126+
)
97127

98128
case "response.failed":
99129
with Tracer.start("realtime_failure") as t:
@@ -139,6 +169,8 @@ async def receive_realtime(self):
139169
@trace
140170
async def receive_client(self):
141171
signature = "api.session.RealtimeSession.receive_client"
172+
if self.client is None or self.realtime is None:
173+
return
142174
try:
143175
while self.client.client_state != WebSocketState.DISCONNECTED:
144176
message = await self.client.receive_text()
@@ -170,6 +202,8 @@ async def receive_client(self):
170202
print("Realtime Socket Disconnected")
171203

172204
async def close(self):
205+
if self.client is None or self.realtime is None:
206+
return
173207
try:
174208
await self.client.close()
175209
await self.realtime.close()
@@ -182,7 +216,7 @@ async def close(self):
182216
class ChatSession:
183217
def __init__(self, client: WebSocket):
184218
self.client = client
185-
self.realtime: RealtimeSession = None
219+
self.realtime: Union[RealtimeSession, None] = None
186220
self.context: List[str] = []
187221

188222
async def send_message(self, message: Message):
@@ -193,16 +227,20 @@ def add_realtime(self, realtime: RealtimeSession):
193227

194228
def is_closed(self):
195229
client_closed = (
196-
self.client == None
230+
self.client is None
197231
or self.client.client_state == WebSocketState.DISCONNECTED
198232
)
199-
realtime_closed = self.realtime == None or self.realtime.realtime.closed
233+
realtime_closed = (
234+
self.realtime is None
235+
or self.realtime.realtime is None
236+
or self.realtime.realtime.closed
237+
)
200238
return client_closed and realtime_closed
201239

202240
@trace
203241
async def receive_chat(self):
204242
while (
205-
self.client != None
243+
self.client is not None
206244
and self.client.client_state != WebSocketState.DISCONNECTED
207245
):
208246
with Tracer.start("chat_turn") as t:
@@ -214,7 +252,7 @@ async def receive_chat(self):
214252
Tracer.INPUTS,
215253
{
216254
"request": msg.text,
217-
"image": msg.image != None,
255+
"image": msg.image is not None,
218256
},
219257
)
220258

@@ -250,9 +288,10 @@ async def receive_chat(self):
250288
},
251289
)
252290

253-
254-
255291
async def start_realtime(self, prompt: str):
292+
if self.realtime is None:
293+
raise Exception("Realtime session not available")
294+
256295
await self.realtime.send_realtime_instructions(prompt)
257296
tasks = [
258297
asyncio.create_task(self.realtime.receive_realtime()),

0 commit comments

Comments
 (0)