Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions gateway/platforms/discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -3078,6 +3078,7 @@ async def send_slash_confirm(
async def send_update_prompt(
self, chat_id: str, prompt: str, default: str = "",
session_key: str = "",
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Send an interactive button-based update prompt (Yes / No).

Expand All @@ -3087,9 +3088,10 @@ async def send_update_prompt(
if not self._client or not DISCORD_AVAILABLE:
return SendResult(success=False, error="Not connected")
try:
channel = self._client.get_channel(int(chat_id))
target_id = metadata.get("thread_id") if metadata and metadata.get("thread_id") else chat_id
channel = self._client.get_channel(int(target_id))
if not channel:
channel = await self._client.fetch_channel(int(chat_id))
channel = await self._client.fetch_channel(int(target_id))

default_hint = f" (default: {default})" if default else ""
embed = discord.Embed(
Expand Down
4 changes: 4 additions & 0 deletions gateway/platforms/telegram.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,6 +1360,7 @@ async def delete_message(self, chat_id: str, message_id: str) -> bool:
async def send_update_prompt(
self, chat_id: str, prompt: str, default: str = "",
session_key: str = "",
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Send an inline-keyboard update prompt (Yes / No buttons).

Expand All @@ -1377,11 +1378,14 @@ async def send_update_prompt(
InlineKeyboardButton("✗ No", callback_data="update_prompt:n"),
]
])
thread_id = self._metadata_thread_id(metadata)
message_thread_id = self._message_thread_id_for_send(thread_id)
msg = await self._bot.send_message(
chat_id=int(chat_id),
text=text,
parse_mode=ParseMode.MARKDOWN,
reply_markup=keyboard,
message_thread_id=message_thread_id,
**self._link_preview_kwargs(),
)
return SendResult(success=True, message_id=str(msg.message_id))
Expand Down
29 changes: 23 additions & 6 deletions gateway/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9490,6 +9490,8 @@ async def _handle_update_command(self, event: MessageEvent) -> str:
"session_key": session_key,
"timestamp": datetime.now().isoformat(),
}
if event.source.thread_id:
pending["thread_id"] = event.source.thread_id
_tmp_pending = pending_path.with_suffix(".tmp")
_tmp_pending.write_text(json.dumps(pending))
_tmp_pending.replace(pending_path)
Expand Down Expand Up @@ -9575,13 +9577,16 @@ async def _watch_update_progress(
adapter = None
chat_id = None
session_key = None
metadata = None
for path in (claimed_path, pending_path):
if path.exists():
try:
pending = json.loads(path.read_text())
platform_str = pending.get("platform")
chat_id = pending.get("chat_id")
session_key = pending.get("session_key")
thread_id = pending.get("thread_id")
metadata = {"thread_id": thread_id} if thread_id else None
if platform_str and chat_id:
platform = Platform(platform_str)
adapter = self.adapters.get(platform)
Expand Down Expand Up @@ -9629,7 +9634,7 @@ async def _flush_buffer() -> None:
chunks = [clean[i:i + max_chunk] for i in range(0, len(clean), max_chunk)]
for chunk in chunks:
try:
await adapter.send(chat_id, f"```\n{chunk}\n```")
await adapter.send(chat_id, f"```\n{chunk}\n```", metadata=metadata)
except Exception as e:
logger.debug("Update stream send failed: %s", e)

Expand All @@ -9652,9 +9657,13 @@ async def _flush_buffer() -> None:
exit_code_raw = exit_code_path.read_text().strip() or "1"
exit_code = int(exit_code_raw)
if exit_code == 0:
await adapter.send(chat_id, "✅ Hermes update finished.")
await adapter.send(chat_id, "✅ Hermes update finished.", metadata=metadata)
else:
await adapter.send(chat_id, "❌ Hermes update failed (exit code {}).".format(exit_code))
await adapter.send(
chat_id,
"❌ Hermes update failed (exit code {}).".format(exit_code),
metadata=metadata,
)
logger.info("Update finished (exit=%s), notified %s", exit_code, session_key)
except Exception as e:
logger.warning("Update final notification failed: %s", e)
Expand Down Expand Up @@ -9704,6 +9713,7 @@ async def _flush_buffer() -> None:
prompt=prompt_text,
default=default,
session_key=session_key,
metadata=metadata,
)
sent_buttons = True
except Exception as btn_err:
Expand All @@ -9715,7 +9725,8 @@ async def _flush_buffer() -> None:
f"⚕ **Update needs your input:**\n\n"
f"{prompt_text}{default_hint}\n\n"
f"Reply `/approve` (yes) or `/deny` (no), "
f"or type your answer directly."
f"or type your answer directly.",
metadata=metadata,
)
self._update_prompt_pending[session_key] = True
# Remove the prompt file so it isn't re-read on the
Expand All @@ -9735,7 +9746,11 @@ async def _flush_buffer() -> None:
exit_code_path.write_text("124")
await _flush_buffer()
try:
await adapter.send(chat_id, "❌ Hermes update timed out after 30 minutes.")
await adapter.send(
chat_id,
"❌ Hermes update timed out after 30 minutes.",
metadata=metadata,
)
except Exception:
pass
for p in (pending_path, claimed_path, output_path,
Expand Down Expand Up @@ -9777,6 +9792,7 @@ async def _send_update_notification(self) -> bool:
pending = json.loads(claimed_path.read_text())
platform_str = pending.get("platform")
chat_id = pending.get("chat_id")
thread_id = pending.get("thread_id")

if not exit_code_path.exists():
logger.info("Update notification deferred: update still running")
Expand All @@ -9798,6 +9814,7 @@ async def _send_update_notification(self) -> bool:
adapter = self.adapters.get(platform)

if adapter and chat_id:
metadata = {"thread_id": thread_id} if thread_id else None
# Strip ANSI escape codes for clean display
output = re.sub(r'\x1b\[[0-9;]*m', '', output).strip()
if output:
Expand All @@ -9812,7 +9829,7 @@ async def _send_update_notification(self) -> bool:
msg = "✅ Hermes update finished successfully."
else:
msg = "❌ Hermes update failed. Check the gateway logs or run `hermes update` manually for details."
await adapter.send(chat_id, msg)
await adapter.send(chat_id, msg, metadata=metadata)
logger.info(
"Sent post-update notification to %s:%s (exit=%s)",
platform_str,
Expand Down
56 changes: 55 additions & 1 deletion tests/gateway/test_update_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@


def _make_event(text="/update", platform=Platform.TELEGRAM,
user_id="12345", chat_id="67890"):
user_id="12345", chat_id="67890", thread_id=None):
"""Build a MessageEvent for testing."""
source = SessionSource(
platform=platform,
user_id=user_id,
chat_id=chat_id,
user_name="testuser",
thread_id=thread_id,
)
return MessageEvent(text=text, source=source)

Expand Down Expand Up @@ -214,6 +215,34 @@ async def test_writes_pending_marker(self, tmp_path):
assert "timestamp" in data
assert not (hermes_home / ".update_exit_code").exists()

@pytest.mark.asyncio
async def test_writes_pending_marker_with_thread_id(self, tmp_path):
"""Persists thread_id so update notifications can route back to the thread."""
runner = _make_runner()
event = _make_event(
platform=Platform.TELEGRAM,
chat_id="99999",
thread_id="777",
)

fake_root = tmp_path / "project"
fake_root.mkdir()
(fake_root / ".git").mkdir()
(fake_root / "gateway").mkdir()
(fake_root / "gateway" / "run.py").touch()
fake_file = str(fake_root / "gateway" / "run.py")
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()

with patch("gateway.run._hermes_home", hermes_home), \
patch("gateway.run.__file__", fake_file), \
patch("shutil.which", side_effect=lambda x: "/usr/bin/hermes" if x == "hermes" else "/usr/bin/setsid"), \
patch("subprocess.Popen"):
await runner._handle_update_command(event)

data = json.loads((hermes_home / ".update_pending.json").read_text())
assert data["thread_id"] == "777"

@pytest.mark.asyncio
async def test_spawns_setsid(self, tmp_path):
"""Uses setsid when available."""
Expand Down Expand Up @@ -432,6 +461,31 @@ async def test_sends_notification_with_output(self, tmp_path):
assert call_args[0][0] == "67890" # chat_id
assert "Update complete" in call_args[0][1] or "update finished" in call_args[0][1].lower()

@pytest.mark.asyncio
async def test_sends_notification_with_thread_metadata(self, tmp_path):
"""Final update notification preserves thread metadata when present."""
runner = _make_runner()
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()

pending = {
"platform": "telegram",
"chat_id": "67890",
"thread_id": "777",
"user_id": "12345",
}
(hermes_home / ".update_pending.json").write_text(json.dumps(pending))
(hermes_home / ".update_output.txt").write_text("done")
(hermes_home / ".update_exit_code").write_text("0")

mock_adapter = AsyncMock()
runner.adapters = {Platform.TELEGRAM: mock_adapter}

with patch("gateway.run._hermes_home", hermes_home):
await runner._send_update_notification()

assert mock_adapter.send.call_args.kwargs["metadata"] == {"thread_id": "777"}

@pytest.mark.asyncio
async def test_strips_ansi_codes(self, tmp_path):
"""ANSI escape codes are removed from output."""
Expand Down
52 changes: 52 additions & 0 deletions tests/gateway/test_update_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,58 @@ async def simulate_prompt_cycle():
# Check session was marked as having pending prompt
# (may be cleared by the time we check since update finished)

@pytest.mark.asyncio
async def test_prompt_forwarding_preserves_thread_metadata(self, tmp_path):
"""Forwarded update prompts keep the originating thread/topic metadata."""
runner = _make_runner()
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()

pending = {
"platform": "telegram",
"chat_id": "111",
"thread_id": "777",
"user_id": "222",
"session_key": "agent:main:telegram:group:111:777",
}
(hermes_home / ".update_pending.json").write_text(json.dumps(pending))
(hermes_home / ".update_output.txt").write_text("")
(hermes_home / ".update_prompt.json").write_text(json.dumps({
"prompt": "Restore local changes? [Y/n]",
"default": "y",
"id": "threaded-prompt",
}))

class _PromptCapableAdapter:
def __init__(self):
self.send = AsyncMock()
self.prompt_calls = AsyncMock()

async def send_update_prompt(self, **kwargs):
return await self.prompt_calls(**kwargs)

mock_adapter = _PromptCapableAdapter()
runner.adapters = {Platform.TELEGRAM: mock_adapter}

async def finish_after_prompt():
await asyncio.sleep(0.3)
(hermes_home / ".update_response").write_text("y")
await asyncio.sleep(0.2)
(hermes_home / ".update_exit_code").write_text("0")

with patch("gateway.run._hermes_home", hermes_home):
task = asyncio.create_task(finish_after_prompt())
await runner._watch_update_progress(
poll_interval=0.1,
stream_interval=0.2,
timeout=5.0,
)
await task

assert mock_adapter.prompt_calls.call_args.kwargs["metadata"] == {
"thread_id": "777"
}

@pytest.mark.asyncio
async def test_cleans_up_on_completion(self, tmp_path):
"""All marker files are cleaned up when update finishes."""
Expand Down
Loading