Skip to content

Commit 888c98c

Browse files
authored
feat: estimate input tokens before model calls (#2221)
1 parent e12ac9d commit 888c98c

9 files changed

Lines changed: 242 additions & 4 deletions

File tree

src/strands/agent/agent_result.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ def context_size(self) -> int | None:
4444
"""
4545
return self.metrics.latest_context_size
4646

47+
@property
48+
def projected_context_size(self) -> int | None:
49+
"""Projected context size for the next model call.
50+
51+
Returns:
52+
The projected token count (inputTokens + outputTokens), or None if no data is available.
53+
"""
54+
return self.metrics.projected_context_size
55+
4756
def __str__(self) -> str:
4857
"""Return a string representation of the agent result.
4958

src/strands/event_loop/event_loop.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,48 @@ def _has_tool_use_in_latest_message(messages: "Messages") -> bool:
7575
return False
7676

7777

78+
async def _estimate_input_tokens(agent: "Agent") -> int:
79+
"""Estimate the input token count for the next model call.
80+
81+
Reads inputTokens + outputTokens from the last assistant message's metadata as a known
82+
baseline, then estimates only new messages added after it. Falls back to full estimation
83+
when no metadata is available (cold start or first call). On cold start, tool specs are
84+
resolved lazily so that the caller does not need to resolve them before BeforeModelCallEvent.
85+
86+
Args:
87+
agent: The agent instance with messages and model.
88+
89+
Returns:
90+
Estimated input token count.
91+
"""
92+
messages = agent.messages
93+
94+
# Find the last assistant message with usage metadata
95+
last_assistant_idx = -1
96+
for i, msg in reversed(list(enumerate(messages))):
97+
if msg.get("role") == "assistant" and msg.get("metadata", {}).get("usage"):
98+
last_assistant_idx = i
99+
break
100+
101+
if last_assistant_idx >= 0:
102+
usage = messages[last_assistant_idx]["metadata"]["usage"]
103+
known_baseline = usage["inputTokens"] + usage["outputTokens"]
104+
new_messages = messages[last_assistant_idx + 1 :]
105+
if not new_messages:
106+
return known_baseline
107+
# System prompt and tool spec tokens are already included in the baseline
108+
return known_baseline + await agent.model.count_tokens(new_messages)
109+
110+
# Cold start: resolve tool specs lazily for estimation only
111+
tool_specs = agent.tool_registry.get_all_tool_specs()
112+
return await agent.model.count_tokens(
113+
messages,
114+
tool_specs=tool_specs,
115+
system_prompt=agent.system_prompt,
116+
system_prompt_content=agent._system_prompt_content,
117+
)
118+
119+
78120
async def event_loop_cycle(
79121
agent: "Agent",
80122
invocation_state: dict[str, Any],
@@ -325,10 +367,18 @@ async def _handle_model_execution(
325367
)
326368
with trace_api.use_span(model_invoke_span, end_on_exit=False):
327369
try:
370+
# Estimate input tokens for the upcoming model call (non-fatal)
371+
projected_input_tokens: int | None = None
372+
try:
373+
projected_input_tokens = await _estimate_input_tokens(agent)
374+
except Exception as e:
375+
logger.debug("error=<%s> | token estimation failed, proceeding without estimate", e)
376+
328377
await agent.hooks.invoke_callbacks_async(
329378
BeforeModelCallEvent(
330379
agent=agent,
331380
invocation_state=invocation_state,
381+
projected_input_tokens=projected_input_tokens,
332382
)
333383
)
334384

src/strands/hooks/events.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,14 @@ class BeforeModelCallEvent(HookEvent):
236236
invocation_state: State and configuration passed through the agent invocation.
237237
This can include shared context for multi-agent coordination, request tracking,
238238
and dynamic configuration.
239+
projected_input_tokens: Projected input token count for the upcoming model call.
240+
Computed by the agent loop from message metadata and token estimation.
241+
Available for hooks and plugins (e.g. conversation managers) to make
242+
proactive decisions about context management. None if estimation failed.
239243
"""
240244

241245
invocation_state: dict[str, Any] = field(default_factory=dict)
246+
projected_input_tokens: int | None = None
242247

243248

244249
@dataclass

src/strands/telemetry/metrics.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,25 @@ def latest_context_size(self) -> int | None:
215215
return self.agent_invocations[-1].cycles[-1].usage.get("inputTokens")
216216
return None
217217

218+
@property
219+
def projected_context_size(self) -> int | None:
220+
"""Projected context size for the next model call.
221+
222+
Computed as inputTokens + outputTokens from the most recent cycle's usage,
223+
representing the approximate input token count for the next model call
224+
(prior input + generated output that is now part of the conversation).
225+
226+
Returns:
227+
The projected token count, or None if no data is available.
228+
"""
229+
if self.agent_invocations and self.agent_invocations[-1].cycles:
230+
usage = self.agent_invocations[-1].cycles[-1].usage
231+
input_tokens = usage.get("inputTokens")
232+
output_tokens = usage.get("outputTokens")
233+
if input_tokens is not None and output_tokens is not None:
234+
return input_tokens + output_tokens
235+
return None
236+
218237
@property
219238
def _metrics_client(self) -> "MetricsClient":
220239
"""Get the singleton MetricsClient instance."""

tests/strands/agent/hooks/test_events.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,22 @@ def test_after_invocation_event_resume_accepts_various_input_types(agent):
260260
# None to stop
261261
event.resume = None
262262
assert event.resume is None
263+
264+
265+
def test_before_model_call_event_projected_input_tokens_default(agent):
266+
"""Test that projected_input_tokens defaults to None."""
267+
event = BeforeModelCallEvent(agent=agent)
268+
assert event.projected_input_tokens is None
269+
270+
271+
def test_before_model_call_event_projected_input_tokens_set(agent):
272+
"""Test that projected_input_tokens can be set at construction."""
273+
event = BeforeModelCallEvent(agent=agent, projected_input_tokens=500)
274+
assert event.projected_input_tokens == 500
275+
276+
277+
def test_before_model_call_event_projected_input_tokens_not_writable(agent):
278+
"""Test that projected_input_tokens is not writable after construction."""
279+
event = BeforeModelCallEvent(agent=agent, projected_input_tokens=500)
280+
with pytest.raises(AttributeError, match="Property projected_input_tokens is not writable"):
281+
event.projected_input_tokens = 1000

tests/strands/agent/test_agent_hooks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u
165165
agent=agent,
166166
message=agent.messages[0],
167167
)
168-
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY)
168+
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY, projected_input_tokens=ANY)
169169
assert next(events) == AfterModelCallEvent(
170170
agent=agent,
171171
invocation_state=ANY,
@@ -195,7 +195,7 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u
195195
result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"},
196196
)
197197
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2])
198-
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY)
198+
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY, projected_input_tokens=ANY)
199199
assert next(events) == AfterModelCallEvent(
200200
agent=agent,
201201
invocation_state=ANY,
@@ -239,7 +239,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m
239239
agent=agent,
240240
message=agent.messages[0],
241241
)
242-
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY)
242+
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY, projected_input_tokens=ANY)
243243
assert next(events) == AfterModelCallEvent(
244244
agent=agent,
245245
invocation_state=ANY,
@@ -269,7 +269,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m
269269
result={"content": [{"text": "!loot a dekovni I"}], "status": "success", "toolUseId": "123"},
270270
)
271271
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[2])
272-
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY)
272+
assert next(events) == BeforeModelCallEvent(agent=agent, invocation_state=ANY, projected_input_tokens=ANY)
273273
assert next(events) == AfterModelCallEvent(
274274
agent=agent,
275275
invocation_state=ANY,

tests/strands/agent/test_agent_result.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,3 +384,17 @@ def test_context_size_none_when_no_data(mock_metrics, simple_message: Message):
384384
mock_metrics.latest_context_size = None
385385
result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={})
386386
assert result.context_size is None
387+
388+
389+
def test_projected_context_size_delegates_to_metrics(mock_metrics, simple_message: Message):
390+
"""Test that projected_context_size delegates to metrics.projected_context_size."""
391+
mock_metrics.projected_context_size = 15000
392+
result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={})
393+
assert result.projected_context_size == 15000
394+
395+
396+
def test_projected_context_size_none_when_no_data(mock_metrics, simple_message: Message):
397+
"""Test that projected_context_size returns None when metrics has no data."""
398+
mock_metrics.projected_context_size = None
399+
result = AgentResult(stop_reason="end_turn", message=simple_message, metrics=mock_metrics, state={})
400+
assert result.projected_context_size is None

tests/strands/event_loop/test_event_loop.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,3 +1198,84 @@ async def test_event_loop_metrics_recorded_before_recursion(
11981198
# Verify the event loop completed successfully
11991199
tru_stop_reason, _, _, _, _, _ = events[-1]["stop"]
12001200
assert tru_stop_reason == "end_turn"
1201+
1202+
1203+
class TestEstimateInputTokens:
1204+
"""Tests for _estimate_input_tokens helper."""
1205+
1206+
@pytest.mark.asyncio
1207+
async def test_cold_start_estimates_all_messages(self):
1208+
"""On cold start (no prior usage metadata), estimates all messages with lazily resolved tool specs."""
1209+
agent = unittest.mock.AsyncMock()
1210+
agent.messages = [{"role": "user", "content": [{"text": "Hi"}]}]
1211+
agent.system_prompt = "You are helpful"
1212+
agent._system_prompt_content = None
1213+
agent.tool_registry = unittest.mock.MagicMock()
1214+
agent.tool_registry.get_all_tool_specs.return_value = [{"name": "tool1"}]
1215+
agent.model.count_tokens = AsyncMock(return_value=42)
1216+
1217+
result = await strands.event_loop.event_loop._estimate_input_tokens(agent)
1218+
1219+
assert result == 42
1220+
agent.tool_registry.get_all_tool_specs.assert_called_once()
1221+
agent.model.count_tokens.assert_called_once_with(
1222+
agent.messages,
1223+
tool_specs=[{"name": "tool1"}],
1224+
system_prompt="You are helpful",
1225+
system_prompt_content=None,
1226+
)
1227+
1228+
@pytest.mark.asyncio
1229+
async def test_baseline_only_no_new_messages(self):
1230+
"""When last message is assistant with usage and no new messages after, returns baseline."""
1231+
agent = unittest.mock.AsyncMock()
1232+
agent.messages = [
1233+
{"role": "user", "content": [{"text": "Hi"}]},
1234+
{
1235+
"role": "assistant",
1236+
"content": [{"text": "Hello"}],
1237+
"metadata": {"usage": {"inputTokens": 100, "outputTokens": 20, "totalTokens": 120}},
1238+
},
1239+
]
1240+
agent.system_prompt = "You are helpful"
1241+
1242+
result = await strands.event_loop.event_loop._estimate_input_tokens(agent)
1243+
1244+
assert result == 120
1245+
agent.model.count_tokens.assert_not_called()
1246+
1247+
@pytest.mark.asyncio
1248+
async def test_baseline_plus_delta(self):
1249+
"""When new messages exist after last assistant, adds estimated delta to baseline."""
1250+
agent = unittest.mock.AsyncMock()
1251+
agent.messages = [
1252+
{"role": "user", "content": [{"text": "Hi"}]},
1253+
{
1254+
"role": "assistant",
1255+
"content": [{"text": "Hello"}],
1256+
"metadata": {"usage": {"inputTokens": 100, "outputTokens": 30, "totalTokens": 130}},
1257+
},
1258+
{"role": "user", "content": [{"text": "tool result"}]},
1259+
]
1260+
agent.system_prompt = "You are helpful"
1261+
agent.model.count_tokens = AsyncMock(return_value=50)
1262+
1263+
result = await strands.event_loop.event_loop._estimate_input_tokens(agent)
1264+
1265+
# baseline (100+30) + delta (50) = 180
1266+
assert result == 180
1267+
agent.model.count_tokens.assert_called_once()
1268+
1269+
@pytest.mark.asyncio
1270+
async def test_error_fallback_returns_none_at_call_site(self):
1271+
"""When count_tokens raises, the caller catches and sets projected_input_tokens to None."""
1272+
agent = unittest.mock.AsyncMock()
1273+
agent.messages = [{"role": "user", "content": [{"text": "Hi"}]}]
1274+
agent.system_prompt = "You are helpful"
1275+
agent._system_prompt_content = None
1276+
agent.tool_registry = unittest.mock.MagicMock()
1277+
agent.tool_registry.get_all_tool_specs.return_value = []
1278+
agent.model.count_tokens = AsyncMock(side_effect=Exception("API unavailable"))
1279+
1280+
with pytest.raises(Exception, match="API unavailable"):
1281+
await strands.event_loop.event_loop._estimate_input_tokens(agent)

tests/strands/telemetry/test_metrics.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,3 +613,44 @@ def test_latest_context_size_missing_input_tokens_key(event_loop_metrics):
613613
)
614614
)
615615
assert event_loop_metrics.latest_context_size is None
616+
617+
618+
def test_projected_context_size_no_invocations(event_loop_metrics):
619+
assert event_loop_metrics.projected_context_size is None
620+
621+
622+
def test_projected_context_size_invocation_with_no_cycles(event_loop_metrics):
623+
event_loop_metrics.reset_usage_metrics()
624+
assert event_loop_metrics.projected_context_size is None
625+
626+
627+
def test_projected_context_size_returns_input_plus_output(event_loop_metrics, mock_get_meter_provider):
628+
event_loop_metrics.reset_usage_metrics()
629+
event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "c1"})
630+
event_loop_metrics.update_usage(Usage(inputTokens=100, outputTokens=50, totalTokens=150))
631+
632+
assert event_loop_metrics.projected_context_size == 150
633+
634+
635+
def test_projected_context_size_updates_across_cycles(event_loop_metrics, mock_get_meter_provider):
636+
event_loop_metrics.reset_usage_metrics()
637+
event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "c1"})
638+
event_loop_metrics.update_usage(Usage(inputTokens=100, outputTokens=50, totalTokens=150))
639+
640+
event_loop_metrics.start_cycle(attributes={"event_loop_cycle_id": "c2"})
641+
event_loop_metrics.update_usage(Usage(inputTokens=200, outputTokens=80, totalTokens=280))
642+
643+
assert event_loop_metrics.projected_context_size == 280
644+
645+
646+
def test_projected_context_size_missing_tokens_key(event_loop_metrics):
647+
"""Returns None when usage dict is missing inputTokens or outputTokens."""
648+
event_loop_metrics.reset_usage_metrics()
649+
invocation = event_loop_metrics.agent_invocations[-1]
650+
invocation.cycles.append(
651+
strands.telemetry.metrics.EventLoopCycleMetric(
652+
event_loop_cycle_id="c1",
653+
usage={"outputTokens": 50, "totalTokens": 50},
654+
)
655+
)
656+
assert event_loop_metrics.projected_context_size is None

0 commit comments

Comments
 (0)