@@ -137,6 +137,10 @@ def __init__(
137137 message_buffer_limit = message_buffer_limit ,
138138 message_buffer_min = message_buffer_min ,
139139 partial_evict_summarizer_percentage = partial_evict_summarizer_percentage ,
140+ agent_manager = self .agent_manager ,
141+ message_manager = self .message_manager ,
142+ actor = self .actor ,
143+ agent_id = self .agent_id ,
140144 )
141145
142146 async def _check_run_cancellation (self ) -> bool :
@@ -345,16 +349,17 @@ async def step_stream_no_tokens(
345349 agent_step_span .end ()
346350
347351 # Log LLM Trace
348- await self .telemetry_manager .create_provider_trace_async (
349- actor = self .actor ,
350- provider_trace_create = ProviderTraceCreate (
351- request_json = request_data ,
352- response_json = response_data ,
353- step_id = step_id , # Use original step_id for telemetry
354- organization_id = self .actor .organization_id ,
355- ),
356- )
357- step_progression = StepProgression .LOGGED_TRACE
352+ if settings .track_provider_trace :
353+ await self .telemetry_manager .create_provider_trace_async (
354+ actor = self .actor ,
355+ provider_trace_create = ProviderTraceCreate (
356+ request_json = request_data ,
357+ response_json = response_data ,
358+ step_id = step_id , # Use original step_id for telemetry
359+ organization_id = self .actor .organization_id ,
360+ ),
361+ )
362+ step_progression = StepProgression .LOGGED_TRACE
358363
359364 # stream step
360365 # TODO: improve TTFT
@@ -642,17 +647,18 @@ async def _step(
642647 agent_step_span .end ()
643648
644649 # Log LLM Trace
645- await self .telemetry_manager .create_provider_trace_async (
646- actor = self .actor ,
647- provider_trace_create = ProviderTraceCreate (
648- request_json = request_data ,
649- response_json = response_data ,
650- step_id = step_id , # Use original step_id for telemetry
651- organization_id = self .actor .organization_id ,
652- ),
653- )
650+ if settings .track_provider_trace :
651+ await self .telemetry_manager .create_provider_trace_async (
652+ actor = self .actor ,
653+ provider_trace_create = ProviderTraceCreate (
654+ request_json = request_data ,
655+ response_json = response_data ,
656+ step_id = step_id , # Use original step_id for telemetry
657+ organization_id = self .actor .organization_id ,
658+ ),
659+ )
660+ step_progression = StepProgression .LOGGED_TRACE
654661
655- step_progression = StepProgression .LOGGED_TRACE
656662 MetricRegistry ().step_execution_time_ms_histogram .record (get_utc_timestamp_ns () - step_start , get_ctx_attributes ())
657663 step_progression = StepProgression .FINISHED
658664
@@ -1003,31 +1009,32 @@ async def step_stream(
10031009 # Log LLM Trace
10041010 # We are piecing together the streamed response here.
10051011 # Content here does not match the actual response schema as streams come in chunks.
1006- await self .telemetry_manager .create_provider_trace_async (
1007- actor = self .actor ,
1008- provider_trace_create = ProviderTraceCreate (
1009- request_json = request_data ,
1010- response_json = {
1011- "content" : {
1012- "tool_call" : tool_call .model_dump_json (),
1013- "reasoning" : [content .model_dump_json () for content in reasoning_content ],
1012+ if settings .track_provider_trace :
1013+ await self .telemetry_manager .create_provider_trace_async (
1014+ actor = self .actor ,
1015+ provider_trace_create = ProviderTraceCreate (
1016+ request_json = request_data ,
1017+ response_json = {
1018+ "content" : {
1019+ "tool_call" : tool_call .model_dump_json (),
1020+ "reasoning" : [content .model_dump_json () for content in reasoning_content ],
1021+ },
1022+ "id" : interface .message_id ,
1023+ "model" : interface .model ,
1024+ "role" : "assistant" ,
1025+ # "stop_reason": "",
1026+ # "stop_sequence": None,
1027+ "type" : "message" ,
1028+ "usage" : {
1029+ "input_tokens" : usage .prompt_tokens ,
1030+ "output_tokens" : usage .completion_tokens ,
1031+ },
10141032 },
1015- "id" : interface .message_id ,
1016- "model" : interface .model ,
1017- "role" : "assistant" ,
1018- # "stop_reason": "",
1019- # "stop_sequence": None,
1020- "type" : "message" ,
1021- "usage" : {
1022- "input_tokens" : usage .prompt_tokens ,
1023- "output_tokens" : usage .completion_tokens ,
1024- },
1025- },
1026- step_id = step_id , # Use original step_id for telemetry
1027- organization_id = self .actor .organization_id ,
1028- ),
1029- )
1030- step_progression = StepProgression .LOGGED_TRACE
1033+ step_id = step_id , # Use original step_id for telemetry
1034+ organization_id = self .actor .organization_id ,
1035+ ),
1036+ )
1037+ step_progression = StepProgression .LOGGED_TRACE
10311038
10321039 # yields tool response as this is handled from Letta and not the response from the LLM provider
10331040 tool_return = [msg for msg in persisted_messages if msg .role == "tool" ][- 1 ].to_letta_messages ()[0 ]
@@ -1352,6 +1359,7 @@ async def _rebuild_context_window(
13521359 ) -> list [Message ]:
13531360 # If total tokens is reached, we truncate down
13541361 # TODO: This can be broken by bad configs, e.g. lower bound too high, initial messages too fat, etc.
1362+ # TODO: `force` and `clear` seem to no longer be used, we should remove
13551363 if force or (total_tokens and total_tokens > llm_config .context_window ):
13561364 self .logger .warning (
13571365 f"Total tokens { total_tokens } exceeds configured max tokens { llm_config .context_window } , forcefully clearing message history."
@@ -1363,6 +1371,7 @@ async def _rebuild_context_window(
13631371 clear = True ,
13641372 )
13651373 else :
1374+ # NOTE (Sarah): Seems like this is doing nothing?
13661375 self .logger .info (
13671376 f"Total tokens { total_tokens } does not exceed configured max tokens { llm_config .context_window } , passing summarizing w/o force."
13681377 )
@@ -1453,8 +1462,10 @@ async def _create_llm_request_data_async(
14531462 force_tool_call = valid_tool_names [0 ]
14541463
14551464 allowed_tools = [enable_strict_mode (t .json_schema ) for t in tools if t .name in set (valid_tool_names )]
1465+ # Extract terminal tool names from tool rules
1466+ terminal_tool_names = {rule .tool_name for rule in tool_rules_solver .terminal_tool_rules }
14561467 allowed_tools = runtime_override_tool_json_schema (
1457- tool_list = allowed_tools , response_format = agent_state .response_format , request_heartbeat = True
1468+ tool_list = allowed_tools , response_format = agent_state .response_format , request_heartbeat = True , terminal_tools = terminal_tool_names
14581469 )
14591470
14601471 return (
0 commit comments