Skip to content

Commit 3787b15

Browse files
committed
cleaned up some internal errors
1 parent fc199e5 commit 3787b15

File tree

3 files changed

+96
-95
lines changed

3 files changed

+96
-95
lines changed

api/agent/handler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ async def add_message(
7171
**{
7272
key: value
7373
for key, value in tool_call[tool_call["type"]].items()
74+
if tool_call["type"] in tool_call
7475
},
7576
}
7677
for tool_call in tool_calls

api/connection.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
from fastapi.websockets import WebSocketState
55

66

7-
# MessageType: TypeAlias = Literal[
8-
# "user", "assistant", "system", "audio", "console", "interrupt", "function", "agent"
9-
# ]
10-
117
class Connection:
128
def __init__(self, websocket: WebSocket):
139
self.websocket = websocket
@@ -28,6 +24,9 @@ async def close(self):
2824
if self.websocket.client_state == WebSocketState.CONNECTED:
2925
await self.websocket.close()
3026

27+
async def get_state(self) -> WebSocketState:
28+
return self.websocket.client_state
29+
3130
@property
3231
def state(self) -> WebSocketState:
3332
return self.websocket.client_state
@@ -41,7 +40,8 @@ def __init__(self):
4140
async def connect(self, id: str, websocket: WebSocket) -> Connection:
4241
# destroy existing connection if it exists
4342
if id in self.active_connections:
44-
await self.active_connections[id].close()
43+
if self.active_connections[id].state == WebSocketState.CONNECTED:
44+
await self.active_connections[id].close()
4545
del self.active_connections[id]
4646

4747
await websocket.accept()

api/voice/session.py

Lines changed: 90 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -155,95 +155,93 @@ async def receive_realtime(self):
155155

156156
match event.type:
157157
case "error":
158-
print(json.dumps(event.model_dump(), indent=2))
159-
# await self._handle_error(event)
158+
159+
await self.handle_error(event)
160160
case "session.created":
161-
await self._session_created(event)
161+
await self.session_created(event)
162162
case "session.updated":
163-
await self._session_updated(event)
163+
await self.session_updated(event)
164164
case "conversation.created":
165-
await self._conversation_created(event)
165+
await self.conversation_created(event)
166166
case "conversation.item.created":
167-
await self._conversation_item_created(event)
167+
await self.conversation_item_created(event)
168168
case "conversation.item.input_audio_transcription.completed":
169-
await self._conversation_item_input_audio_transcription_completed(
169+
await self.conversation_item_input_audio_transcription_completed(
170170
event
171171
)
172172
case "conversation.item.input_audio_transcription.delta":
173-
await self._conversation_item_input_audio_transcription_delta(event)
173+
await self.conversation_item_input_audio_transcription_delta(event) # type: ignore
174174
case "conversation.item.input_audio_transcription.failed":
175-
await self._conversation_item_input_audio_transcription_failed(
176-
event
177-
)
175+
await self.conversation_item_input_audio_transcription_failed(event)
178176
case "conversation.item.truncated":
179-
await self._conversation_item_truncated(event)
177+
await self.conversation_item_truncated(event)
180178
case "conversation.item.deleted":
181-
await self._conversation_item_deleted(event)
179+
await self.conversation_item_deleted(event)
182180
case "input_audio_buffer.committed":
183-
await self._input_audio_buffer_committed(event)
181+
await self.input_audio_buffer_committed(event)
184182
case "input_audio_buffer.cleared":
185-
await self._input_audio_buffer_cleared(event)
183+
await self.input_audio_buffer_cleared(event)
186184
case "input_audio_buffer.speech_started":
187-
await self._input_audio_buffer_speech_started(event)
185+
await self.input_audio_buffer_speech_started(event)
188186
case "input_audio_buffer.speech_stopped":
189-
await self._input_audio_buffer_speech_stopped(event)
187+
await self.input_audio_buffer_speech_stopped(event)
190188
case "response.created":
191-
await self._response_created(event)
189+
await self.response_created(event)
192190
case "response.done":
193-
await self._response_done(event)
191+
await self.response_done(event)
194192
case "response.output_item.added":
195-
await self._response_output_item_added(event)
193+
await self.response_output_item_added(event)
196194
case "response.output_item.done":
197-
await self._response_output_item_done(event)
195+
await self.response_output_item_done(event)
198196
case "response.content_part.added":
199-
await self._response_content_part_added(event)
197+
await self.response_content_part_added(event)
200198
case "response.content_part.done":
201-
await self._response_content_part_done(event)
199+
await self.response_content_part_done(event)
202200
case "response.text.delta":
203-
await self._response_text_delta(event)
201+
await self.response_text_delta(event) # type: ignore
204202
case "response.text.done":
205-
await self._response_text_done(event)
203+
await self.response_text_done(event)
206204
case "response.audio_transcript.delta":
207-
await self._response_audio_transcript_delta(event)
205+
await self.response_audio_transcript_delta(event) # type: ignore
208206
case "response.audio_transcript.done":
209-
await self._response_audio_transcript_done(event)
207+
await self.response_audio_transcript_done(event)
210208
case "response.audio.delta":
211-
await self._response_audio_delta(event)
209+
await self.response_audio_delta(event) # type: ignore
212210
case "response.audio.done":
213-
await self._response_audio_done(event)
211+
await self.response_audio_done(event) # type: ignore
214212
case "response.function_call_arguments.delta":
215-
await self._response_function_call_arguments_delta(event)
213+
await self.response_function_call_arguments_delta(event) # type: ignore
216214
case "response.function_call_arguments.done":
217-
await self._response_function_call_arguments_done(event)
215+
await self.response_function_call_arguments_done(event) # type: ignore
218216
case "rate_limits.updated":
219-
await self._rate_limits_updated(event)
217+
await self.rate_limits_updated(event)
220218
case _:
221219
print(
222220
f"Unhandled event type {event.type}",
223221
)
224222

225-
@trace(name="error")
226-
async def _handle_error(self, event: ErrorEvent):
227-
print("Error event", event.error)
223+
@trace
224+
async def handle_error(self, event: ErrorEvent):
225+
print(json.dumps(event.model_dump(), indent=2))
228226

229-
@trace(name="session.created")
230-
async def _session_created(self, event: SessionCreatedEvent):
227+
@trace
228+
async def session_created(self, event: SessionCreatedEvent):
231229
pass
232230

233-
@trace(name="session.updated")
234-
async def _session_updated(self, event: SessionUpdatedEvent):
231+
@trace
232+
async def session_updated(self, event: SessionUpdatedEvent):
235233
pass
236234

237-
@trace(name="conversation.created")
238-
async def _conversation_created(self, event: ConversationCreatedEvent):
235+
@trace
236+
async def conversation_created(self, event: ConversationCreatedEvent):
239237
pass
240238

241-
@trace(name="conversation.item.created")
242-
async def _conversation_item_created(self, event: ConversationItemCreatedEvent):
239+
@trace
240+
async def conversation_item_created(self, event: ConversationItemCreatedEvent):
243241
pass
244242

245-
@trace(name="conversation.item.input_audio_transcription.completed")
246-
async def _conversation_item_input_audio_transcription_completed(
243+
@trace
244+
async def conversation_item_input_audio_transcription_completed(
247245
self, event: ConversationItemInputAudioTranscriptionCompletedEvent
248246
):
249247
if event.transcript is None or len(event.transcript.strip()) == 0:
@@ -269,53 +267,51 @@ async def _conversation_item_input_audio_transcription_completed(
269267
},
270268
)
271269

272-
async def _conversation_item_input_audio_transcription_delta(
270+
async def conversation_item_input_audio_transcription_delta(
273271
self, event: ConversationItemInputAudioTranscriptionDeltaEvent
274272
):
275273
pass
276274

277-
@trace(name="conversation.item.input_audio_transcription.failed")
278-
async def _conversation_item_input_audio_transcription_failed(
275+
@trace
276+
async def conversation_item_input_audio_transcription_failed(
279277
self, event: ConversationItemInputAudioTranscriptionFailedEvent
280278
):
281279
pass
282280

283-
@trace(name="conversation.item.truncated")
284-
async def _conversation_item_truncated(self, event: ConversationItemTruncatedEvent):
281+
@trace
282+
async def conversation_item_truncated(self, event: ConversationItemTruncatedEvent):
285283
pass
286284

287-
@trace(name="conversation.item.deleted")
288-
async def _conversation_item_deleted(self, event: ConversationItemDeletedEvent):
285+
@trace
286+
async def conversation_item_deleted(self, event: ConversationItemDeletedEvent):
289287
pass
290288

291-
@trace(name="input_audio_buffer.committed")
292-
async def _input_audio_buffer_committed(
293-
self, event: InputAudioBufferCommittedEvent
294-
):
289+
@trace
290+
async def input_audio_buffer_committed(self, event: InputAudioBufferCommittedEvent):
295291
pass
296292

297-
@trace(name="input_audio_buffer.cleared")
298-
async def _input_audio_buffer_cleared(self, event: InputAudioBufferClearedEvent):
293+
@trace
294+
async def input_audio_buffer_cleared(self, event: InputAudioBufferClearedEvent):
299295
pass
300296

301-
@trace(name="input_audio_buffer.speech_started")
302-
async def _input_audio_buffer_speech_started(
297+
@trace
298+
async def input_audio_buffer_speech_started(
303299
self, event: InputAudioBufferSpeechStartedEvent
304300
):
305301
await self.connection.send_update(Update.interrupt())
306302

307-
@trace(name="input_audio_buffer.speech_stopped")
308-
async def _input_audio_buffer_speech_stopped(
303+
@trace
304+
async def input_audio_buffer_speech_stopped(
309305
self, event: InputAudioBufferSpeechStoppedEvent
310306
):
311307
pass
312308

313-
@trace(name="response.created")
314-
async def _response_created(self, event: ResponseCreatedEvent):
309+
@trace
310+
async def response_created(self, event: ResponseCreatedEvent):
315311
pass
316312

317-
@trace(name="response.done")
318-
async def _response_done(self, event: ResponseDoneEvent):
313+
@trace
314+
async def response_done(self, event: ResponseDoneEvent):
319315
if event.response.output is not None and len(event.response.output) > 0:
320316
output = event.response.output[0]
321317
match output.type:
@@ -361,19 +357,24 @@ async def _response_done(self, event: ResponseDoneEvent):
361357

362358
self.active = False
363359

364-
@trace(name="response.output_item.added")
365-
async def _response_output_item_added(self, event: ResponseOutputItemAddedEvent):
360+
@trace
361+
async def response_output_item_added(self, event: ResponseOutputItemAddedEvent):
366362
pass
367363

368-
@trace(name="response.output_item.done")
369-
async def _response_output_item_done(self, event: ResponseOutputItemDoneEvent):
364+
@trace
365+
async def response_output_item_done(self, event: ResponseOutputItemDoneEvent):
370366
if event.item.type == "function_call":
367+
try:
368+
args = json.loads(event.item.arguments or "{}")
369+
except json.JSONDecodeError:
370+
args = {}
371+
371372
await self.connection.send_update(
372373
Update.function(
373374
id=str(event.item.id),
374375
call_id=str(event.item.call_id),
375376
name=str(event.item.name),
376-
arguments=json.loads(event.item.arguments or "{}"),
377+
arguments=args,
377378
)
378379
)
379380

@@ -389,55 +390,54 @@ async def _response_output_item_done(self, event: ResponseOutputItemDoneEvent):
389390
},
390391
)
391392

392-
@trace(name="response.content_part.added")
393-
async def _response_content_part_added(self, event: ResponseContentPartAddedEvent):
393+
@trace
394+
async def response_content_part_added(self, event: ResponseContentPartAddedEvent):
394395
pass
395396

396-
@trace(name="response.content_part.done")
397-
async def _response_content_part_done(self, event: ResponseContentPartDoneEvent):
397+
@trace
398+
async def response_content_part_done(self, event: ResponseContentPartDoneEvent):
398399
pass
399400

400-
@trace(name="response.text.delta")
401-
async def _response_text_delta(self, event: ResponseTextDeltaEvent):
401+
async def response_text_delta(self, event: ResponseTextDeltaEvent):
402402
pass
403403

404-
@trace(name="response.text.done")
405-
async def _response_text_done(self, event: ResponseTextDoneEvent):
404+
@trace
405+
async def response_text_done(self, event: ResponseTextDoneEvent):
406406
pass
407407

408-
async def _response_audio_transcript_delta(
408+
async def response_audio_transcript_delta(
409409
self, event: ResponseAudioTranscriptDeltaEvent
410410
):
411411
pass
412412

413-
@trace(name="response.audio.transcript.done")
414-
async def _response_audio_transcript_done(
413+
@trace
414+
async def response_audio_transcript_done(
415415
self, event: ResponseAudioTranscriptDoneEvent
416416
):
417417
pass
418418

419-
async def _response_audio_delta(self, event: ResponseAudioDeltaEvent):
419+
async def response_audio_delta(self, event: ResponseAudioDeltaEvent):
420420
await self.connection.send_update(
421421
Update.audio(id=event.event_id, data=event.delta)
422422
)
423423

424-
@trace(name="response.audio.done")
425-
async def _response_audio_done(self, event: ResponseAudioDoneEvent):
424+
@trace
425+
async def response_audio_done(self, event: ResponseAudioDoneEvent):
426426
pass
427427

428-
async def _response_function_call_arguments_delta(
428+
async def response_function_call_arguments_delta(
429429
self, event: ResponseFunctionCallArgumentsDeltaEvent
430430
):
431431
pass
432432

433-
@trace(name="response.function_call_arguments.done")
434-
async def _response_function_call_arguments_done(
433+
@trace
434+
async def response_function_call_arguments_done(
435435
self, event: ResponseFunctionCallArgumentsDoneEvent
436436
):
437437
pass
438438

439-
@trace(name="rate_limits.updated")
440-
async def _rate_limits_updated(self, event: RateLimitsUpdatedEvent):
439+
@trace
440+
async def rate_limits_updated(self, event: RateLimitsUpdatedEvent):
441441
pass
442442

443443
@trace

0 commit comments

Comments
 (0)