Skip to content

Commit 20441cf

Browse files
kshitijk4poorteknium1
authored andcommitted
fix(insights): persist token usage for non-CLI sessions
1 parent 585855d commit 20441cf

6 files changed

Lines changed: 73 additions & 112 deletions

File tree

gateway/run.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2804,20 +2804,12 @@ async def _handle_message_with_agent(self, event, source, _quick_key: str):
28042804
skip_db=agent_persisted,
28052805
)
28062806

2807-
# Update session with actual prompt token count and model from the agent
2807+
# Token counts and model are now persisted by the agent directly.
2808+
# Keep only last_prompt_tokens here for context-window tracking and
2809+
# compression decisions.
28082810
self.session_store.update_session(
28092811
session_entry.session_key,
2810-
input_tokens=agent_result.get("input_tokens", 0),
2811-
output_tokens=agent_result.get("output_tokens", 0),
2812-
cache_read_tokens=agent_result.get("cache_read_tokens", 0),
2813-
cache_write_tokens=agent_result.get("cache_write_tokens", 0),
28142812
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
2815-
model=agent_result.get("model"),
2816-
estimated_cost_usd=agent_result.get("estimated_cost_usd"),
2817-
cost_status=agent_result.get("cost_status"),
2818-
cost_source=agent_result.get("cost_source"),
2819-
provider=agent_result.get("provider"),
2820-
base_url=agent_result.get("base_url"),
28212813
)
28222814

28232815
# Auto voice reply: send TTS audio before the text response

gateway/session.py

Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -778,66 +778,18 @@ def get_or_create_session(
778778
def update_session(
779779
self,
780780
session_key: str,
781-
input_tokens: int = 0,
782-
output_tokens: int = 0,
783-
cache_read_tokens: int = 0,
784-
cache_write_tokens: int = 0,
785781
last_prompt_tokens: int = None,
786-
model: str = None,
787-
estimated_cost_usd: Optional[float] = None,
788-
cost_status: Optional[str] = None,
789-
cost_source: Optional[str] = None,
790-
provider: Optional[str] = None,
791-
base_url: Optional[str] = None,
792782
) -> None:
793-
"""Update a session's metadata after an interaction."""
794-
db_session_id = None
795-
783+
"""Update lightweight session metadata after an interaction."""
796784
with self._lock:
797785
self._ensure_loaded_locked()
798786

799787
if session_key in self._entries:
800788
entry = self._entries[session_key]
801789
entry.updated_at = _now()
802-
# Direct assignment — the gateway receives cumulative totals
803-
# from the cached agent, not per-call deltas.
804-
entry.input_tokens = input_tokens
805-
entry.output_tokens = output_tokens
806-
entry.cache_read_tokens = cache_read_tokens
807-
entry.cache_write_tokens = cache_write_tokens
808790
if last_prompt_tokens is not None:
809791
entry.last_prompt_tokens = last_prompt_tokens
810-
if estimated_cost_usd is not None:
811-
entry.estimated_cost_usd = estimated_cost_usd
812-
if cost_status:
813-
entry.cost_status = cost_status
814-
entry.total_tokens = (
815-
entry.input_tokens
816-
+ entry.output_tokens
817-
+ entry.cache_read_tokens
818-
+ entry.cache_write_tokens
819-
)
820792
self._save()
821-
db_session_id = entry.session_id
822-
823-
if self._db and db_session_id:
824-
try:
825-
self._db.set_token_counts(
826-
db_session_id,
827-
input_tokens=input_tokens,
828-
output_tokens=output_tokens,
829-
cache_read_tokens=cache_read_tokens,
830-
cache_write_tokens=cache_write_tokens,
831-
estimated_cost_usd=estimated_cost_usd,
832-
cost_status=cost_status,
833-
cost_source=cost_source,
834-
billing_provider=provider,
835-
billing_base_url=base_url,
836-
model=model,
837-
absolute=True,
838-
)
839-
except Exception as e:
840-
logger.debug("Session DB operation failed: %s", e)
841793

842794
def reset_session(self, session_key: str) -> Optional[SessionEntry]:
843795
"""Force reset a session, creating a new session ID."""

run_agent.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7221,11 +7221,13 @@ def _stop_spinner():
72217221
self.session_cost_source = cost_result.source
72227222

72237223
# Persist token counts to session DB for /insights.
7224-
# Gateway sessions persist via session_store.update_session()
7225-
# after run_conversation returns, so only persist here for
7226-
# CLI (and other non-gateway) platforms to avoid double-counting.
7227-
if (self._session_db and self.session_id
7228-
and getattr(self, 'platform', None) == 'cli'):
7224+
# Do this for every platform with a session_id so non-CLI
7225+
# sessions (gateway, cron, delegated runs) cannot lose
7226+
# token/accounting data if a higher-level persistence path
7227+
# is skipped or fails. Gateway/session-store writes use
7228+
# absolute totals, so they safely overwrite these per-call
7229+
# deltas instead of double-counting them.
7230+
if self._session_db and self.session_id:
72297231
try:
72307232
self._session_db.update_token_counts(
72317233
self.session_id,

tests/gateway/test_session.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -825,43 +825,6 @@ def test_update_session_zero_resets(self, tmp_path):
825825
store.update_session("k1", last_prompt_tokens=0)
826826
assert entry.last_prompt_tokens == 0
827827

828-
def test_update_session_passes_model_to_db(self, tmp_path):
829-
"""Gateway session updates should forward the resolved model to SQLite."""
830-
config = GatewayConfig()
831-
with patch("gateway.session.SessionStore._ensure_loaded"):
832-
store = SessionStore(sessions_dir=tmp_path, config=config)
833-
store._loaded = True
834-
store._save = MagicMock()
835-
store._db = MagicMock()
836-
837-
from gateway.session import SessionEntry
838-
from datetime import datetime
839-
entry = SessionEntry(
840-
session_key="k1",
841-
session_id="s1",
842-
created_at=datetime.now(),
843-
updated_at=datetime.now(),
844-
)
845-
store._entries = {"k1": entry}
846-
847-
store.update_session("k1", model="openai/gpt-5.4")
848-
849-
store._db.set_token_counts.assert_called_once_with(
850-
"s1",
851-
input_tokens=0,
852-
output_tokens=0,
853-
cache_read_tokens=0,
854-
cache_write_tokens=0,
855-
estimated_cost_usd=None,
856-
cost_status=None,
857-
cost_source=None,
858-
billing_provider=None,
859-
billing_base_url=None,
860-
model="openai/gpt-5.4",
861-
absolute=True,
862-
)
863-
864-
865828
class TestRewriteTranscriptPreservesReasoning:
866829
"""rewrite_transcript must not drop reasoning fields from SQLite."""
867830

tests/gateway/test_status_command.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -126,15 +126,5 @@ async def test_handle_message_persists_agent_token_counts(monkeypatch):
126126
assert result == "ok"
127127
runner.session_store.update_session.assert_called_once_with(
128128
session_entry.session_key,
129-
input_tokens=120,
130-
output_tokens=45,
131-
cache_read_tokens=0,
132-
cache_write_tokens=0,
133129
last_prompt_tokens=80,
134-
model="openai/test-model",
135-
estimated_cost_usd=None,
136-
cost_status=None,
137-
cost_source=None,
138-
provider=None,
139-
base_url=None,
140130
)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from types import SimpleNamespace
2+
from unittest.mock import MagicMock, patch
3+
4+
from run_agent import AIAgent
5+
6+
7+
def _mock_response(*, usage: dict, content: str = "done"):
8+
msg = SimpleNamespace(content=content, tool_calls=None)
9+
choice = SimpleNamespace(message=msg, finish_reason="stop")
10+
return SimpleNamespace(
11+
choices=[choice],
12+
model="test/model",
13+
usage=SimpleNamespace(**usage),
14+
)
15+
16+
17+
def _make_agent(session_db, *, platform: str):
18+
with (
19+
patch("run_agent.get_tool_definitions", return_value=[]),
20+
patch("run_agent.check_toolset_requirements", return_value={}),
21+
patch("run_agent.OpenAI"),
22+
):
23+
agent = AIAgent(
24+
api_key="test-key",
25+
quiet_mode=True,
26+
skip_context_files=True,
27+
skip_memory=True,
28+
session_db=session_db,
29+
session_id=f"{platform}-session",
30+
platform=platform,
31+
)
32+
agent.client = MagicMock()
33+
agent.client.chat.completions.create.return_value = _mock_response(
34+
usage={
35+
"prompt_tokens": 11,
36+
"completion_tokens": 7,
37+
"total_tokens": 18,
38+
}
39+
)
40+
return agent
41+
42+
43+
def test_run_conversation_persists_tokens_for_telegram_sessions():
44+
session_db = MagicMock()
45+
agent = _make_agent(session_db, platform="telegram")
46+
47+
result = agent.run_conversation("hello")
48+
49+
assert result["final_response"] == "done"
50+
session_db.update_token_counts.assert_called_once()
51+
assert session_db.update_token_counts.call_args.args[0] == "telegram-session"
52+
53+
54+
def test_run_conversation_persists_tokens_for_cron_sessions():
55+
session_db = MagicMock()
56+
agent = _make_agent(session_db, platform="cron")
57+
58+
result = agent.run_conversation("hello")
59+
60+
assert result["final_response"] == "done"
61+
session_db.update_token_counts.assert_called_once()
62+
assert session_db.update_token_counts.call_args.args[0] == "cron-session"

0 commit comments

Comments
 (0)