1010
1111from letta .agents .base_agent import BaseAgent
1212from letta .agents .ephemeral_summary_agent import EphemeralSummaryAgent
13- from letta .agents .helpers import _create_letta_response , _prepare_in_context_messages_no_persist_async , generate_step_id
13+ from letta .agents .helpers import (
14+ _build_rule_violation_result ,
15+ _create_letta_response ,
16+ _pop_heartbeat ,
17+ _prepare_in_context_messages_no_persist_async ,
18+ _safe_load_dict ,
19+ generate_step_id ,
20+ )
1421from letta .constants import DEFAULT_MAX_STEPS , NON_USER_MSG_PREFIX
1522from letta .errors import ContextWindowExceededError
1623from letta .helpers import ToolRulesSolver
@@ -931,45 +938,15 @@ async def _handle_ai_response(
931938 run_id : Optional [str ] = None ,
932939 ) -> Tuple [List [Message ], bool , Optional [LettaStopReason ]]:
933940 """
934- Now that streaming is done, handle the final AI response.
935- This might yield additional SSE tokens if we do stalling.
936- At the end, set self._continue_execution accordingly.
941+ Handle the final AI response once streaming completes, execute / validate the
942+ tool call, decide whether we should keep stepping, and persist state.
937943 """
938- stop_reason = None
939- # Check if the called tool is allowed by tool name:
940- tool_call_name = tool_call .function .name
941- tool_call_args_str = tool_call .function .arguments
942-
943- # Temp hack to gracefully handle parallel tool calling attempt, only take first one
944- if "}{" in tool_call_args_str :
945- tool_call_args_str = tool_call_args_str .split ("}{" , 1 )[0 ] + "}"
946-
947- try :
948- tool_args = json .loads (tool_call_args_str )
949- assert isinstance (tool_args , dict ), "tool_args must be a dict"
950- except json .JSONDecodeError :
951- tool_args = {}
952- except AssertionError :
953- tool_args = json .loads (tool_args )
954-
955- # Get request heartbeats and coerce to bool
956- request_heartbeat = tool_args .pop ("request_heartbeat" , False )
957- if is_final_step :
958- stop_reason = LettaStopReason (stop_reason = StopReasonType .max_steps .value )
959- self .logger .info ("Agent has reached max steps." )
960- request_heartbeat = False
961- else :
962- # Pre-emptively pop out inner_thoughts
963- tool_args .pop (INNER_THOUGHTS_KWARG , "" )
964-
965- # So this is necessary, because sometimes non-structured outputs makes mistakes
966- if not isinstance (request_heartbeat , bool ):
967- if isinstance (request_heartbeat , str ):
968- request_heartbeat = request_heartbeat .lower () == "true"
969- else :
970- request_heartbeat = bool (request_heartbeat )
971-
972- tool_call_id = tool_call .id or f"call_{ uuid .uuid4 ().hex [:8 ]} "
944+ # 1. Parse and validate the tool-call envelope
945+ tool_call_name : str = tool_call .function .name
946+ tool_call_id : str = tool_call .id or f"call_{ uuid .uuid4 ().hex [:8 ]} "
947+ tool_args = _safe_load_dict (tool_call .function .arguments )
948+ request_heartbeat : bool = _pop_heartbeat (tool_args )
949+ tool_args .pop (INNER_THOUGHTS_KWARG , None )
973950
974951 log_telemetry (
975952 self .logger ,
@@ -979,16 +956,11 @@ async def _handle_ai_response(
979956 tool_call_id = tool_call_id ,
980957 request_heartbeat = request_heartbeat ,
981958 )
982- # Check if tool rule is violated - if so, we'll force continuation
983- tool_rule_violated = tool_call_name not in valid_tool_names
984959
960+ # 2. Execute the tool (or synthesize an error result if disallowed)
961+ tool_rule_violated = tool_call_name not in valid_tool_names
985962 if tool_rule_violated :
986- base_error_message = f"[ToolConstraintError] Cannot call { tool_call_name } , valid tools to call include: { valid_tool_names } ."
987- violated_rule_messages = tool_rules_solver .guess_rule_violation (tool_call_name )
988- if violated_rule_messages :
989- bullet_points = "\n " .join (f"\t - { msg } " for msg in violated_rule_messages )
990- base_error_message += f"\n ** Hint: Possible rules that were violated:\n { bullet_points } "
991- tool_execution_result = ToolExecutionResult (status = "error" , func_return = base_error_message )
963+ tool_execution_result = _build_rule_violation_result (tool_call_name , valid_tool_names , tool_rules_solver )
992964 else :
993965 tool_execution_result = await self ._execute_tool (
994966 tool_name = tool_call_name ,
@@ -997,66 +969,38 @@ async def _handle_ai_response(
997969 agent_step_span = agent_step_span ,
998970 step_id = step_id ,
999971 )
972+
1000973 log_telemetry (
1001974 self .logger , "_handle_ai_response execute tool finish" , tool_execution_result = tool_execution_result , tool_call_id = tool_call_id
1002975 )
1003976
1004- if tool_call_name in ["conversation_search" , "conversation_search_date" , "archival_memory_search" ]:
1005- # with certain functions we rely on the paging mechanism to handle overflow
1006- truncate = False
1007- else :
1008- # but by default, we add a truncation safeguard to prevent bad functions from
1009- # overflow the agent context window
1010- truncate = True
1011-
1012- # get the function response limit
1013- target_tool = next ((x for x in agent_state .tools if x .name == tool_call_name ), None )
1014- return_char_limit = target_tool .return_char_limit if target_tool else None
977+ # 3. Prepare the function-response payload
978+ truncate = tool_call_name not in {"conversation_search" , "conversation_search_date" , "archival_memory_search" }
979+ return_char_limit = next (
980+ (t .return_char_limit for t in agent_state .tools if t .name == tool_call_name ),
981+ None ,
982+ )
1015983 function_response_string = validate_function_response (
1016- tool_execution_result .func_return , return_char_limit = return_char_limit , truncate = truncate
984+ tool_execution_result .func_return ,
985+ return_char_limit = return_char_limit ,
986+ truncate = truncate ,
1017987 )
1018- function_response = package_function_response (
988+ self . last_function_response = package_function_response (
1019989 was_success = tool_execution_result .success_flag ,
1020990 response_string = function_response_string ,
1021991 timezone = agent_state .timezone ,
1022992 )
1023993
1024- # 4. Register tool call with tool rule solver
1025- # Resolve whether or not to continue stepping
1026- continue_stepping = request_heartbeat
1027-
1028- # Force continuation if tool rule was violated to give the model another chance
1029- if tool_rule_violated :
1030- continue_stepping = True
1031- else :
1032- tool_rules_solver .register_tool_call (tool_name = tool_call_name )
1033- if tool_rules_solver .is_terminal_tool (tool_name = tool_call_name ):
1034- if continue_stepping :
1035- stop_reason = LettaStopReason (stop_reason = StopReasonType .tool_rule .value )
1036- continue_stepping = False
1037- elif tool_rules_solver .has_children_tools (tool_name = tool_call_name ):
1038- continue_stepping = True
1039- elif tool_rules_solver .is_continue_tool (tool_name = tool_call_name ):
1040- continue_stepping = True
1041-
1042- # Check if required-before-exit tools have been called before allowing exit
1043- heartbeat_reason = None # Default
1044- uncalled_required_tools = tool_rules_solver .get_uncalled_required_tools ()
1045- if not continue_stepping and uncalled_required_tools :
1046- continue_stepping = True
1047- heartbeat_reason = (
1048- f"{ NON_USER_MSG_PREFIX } Cannot finish, still need to call the following required tools: { ', ' .join (uncalled_required_tools )} "
1049- )
1050-
1051- # TODO: @caren is this right?
1052- # reset stop reason since we ain't stopping!
1053- stop_reason = None
1054- self .logger .info (f"RequiredBeforeExitToolRule: Forcing agent continuation. Missing required tools: { uncalled_required_tools } " )
994+ # 4. Decide whether to keep stepping (<<< focal section simplified)
995+ continue_stepping , heartbeat_reason , stop_reason = self ._decide_continuation (
996+ request_heartbeat = request_heartbeat ,
997+ tool_call_name = tool_call_name ,
998+ tool_rule_violated = tool_rule_violated ,
999+ tool_rules_solver = tool_rules_solver ,
1000+ is_final_step = is_final_step ,
1001+ )
10551002
1056- # 5a. Persist Steps to DB
1057- # Following agent loop to persist this before messages
1058- # TODO (cliandy): determine what should match old loop w/provider_id
1059- # TODO (cliandy): UsageStatistics and LettaUsageStatistics are used in many places, but are not the same.
1003+ # 5. Persist step + messages and propagate to jobs
10601004 logged_step = await self .step_manager .log_step_async (
10611005 actor = self .actor ,
10621006 agent_id = agent_state .id ,
@@ -1071,7 +1015,6 @@ async def _handle_ai_response(
10711015 step_id = step_id ,
10721016 )
10731017
1074- # 5b. Persist Messages to DB
10751018 tool_call_messages = create_letta_messages_from_llm_response (
10761019 agent_id = agent_state .id ,
10771020 model = agent_state .llm_config .model ,
@@ -1083,27 +1026,72 @@ async def _handle_ai_response(
10831026 function_response = function_response_string ,
10841027 timezone = agent_state .timezone ,
10851028 actor = self .actor ,
1086- add_heartbeat_request_system_message = continue_stepping ,
1029+ continue_stepping = continue_stepping ,
10871030 heartbeat_reason = heartbeat_reason ,
10881031 reasoning_content = reasoning_content ,
10891032 pre_computed_assistant_message_id = pre_computed_assistant_message_id ,
1090- step_id = logged_step .id if logged_step else None , # TODO (cliandy): eventually move over other agent loops
1033+ step_id = logged_step .id if logged_step else None ,
10911034 )
10921035
10931036 persisted_messages = await self .message_manager .create_many_messages_async (
10941037 (initial_messages or []) + tool_call_messages , actor = self .actor
10951038 )
1096- self .last_function_response = function_response
10971039
10981040 if run_id :
10991041 await self .job_manager .add_messages_to_job_async (
11001042 job_id = run_id ,
1101- message_ids = [message .id for message in persisted_messages if message .role != "user" ],
1043+ message_ids = [m .id for m in persisted_messages if m .role != "user" ],
11021044 actor = self .actor ,
11031045 )
11041046
11051047 return persisted_messages , continue_stepping , stop_reason
11061048
1049+ def _decide_continuation (
1050+ self ,
1051+ request_heartbeat : bool ,
1052+ tool_call_name : str ,
1053+ tool_rule_violated : bool ,
1054+ tool_rules_solver : ToolRulesSolver ,
1055+ is_final_step : bool | None ,
1056+ ) -> tuple [bool , str | None , LettaStopReason | None ]:
1057+
1058+ continue_stepping = request_heartbeat
1059+ heartbeat_reason : str | None = None
1060+ stop_reason : LettaStopReason | None = None
1061+
1062+ if tool_rule_violated :
1063+ continue_stepping = True
1064+ heartbeat_reason = f"{ NON_USER_MSG_PREFIX } Continuing: tool rule violation."
1065+ else :
1066+ tool_rules_solver .register_tool_call (tool_call_name )
1067+
1068+ if tool_rules_solver .is_terminal_tool (tool_call_name ):
1069+ if continue_stepping :
1070+ stop_reason = LettaStopReason (stop_reason = StopReasonType .tool_rule .value )
1071+ continue_stepping = False
1072+
1073+ elif tool_rules_solver .has_children_tools (tool_call_name ):
1074+ continue_stepping = True
1075+ heartbeat_reason = f"{ NON_USER_MSG_PREFIX } Continuing: child tool rule."
1076+
1077+ elif tool_rules_solver .is_continue_tool (tool_call_name ):
1078+ continue_stepping = True
1079+ heartbeat_reason = f"{ NON_USER_MSG_PREFIX } Continuing: continue tool rule."
1080+
1081+ # – hard stop overrides –
1082+ if is_final_step :
1083+ continue_stepping = False
1084+ stop_reason = LettaStopReason (stop_reason = StopReasonType .max_steps .value )
1085+ else :
1086+ uncalled = tool_rules_solver .get_uncalled_required_tools ()
1087+ if not continue_stepping and uncalled :
1088+ continue_stepping = True
1089+ heartbeat_reason = f"{ NON_USER_MSG_PREFIX } Missing required tools: " f"{ ', ' .join (uncalled )} "
1090+
1091+ stop_reason = None # reset – we’re still going
1092+
1093+ return continue_stepping , heartbeat_reason , stop_reason
1094+
11071095 @trace_method
11081096 async def _execute_tool (
11091097 self ,
0 commit comments