Skip to content

Commit 1848df2

Browse files
committed
feat: add special approval request otid for openai streaming (#5744)
* feat: add special approval request otid for openai streaming * fix import
1 parent c67bdd9 commit 1848df2

File tree

1 file changed

+22
-23
lines changed

1 file changed

+22
-23
lines changed

letta/interfaces/openai_streaming_interface.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from letta.schemas.message import Message
5353
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
5454
from letta.server.rest_api.json_parser import OptimisticJSONParser
55+
from letta.server.rest_api.utils import decrement_message_uuid
5556
from letta.streaming_utils import (
5657
FunctionArgumentsStreamHandler,
5758
JSONInnerThoughtsExtractor,
@@ -325,14 +326,14 @@ async def _process_chunk(
325326
self.tool_call_name = str(self.function_name_buffer)
326327
if self.tool_call_name in self.requires_approval_tools:
327328
tool_call_msg = ApprovalRequestMessage(
328-
id=self.letta_message_id,
329+
id=decrement_message_uuid(self.letta_message_id),
329330
date=datetime.now(timezone.utc),
330331
tool_call=ToolCallDelta(
331332
name=self.function_name_buffer,
332333
arguments=None,
333334
tool_call_id=self.function_id_buffer,
334335
),
335-
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
336+
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
336337
run_id=self.run_id,
337338
step_id=self.step_id,
338339
)
@@ -413,15 +414,15 @@ async def _process_chunk(
413414
message_index += 1
414415
if self.function_name_buffer in self.requires_approval_tools:
415416
tool_call_msg = ApprovalRequestMessage(
416-
id=self.letta_message_id,
417+
id=decrement_message_uuid(self.letta_message_id),
417418
date=datetime.now(timezone.utc),
418419
tool_call=ToolCallDelta(
419420
name=self.function_name_buffer,
420421
arguments=combined_chunk,
421422
tool_call_id=self.function_id_buffer,
422423
),
423424
# name=name,
424-
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
425+
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
425426
run_id=self.run_id,
426427
step_id=self.step_id,
427428
)
@@ -452,15 +453,15 @@ async def _process_chunk(
452453
message_index += 1
453454
if self.function_name_buffer in self.requires_approval_tools:
454455
tool_call_msg = ApprovalRequestMessage(
455-
id=self.letta_message_id,
456+
id=decrement_message_uuid(self.letta_message_id),
456457
date=datetime.now(timezone.utc),
457458
tool_call=ToolCallDelta(
458459
name=None,
459460
arguments=updates_main_json,
460461
tool_call_id=self.function_id_buffer,
461462
),
462463
# name=name,
463-
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
464+
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
464465
run_id=self.run_id,
465466
step_id=self.step_id,
466467
)
@@ -603,6 +604,8 @@ async def process(
603604
# For reasoning models, emit a hidden reasoning message before the first chunk
604605
if not self.emitted_hidden_reasoning and is_openai_reasoning_model(self.model):
605606
self.emitted_hidden_reasoning = True
607+
if prev_message_type and prev_message_type != "hidden_reasoning_message":
608+
message_index += 1
606609
hidden_message = HiddenReasoningMessage(
607610
id=self.letta_message_id,
608611
date=datetime.now(timezone.utc),
@@ -614,7 +617,6 @@ async def process(
614617
)
615618
self.content_messages.append(hidden_message)
616619
prev_message_type = hidden_message.message_type
617-
message_index += 1 # Increment for the next message
618620
yield hidden_message
619621

620622
async for chunk in stream:
@@ -676,6 +678,8 @@ async def _process_chunk(
676678
message_delta = choice.delta
677679

678680
if message_delta.content is not None and message_delta.content != "":
681+
if prev_message_type and prev_message_type != "assistant_message":
682+
message_index += 1
679683
assistant_msg = AssistantMessage(
680684
id=self.letta_message_id,
681685
content=message_delta.content,
@@ -686,7 +690,6 @@ async def _process_chunk(
686690
)
687691
self.content_messages.append(assistant_msg)
688692
prev_message_type = assistant_msg.message_type
689-
message_index += 1
690693
yield assistant_msg
691694

692695
if (
@@ -698,6 +701,8 @@ async def _process_chunk(
698701
delta = chunk.choices[0].delta
699702
reasoning_content = getattr(delta, "reasoning_content", None)
700703
if reasoning_content is not None and reasoning_content != "":
704+
if prev_message_type and prev_message_type != "reasoning_message":
705+
message_index += 1
701706
reasoning_msg = ReasoningMessage(
702707
id=self.letta_message_id,
703708
date=datetime.now(timezone.utc).isoformat(),
@@ -710,7 +715,6 @@ async def _process_chunk(
710715
)
711716
self.content_messages.append(reasoning_msg)
712717
prev_message_type = reasoning_msg.message_type
713-
message_index += 1
714718
yield reasoning_msg
715719

716720
if message_delta.tool_calls is not None and len(message_delta.tool_calls) > 0:
@@ -746,19 +750,21 @@ async def _process_chunk(
746750

747751
if self.requires_approval_tools:
748752
tool_call_msg = ApprovalRequestMessage(
749-
id=self.letta_message_id,
753+
id=decrement_message_uuid(self.letta_message_id),
750754
date=datetime.now(timezone.utc),
751755
tool_call=ToolCallDelta(
752756
name=tool_call.function.name,
753757
arguments=tool_call.function.arguments,
754758
tool_call_id=tool_call.id,
755759
),
756760
# name=name,
757-
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
761+
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
758762
run_id=self.run_id,
759763
step_id=self.step_id,
760764
)
761765
else:
766+
if prev_message_type and prev_message_type != "tool_call_message":
767+
message_index += 1
762768
tool_call_delta = ToolCallDelta(
763769
name=tool_call.function.name,
764770
arguments=tool_call.function.arguments,
@@ -774,8 +780,7 @@ async def _process_chunk(
774780
run_id=self.run_id,
775781
step_id=self.step_id,
776782
)
777-
prev_message_type = tool_call_msg.message_type
778-
message_index += 1 # Increment for the next message
783+
prev_message_type = tool_call_msg.message_type
779784
yield tool_call_msg
780785

781786

@@ -971,11 +976,9 @@ async def _process_event(
971976
# cache for approval if/elses
972977
self.tool_call_name = name
973978
if self.tool_call_name and self.tool_call_name in self.requires_approval_tools:
974-
if prev_message_type and prev_message_type != "approval_request_message":
975-
message_index += 1
976979
yield ApprovalRequestMessage(
977-
id=self.letta_message_id,
978-
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
980+
id=decrement_message_uuid(self.letta_message_id),
981+
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
979982
date=datetime.now(timezone.utc),
980983
tool_call=ToolCallDelta(
981984
name=name,
@@ -985,7 +988,6 @@ async def _process_event(
985988
run_id=self.run_id,
986989
step_id=self.step_id,
987990
)
988-
prev_message_type = "tool_call_message"
989991
else:
990992
if prev_message_type and prev_message_type != "tool_call_message":
991993
message_index += 1
@@ -1141,11 +1143,9 @@ async def _process_event(
11411143
delta = event.delta
11421144

11431145
if self.tool_call_name and self.tool_call_name in self.requires_approval_tools:
1144-
if prev_message_type and prev_message_type != "approval_request_message":
1145-
message_index += 1
11461146
yield ApprovalRequestMessage(
1147-
id=self.letta_message_id,
1148-
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
1147+
id=decrement_message_uuid(self.letta_message_id),
1148+
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
11491149
date=datetime.now(timezone.utc),
11501150
tool_call=ToolCallDelta(
11511151
name=None,
@@ -1155,7 +1155,6 @@ async def _process_event(
11551155
run_id=self.run_id,
11561156
step_id=self.step_id,
11571157
)
1158-
prev_message_type = "approval_request_message"
11591158
else:
11601159
if prev_message_type and prev_message_type != "tool_call_message":
11611160
message_index += 1

0 commit comments

Comments
 (0)