Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 67 additions & 5 deletions nanobot/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import select
import signal
import sys
import time
from collections.abc import Callable
from contextlib import nullcontext, suppress
from pathlib import Path
Expand Down Expand Up @@ -747,13 +748,57 @@ async def _silent(*_args, **_kwargs):
if isinstance(message_tool, MessageTool):
message_record_token = message_tool.set_record_channel_delivery(True)

channel_name = job.payload.channel or "cli"
chat_id = job.payload.to or "direct"
try:
target_channel = channels.channels.get(channel_name)
except NameError:
target_channel = None
wants_stream = target_channel is not None and target_channel.supports_streaming

stream_base_id = None
stream_segment = 0

def _current_stream_id() -> str:
return f"{stream_base_id}:{stream_segment}"

async def _on_stream(delta: str) -> None:
meta = dict(job.payload.channel_meta)
meta["_stream_delta"] = True
meta["_stream_id"] = _current_stream_id()
await bus.publish_outbound(OutboundMessage(
channel=channel_name,
chat_id=chat_id,
content=delta,
metadata=meta,
))

async def _on_stream_end(*, resuming: bool = False) -> None:
nonlocal stream_segment
meta = dict(job.payload.channel_meta)
meta["_stream_end"] = True
meta["_resuming"] = resuming
meta["_stream_id"] = _current_stream_id()
await bus.publish_outbound(OutboundMessage(
channel=channel_name,
chat_id=chat_id,
content="",
metadata=meta,
))
stream_segment += 1

if wants_stream:
stream_base_id = f"cron:{job.id}:{time.time_ns()}"

try:
resp = await agent.process_direct(
reminder_note,
session_key=f"cron:{job.id}",
channel=job.payload.channel or "cli",
chat_id=job.payload.to or "direct",
channel=channel_name,
chat_id=chat_id,
on_progress=_silent,
on_stream=_on_stream if wants_stream else None,
on_stream_end=_on_stream_end if wants_stream else None,
)
finally:
if isinstance(cron_tool, CronTool) and cron_token is not None:
Expand All @@ -764,23 +809,40 @@ async def _silent(*_args, **_kwargs):
response = resp.content if resp else ""

if job.payload.deliver and isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
if wants_stream:
await bus.publish_outbound(OutboundMessage(
channel=channel_name,
chat_id=chat_id,
content="",
metadata={**job.payload.channel_meta, "_turn_end": True},
))
return response

if job.payload.deliver and job.payload.to and response:
should_notify = await evaluate_response(
response, reminder_note, agent.provider, agent.model,
)
if should_notify:
meta = dict(job.payload.channel_meta)
if wants_stream:
meta["_streamed"] = True
await _deliver_to_channel(
OutboundMessage(
channel=job.payload.channel or "cli",
chat_id=job.payload.to,
channel=channel_name,
chat_id=chat_id,
content=response,
metadata=dict(job.payload.channel_meta),
metadata=meta,
),
record=True,
session_key=job.payload.session_key,
)
if wants_stream:
await bus.publish_outbound(OutboundMessage(
channel=channel_name,
chat_id=chat_id,
content="",
metadata={**job.payload.channel_meta, "_turn_end": True},
))
return response

cron.on_job = on_cron_job
Expand Down
133 changes: 133 additions & 0 deletions tests/cli/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -1340,6 +1340,139 @@ async def _always_reject(*_args, **_kwargs) -> bool:
bus.publish_outbound.assert_not_awaited()


def test_gateway_cron_job_streams_when_channel_supports_it(
monkeypatch, tmp_path: Path
) -> None:
"""Cron jobs on streaming channels must emit deltas with stream_id and turn_end."""
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")

config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
bus = MagicMock()
bus.publish_outbound = AsyncMock()
seen: dict[str, object] = {}

monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr("nanobot.providers.factory.make_provider", lambda _config: _fake_provider())
monkeypatch.setattr(
"nanobot.providers.factory.build_provider_snapshot",
lambda _config: _test_provider_snapshot(object(), _config),
)
monkeypatch.setattr(
"nanobot.providers.factory.load_provider_snapshot",
lambda _config_path=None: _test_provider_snapshot(object(), config),
)
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus)
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())

class _FakeStreamingChannel:
supports_streaming = True

class _FakeChannelManager:
def __init__(self, *_args, **_kwargs) -> None:
self.channels = {"websocket": _FakeStreamingChannel()}
self.enabled_channels = ["websocket"]

async def start_all(self):
pass

async def stop_all(self):
pass

class _FakeCron:
def __init__(self, _store_path: Path) -> None:
self.on_job = None
seen["cron"] = self

def status(self):
return {"enabled": True, "jobs": 0, "next_wake_at_ms": None}

def register_system_job(self, job):
pass

def stop(self):
pass

class _FakeAgentLoop:
@classmethod
def from_config(cls, config, bus=None, **extra):
return cls(**extra)
def __init__(self, *args, **kwargs) -> None:
self.model = "test-model"
self.provider = object()
self.tools = {}
self.dream = MagicMock()
self.sessions = MagicMock()

async def process_direct(self, *_args, on_stream=None, on_stream_end=None, **_kwargs):
seen["on_stream"] = on_stream
seen["on_stream_end"] = on_stream_end
if on_stream:
await on_stream("Hello")
await on_stream(" world")
if on_stream_end:
await on_stream_end(resuming=False)
return OutboundMessage(
channel="websocket",
chat_id="user-1",
content="Hello world",
)

async def close_mcp(self) -> None:
return None

async def run(self) -> None:
return None

def stop(self) -> None:
return None

monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
monkeypatch.setattr("nanobot.cli.commands.AgentLoop", _FakeAgentLoop)
monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _FakeChannelManager)

result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert result.exit_code == 0

cron = seen["cron"]
job = CronJob(
id="cron-stream-test",
name="test-stream",
payload=CronPayload(
message="Say hello.",
deliver=True,
channel="websocket",
to="user-1",
),
)
response = asyncio.run(cron.on_job(job))

assert response == "Hello world"
assert seen["on_stream"] is not None
assert seen["on_stream_end"] is not None

calls = bus.publish_outbound.await_args_list
# First two calls are streaming deltas
assert calls[0].args[0].metadata.get("_stream_delta") is True
assert calls[0].args[0].metadata.get("_stream_id") is not None
assert calls[0].args[0].content == "Hello"
assert calls[1].args[0].metadata.get("_stream_delta") is True
assert calls[1].args[0].metadata.get("_stream_id") == calls[0].args[0].metadata["_stream_id"]
assert calls[1].args[0].content == " world"
# Third call is stream_end
assert calls[2].args[0].metadata.get("_stream_end") is True
assert calls[2].args[0].metadata.get("_stream_id") == calls[0].args[0].metadata["_stream_id"]
# Fourth call is the final message with _streamed marker
assert calls[3].args[0].metadata.get("_streamed") is True
assert calls[3].args[0].content == "Hello world"
# Fifth call is turn_end
assert calls[4].args[0].metadata.get("_turn_end") is True


def test_gateway_workspace_override_does_not_migrate_legacy_cron(
monkeypatch, tmp_path: Path
) -> None:
Expand Down
Loading