Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions agent/usage_pricing.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,11 +440,16 @@ def normalize_usage(
provider_name = (provider or "").strip().lower()
mode = (api_mode or "").strip().lower()

reasoning_tokens = 0

if mode == "anthropic_messages" or provider_name == "anthropic":
input_tokens = _to_int(getattr(response_usage, "input_tokens", 0))
output_tokens = _to_int(getattr(response_usage, "output_tokens", 0))
cache_read_tokens = _to_int(getattr(response_usage, "cache_read_input_tokens", 0))
cache_write_tokens = _to_int(getattr(response_usage, "cache_creation_input_tokens", 0))
output_details = getattr(response_usage, "output_tokens_details", None)
if output_details:
reasoning_tokens = _to_int(getattr(output_details, "reasoning_tokens", 0))
elif mode == "codex_responses":
input_total = _to_int(getattr(response_usage, "input_tokens", 0))
output_tokens = _to_int(getattr(response_usage, "output_tokens", 0))
Expand All @@ -454,6 +459,9 @@ def normalize_usage(
getattr(details, "cache_creation_tokens", 0) if details else 0
)
input_tokens = max(0, input_total - cache_read_tokens - cache_write_tokens)
output_details = getattr(response_usage, "output_tokens_details", None)
if output_details:
reasoning_tokens = _to_int(getattr(output_details, "reasoning_tokens", 0))
else:
prompt_total = _to_int(getattr(response_usage, "prompt_tokens", 0))
output_tokens = _to_int(getattr(response_usage, "completion_tokens", 0))
Expand All @@ -463,11 +471,9 @@ def normalize_usage(
getattr(details, "cache_write_tokens", 0) if details else 0
)
input_tokens = max(0, prompt_total - cache_read_tokens - cache_write_tokens)

reasoning_tokens = 0
output_details = getattr(response_usage, "output_tokens_details", None)
if output_details:
reasoning_tokens = _to_int(getattr(output_details, "reasoning_tokens", 0))
completion_details = getattr(response_usage, "completion_tokens_details", None)
if completion_details:
reasoning_tokens = _to_int(getattr(completion_details, "reasoning_tokens", 0))

return CanonicalUsage(
input_tokens=input_tokens,
Expand Down
50 changes: 47 additions & 3 deletions gateway/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3770,12 +3770,14 @@ async def _handle_message_with_agent(self, event, source, _quick_key: str):
skip_db=agent_persisted,
)

# Token counts and model are now persisted by the agent directly.
# Keep only last_prompt_tokens here for context-window tracking and
# compression decisions.
# Persist token counts to session store for /status display
self.session_store.update_session(
session_entry.session_key,
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
input_tokens=agent_result.get("input_tokens", 0),
output_tokens=agent_result.get("output_tokens", 0),
total_tokens=(agent_result.get("input_tokens", 0) + agent_result.get("output_tokens", 0)),
compression_count=agent_result.get("compression_count", 0),
)

# Auto voice reply: send TTS audio before the text response
Expand Down Expand Up @@ -4117,6 +4119,47 @@ async def _handle_status_command(self, event: MessageEvent) -> str:
f"**Connected Platforms:** {', '.join(connected_platforms)}",
])

# Context window info
try:
from agent.model_metadata import DEFAULT_CONTEXT_LENGTHS, _strip_provider_prefix

model = _resolve_gateway_model()
if model:
bare = _strip_provider_prefix(model)
# Quick lookup: exact match, then substring (longest key first)
ctx_len = DEFAULT_CONTEXT_LENGTHS.get(bare)
if ctx_len is None:
ctx_len = DEFAULT_CONTEXT_LENGTHS.get(model)
if ctx_len is None:
# Substring match — find longest matching key
best_key = ""
for key in DEFAULT_CONTEXT_LENGTHS:
if key in bare or bare.startswith(key):
if len(key) > len(best_key):
best_key = key
if best_key:
ctx_len = DEFAULT_CONTEXT_LENGTHS[best_key]

if ctx_len:
# Current usage: prefer last_prompt_tokens from session
used = session_entry.last_prompt_tokens
pct = min(100, int(used / ctx_len * 100)) if ctx_len else 0

def _fmt(n: int) -> str:
if n >= 1_000_000:
return f"{n / 1_000_000:.1f}M"
if n >= 1_000:
return f"{n // 1_000}k"
return str(n)

comp = session_entry.compression_count
lines.append("")
lines.append(f"📚 Context: {_fmt(used)}/{_fmt(ctx_len)} ({pct}%) · 🧹 Compactions: {comp}")
else:
logger.debug("Status: no context_length found for model=%r bare=%r", model, bare)
except Exception:
logger.exception("Status: failed to compute context window info")

return "\n".join(lines)

async def _handle_stop_command(self, event: MessageEvent) -> str:
Expand Down Expand Up @@ -7843,6 +7886,7 @@ def run_sync():
chat_id=source.chat_id,
config=_consumer_cfg,
metadata={"thread_id": _progress_thread_id} if _progress_thread_id else None,
reply_to=event_message_id,
)
if _want_stream_deltas:
_stream_delta_cb = _stream_consumer.on_delta
Expand Down
17 changes: 17 additions & 0 deletions gateway/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,9 @@ class SessionEntry:
# this session (create a new session_id) so the user starts fresh.
# Set by /stop to break stuck-resume loops (#7536).
suspended: bool = False

# Total compression/compaction count across all agent runs in this session.
compression_count: int = 0

def to_dict(self) -> Dict[str, Any]:
result = {
Expand All @@ -392,6 +395,7 @@ def to_dict(self) -> Dict[str, Any]:
"cost_status": self.cost_status,
"memory_flushed": self.memory_flushed,
"suspended": self.suspended,
"compression_count": self.compression_count,
}
if self.origin:
result["origin"] = self.origin.to_dict()
Expand Down Expand Up @@ -429,6 +433,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "SessionEntry":
cost_status=data.get("cost_status", "unknown"),
memory_flushed=data.get("memory_flushed", False),
suspended=data.get("suspended", False),
compression_count=data.get("compression_count", 0),
)


Expand Down Expand Up @@ -770,6 +775,10 @@ def update_session(
self,
session_key: str,
last_prompt_tokens: int = None,
input_tokens: int = None,
output_tokens: int = None,
total_tokens: int = None,
compression_count: int = None,
) -> None:
"""Update lightweight session metadata after an interaction."""
with self._lock:
Expand All @@ -780,6 +789,14 @@ def update_session(
entry.updated_at = _now()
if last_prompt_tokens is not None:
entry.last_prompt_tokens = last_prompt_tokens
if input_tokens is not None:
entry.input_tokens += input_tokens
if output_tokens is not None:
entry.output_tokens += output_tokens
if total_tokens is not None:
entry.total_tokens += total_tokens
if compression_count is not None:
entry.compression_count += compression_count
self._save()

def suspend_session(self, session_key: str) -> bool:
Expand Down
5 changes: 4 additions & 1 deletion gateway/stream_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ def __init__(
chat_id: str,
config: Optional[StreamConsumerConfig] = None,
metadata: Optional[dict] = None,
reply_to: Optional[str] = None,
):
self.adapter = adapter
self.chat_id = chat_id
self.cfg = config or StreamConsumerConfig()
self.metadata = metadata
self.reply_to = reply_to
self._queue: queue.Queue = queue.Queue()
self._accumulated = ""
self._message_id: Optional[str] = None
Expand Down Expand Up @@ -714,10 +716,11 @@ async def _send_or_edit(self, text: str) -> bool:
# The final response will be sent by the fallback path.
return False
else:
# First message — send new
# First message — send new (with reply_to to quote the user's message)
result = await self.adapter.send(
chat_id=self.chat_id,
content=text,
reply_to=self.reply_to,
metadata=self.metadata,
)
if result.success:
Expand Down
1 change: 1 addition & 0 deletions run_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10567,6 +10567,7 @@ def _stop_spinner():
"completion_tokens": self.session_completion_tokens,
"total_tokens": self.session_total_tokens,
"last_prompt_tokens": getattr(self.context_compressor, "last_prompt_tokens", 0) or 0,
"compression_count": getattr(self.context_compressor, "compression_count", 0) or 0,
"estimated_cost_usd": self.session_estimated_cost_usd,
"cost_status": self.session_cost_status,
"cost_source": self.session_cost_source,
Expand Down
4 changes: 4 additions & 0 deletions tests/gateway/test_status_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ async def test_handle_message_persists_agent_token_counts(monkeypatch):
runner.session_store.update_session.assert_called_once_with(
session_entry.session_key,
last_prompt_tokens=80,
input_tokens=120,
output_tokens=45,
total_tokens=165,
compression_count=0,
)


Expand Down