@@ -1198,3 +1198,84 @@ async def test_event_loop_metrics_recorded_before_recursion(
11981198 # Verify the event loop completed successfully
11991199 tru_stop_reason , _ , _ , _ , _ , _ = events [- 1 ]["stop" ]
12001200 assert tru_stop_reason == "end_turn"
1201+
1202+
1203+ class TestEstimateInputTokens :
1204+ """Tests for _estimate_input_tokens helper."""
1205+
1206+ @pytest .mark .asyncio
1207+ async def test_cold_start_estimates_all_messages (self ):
1208+ """On cold start (no prior usage metadata), estimates all messages with lazily resolved tool specs."""
1209+ agent = unittest .mock .AsyncMock ()
1210+ agent .messages = [{"role" : "user" , "content" : [{"text" : "Hi" }]}]
1211+ agent .system_prompt = "You are helpful"
1212+ agent ._system_prompt_content = None
1213+ agent .tool_registry = unittest .mock .MagicMock ()
1214+ agent .tool_registry .get_all_tool_specs .return_value = [{"name" : "tool1" }]
1215+ agent .model .count_tokens = AsyncMock (return_value = 42 )
1216+
1217+ result = await strands .event_loop .event_loop ._estimate_input_tokens (agent )
1218+
1219+ assert result == 42
1220+ agent .tool_registry .get_all_tool_specs .assert_called_once ()
1221+ agent .model .count_tokens .assert_called_once_with (
1222+ agent .messages ,
1223+ tool_specs = [{"name" : "tool1" }],
1224+ system_prompt = "You are helpful" ,
1225+ system_prompt_content = None ,
1226+ )
1227+
1228+ @pytest .mark .asyncio
1229+ async def test_baseline_only_no_new_messages (self ):
1230+ """When last message is assistant with usage and no new messages after, returns baseline."""
1231+ agent = unittest .mock .AsyncMock ()
1232+ agent .messages = [
1233+ {"role" : "user" , "content" : [{"text" : "Hi" }]},
1234+ {
1235+ "role" : "assistant" ,
1236+ "content" : [{"text" : "Hello" }],
1237+ "metadata" : {"usage" : {"inputTokens" : 100 , "outputTokens" : 20 , "totalTokens" : 120 }},
1238+ },
1239+ ]
1240+ agent .system_prompt = "You are helpful"
1241+
1242+ result = await strands .event_loop .event_loop ._estimate_input_tokens (agent )
1243+
1244+ assert result == 120
1245+ agent .model .count_tokens .assert_not_called ()
1246+
1247+ @pytest .mark .asyncio
1248+ async def test_baseline_plus_delta (self ):
1249+ """When new messages exist after last assistant, adds estimated delta to baseline."""
1250+ agent = unittest .mock .AsyncMock ()
1251+ agent .messages = [
1252+ {"role" : "user" , "content" : [{"text" : "Hi" }]},
1253+ {
1254+ "role" : "assistant" ,
1255+ "content" : [{"text" : "Hello" }],
1256+ "metadata" : {"usage" : {"inputTokens" : 100 , "outputTokens" : 30 , "totalTokens" : 130 }},
1257+ },
1258+ {"role" : "user" , "content" : [{"text" : "tool result" }]},
1259+ ]
1260+ agent .system_prompt = "You are helpful"
1261+ agent .model .count_tokens = AsyncMock (return_value = 50 )
1262+
1263+ result = await strands .event_loop .event_loop ._estimate_input_tokens (agent )
1264+
1265+ # baseline (100+30) + delta (50) = 180
1266+ assert result == 180
1267+ agent .model .count_tokens .assert_called_once ()
1268+
1269+ @pytest .mark .asyncio
1270+ async def test_error_fallback_returns_none_at_call_site (self ):
1271+ """When count_tokens raises, the caller catches and sets projected_input_tokens to None."""
1272+ agent = unittest .mock .AsyncMock ()
1273+ agent .messages = [{"role" : "user" , "content" : [{"text" : "Hi" }]}]
1274+ agent .system_prompt = "You are helpful"
1275+ agent ._system_prompt_content = None
1276+ agent .tool_registry = unittest .mock .MagicMock ()
1277+ agent .tool_registry .get_all_tool_specs .return_value = []
1278+ agent .model .count_tokens = AsyncMock (side_effect = Exception ("API unavailable" ))
1279+
1280+ with pytest .raises (Exception , match = "API unavailable" ):
1281+ await strands .event_loop .event_loop ._estimate_input_tokens (agent )
0 commit comments