11import asyncio
22import json
3+ from typing import Dict , List , Literal , Union
4+ from fastapi import WebSocket
35from prompty .tracer import trace
4- from typing import Dict , List , Literal
5- from fastapi import WebSocket , WebSocketDisconnect
6+ from fastapi import WebSocketDisconnect
67from pydantic import BaseModel
78from prompty .tracer import Tracer
89from fastapi .websockets import WebSocketState
@@ -26,26 +27,38 @@ class Message(BaseModel):
2627class RealtimeSession :
2728
2829 def __init__ (self , realtime : RealtimeVoiceClient , client : WebSocket ):
29- self .realtime : RealtimeVoiceClient = realtime
30- self .client : WebSocket = client
30+ self .realtime : Union [ RealtimeVoiceClient , None ] = realtime
31+ self .client : Union [ WebSocket , None ] = client
3132
3233 async def send_message (self , message : Message ):
33- await self .client .send_json (message .model_dump ())
34+ if self .client is not None :
35+ await self .client .send_json (message .model_dump ())
3436
3537 async def send_audio (self , audio : Message ):
3638 # send audio to client, format into bytes
37- await self .client .send_json (audio .model_dump ())
39+ if self .client is not None :
40+ await self .client .send_json (audio .model_dump ())
3841
3942 async def send_console (self , message : Message ):
40- await self .client .send_json (message .model_dump ())
41-
42- async def send_realtime_instructions (self , instructions : str ):
43- await self .realtime .send_session_update (instructions )
43+ if self .client is not None :
44+ await self .client .send_json (message .model_dump ())
45+
46+ async def send_realtime_instructions (
47+ self ,
48+ instructions : Union [str | None ] = None ,
49+ threshold : float = 0.8 ,
50+ silence_duration_ms : int = 500 ,
51+ prefix_padding_ms : int = 300 ,
52+ ):
53+ if self .realtime is not None :
54+ await self .realtime .send_session_update (
55+ instructions , threshold , silence_duration_ms , prefix_padding_ms
56+ )
4457
4558 @trace
4659 async def receive_realtime (self ):
4760 signature = "api.session.RealtimeSession.receive_realtime"
48- while self .realtime != None and not self .realtime .closed :
61+ while self .realtime is not None and not self .realtime .closed :
4962 async for message in self .realtime .receive_message ():
5063 # print("received message", message.type)
5164 if message is None :
@@ -57,7 +70,9 @@ async def receive_realtime(self):
5770 t (Tracer .SIGNATURE , signature )
5871 t (Tracer .INPUTS , message .content )
5972 await self .send_console (
60- Message (type = "console" , payload = json .dumps (message .content ))
73+ Message (
74+ type = "console" , payload = json .dumps (message .content )
75+ )
6176 )
6277 case "conversation.item.input_audio_transcription.completed" :
6378 with Tracer .start ("receive_user_transcript" ) as t :
@@ -70,9 +85,14 @@ async def receive_realtime(self):
7085 "content" : message .content ,
7186 },
7287 )
73- await self .send_message (
74- Message (type = "user" , payload = message .content )
75- )
88+ if (
89+ message .content is not None
90+ and isinstance (message .content , str )
91+ and message .content != ""
92+ ):
93+ await self .send_message (
94+ Message (type = "user" , payload = message .content )
95+ )
7696
7797 case "response.audio_transcript.done" :
7898 with Tracer .start ("receive_assistant_transcript" ) as t :
@@ -85,15 +105,25 @@ async def receive_realtime(self):
85105 "content" : message .content ,
86106 },
87107 )
88- # audio stream
89- await self .send_message (
90- Message (type = "assistant" , payload = message .content )
91- )
108+ if (
109+ message .content is not None
110+ and isinstance (message .content , str )
111+ and message .content != ""
112+ ):
113+ # audio stream
114+ await self .send_message (
115+ Message (type = "assistant" , payload = message .content )
116+ )
92117
93118 case "response.audio.delta" :
94- await self .send_audio (
95- Message (type = "audio" , payload = message .content )
96- )
119+ if (
120+ message .content is not None
121+ and isinstance (message .content , str )
122+ and message .content != ""
123+ ):
124+ await self .send_audio (
125+ Message (type = "audio" , payload = message .content )
126+ )
97127
98128 case "response.failed" :
99129 with Tracer .start ("realtime_failure" ) as t :
@@ -139,6 +169,8 @@ async def receive_realtime(self):
139169 @trace
140170 async def receive_client (self ):
141171 signature = "api.session.RealtimeSession.receive_client"
172+ if self .client is None or self .realtime is None :
173+ return
142174 try :
143175 while self .client .client_state != WebSocketState .DISCONNECTED :
144176 message = await self .client .receive_text ()
@@ -170,6 +202,8 @@ async def receive_client(self):
170202 print ("Realtime Socket Disconnected" )
171203
172204 async def close (self ):
205+ if self .client is None or self .realtime is None :
206+ return
173207 try :
174208 await self .client .close ()
175209 await self .realtime .close ()
@@ -182,7 +216,7 @@ async def close(self):
182216class ChatSession :
183217 def __init__ (self , client : WebSocket ):
184218 self .client = client
185- self .realtime : RealtimeSession = None
219+ self .realtime : Union [ RealtimeSession , None ] = None
186220 self .context : List [str ] = []
187221
188222 async def send_message (self , message : Message ):
@@ -193,16 +227,20 @@ def add_realtime(self, realtime: RealtimeSession):
193227
194228 def is_closed (self ):
195229 client_closed = (
196- self .client == None
230+ self .client is None
197231 or self .client .client_state == WebSocketState .DISCONNECTED
198232 )
199- realtime_closed = self .realtime == None or self .realtime .realtime .closed
233+ realtime_closed = (
234+ self .realtime is None
235+ or self .realtime .realtime is None
236+ or self .realtime .realtime .closed
237+ )
200238 return client_closed and realtime_closed
201239
202240 @trace
203241 async def receive_chat (self ):
204242 while (
205- self .client != None
243+ self .client is not None
206244 and self .client .client_state != WebSocketState .DISCONNECTED
207245 ):
208246 with Tracer .start ("chat_turn" ) as t :
@@ -214,7 +252,7 @@ async def receive_chat(self):
214252 Tracer .INPUTS ,
215253 {
216254 "request" : msg .text ,
217- "image" : msg .image != None ,
255+ "image" : msg .image is not None ,
218256 },
219257 )
220258
@@ -250,9 +288,10 @@ async def receive_chat(self):
250288 },
251289 )
252290
253-
254-
255291 async def start_realtime (self , prompt : str ):
292+ if self .realtime is None :
293+ raise Exception ("Realtime session not available" )
294+
256295 await self .realtime .send_realtime_instructions (prompt )
257296 tasks = [
258297 asyncio .create_task (self .realtime .receive_realtime ()),
0 commit comments