Skip to content

Commit 8130cd9

Browse files
committed
refactor(steering): extract SteeringHook, remove enable_steering config
Move interruption injection logic from _LoopHook into a standalone SteeringHook(AgentHook) that plugs into the hooks system. Steering is now always active — InterruptionChecker is created per-dispatch and SteeringHook is passed as an extra hook. When no interruptions arrive, drain_all() returns empty and the hook is a no-op. Removes enable_steering from AgentDefaults config, CLI wiring, and AgentLoop.__init__. Removes transform_context/convert_to_llm params that were unused after the per-tool cancellation drop. Made-with: Cursor
1 parent 53bf6bb commit 8130cd9

3 files changed

Lines changed: 48 additions & 79 deletions

File tree

nanobot/agent/loop.py

Lines changed: 21 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
from nanobot.agent.context import ContextBuilder
1616
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
1717
from nanobot.agent.memory import Consolidator, Dream
18-
from nanobot.agent.messages import AgentMessage
1918
from nanobot.agent.runner import AgentRunSpec, AgentRunner
20-
from nanobot.agent.steering import InterruptionChecker
19+
from nanobot.agent.steering import InterruptionChecker, SteeringHook
2120
from nanobot.agent.subagent import SubagentManager
2221
from nanobot.agent.tools.cron import CronTool
2322
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
@@ -41,10 +40,6 @@
4140
from nanobot.config.schema import ChannelsConfig, ExecToolConfig, WebToolsConfig
4241
from nanobot.cron.service import CronService
4342

44-
TransformContextHook = Callable[[list[AgentMessage]], list[AgentMessage]]
45-
ConvertToLlmHook = Callable[[list[AgentMessage]], list[dict[str, Any]]]
46-
47-
4843
class _LoopHook(AgentHook):
4944
"""Core hook for the main loop."""
5045

@@ -58,7 +53,6 @@ def __init__(
5853
channel: str = "cli",
5954
chat_id: str = "direct",
6055
message_id: str | None = None,
61-
interruption_checker: InterruptionChecker | None = None,
6256
) -> None:
6357
self._loop = agent_loop
6458
self._on_progress = on_progress
@@ -67,7 +61,6 @@ def __init__(
6761
self._channel = channel
6862
self._chat_id = chat_id
6963
self._message_id = message_id
70-
self._interruption_checker = interruption_checker
7164
self._stream_buf = ""
7265

7366
def wants_streaming(self) -> bool:
@@ -88,20 +81,6 @@ async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> N
8881
await self._on_stream_end(resuming=resuming)
8982
self._stream_buf = ""
9083

91-
async def before_iteration(self, context: AgentHookContext) -> None:
92-
if not self._interruption_checker:
93-
return
94-
pending = self._interruption_checker.drain_all()
95-
if pending:
96-
combined = "\n\n---\n\n".join(m.content for m in pending)
97-
injection = (
98-
"[The user just sent a new message while you were working. "
99-
"Read it and decide: continue current work, switch to the "
100-
"new request, or address both.]\n\n" + combined
101-
)
102-
context.messages.append({"role": "user", "content": injection})
103-
logger.info("Steering: injected {} interruption(s) before LLM call", len(pending))
104-
10584
async def before_execute_tools(self, context: AgentHookContext) -> None:
10685
if self._on_progress:
10786
if not self._on_stream:
@@ -125,23 +104,6 @@ async def after_iteration(self, context: AgentHookContext) -> None:
125104
u.get("completion_tokens", 0),
126105
u.get("cached_tokens", 0),
127106
)
128-
# Steering: inject interruptions after tool results
129-
if self._interruption_checker and context.tool_calls:
130-
pending = self._interruption_checker.drain_all()
131-
if pending:
132-
combined = "\n\n---\n\n".join(m.content for m in pending)
133-
injection = (
134-
"[The user just sent a new message while you were working. "
135-
"Read it and decide: continue current work, switch to the "
136-
"new request, or address both.]\n\n" + combined
137-
)
138-
context.messages.append({"role": "user", "content": injection})
139-
logger.info("Steering: injected {} interruption(s)", len(pending))
140-
if self._on_progress:
141-
await self._on_progress(
142-
"New message merged into current conversation",
143-
tool_hint=True,
144-
)
145107

146108
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
147109
return self._loop._strip_think(content)
@@ -218,10 +180,6 @@ def __init__(
218180
channels_config: ChannelsConfig | None = None,
219181
timezone: str | None = None,
220182
hooks: list[AgentHook] | None = None,
221-
*,
222-
enable_steering: bool = False,
223-
transform_context: TransformContextHook | None = None,
224-
convert_to_llm: ConvertToLlmHook | None = None,
225183
):
226184
from nanobot.config.schema import ExecToolConfig, WebToolsConfig
227185

@@ -254,11 +212,6 @@ def __init__(
254212
self._last_usage: dict[str, int] = {}
255213
self._extra_hooks: list[AgentHook] = hooks or []
256214

257-
# Dual-layer architecture (opt-in)
258-
self.enable_steering = enable_steering
259-
self._transform_context = transform_context
260-
self._convert_to_llm = convert_to_llm or AgentMessage.to_llm_list
261-
262215
self.context = ContextBuilder(workspace, timezone=timezone)
263216
self.sessions = session_manager or SessionManager(workspace)
264217
self.tools = ToolRegistry()
@@ -389,17 +342,15 @@ async def _run_agent_loop(
389342
channel: str = "cli",
390343
chat_id: str = "direct",
391344
message_id: str | None = None,
392-
interruption_checker: InterruptionChecker | None = None,
345+
extra_hooks: list[AgentHook] | None = None,
393346
) -> tuple[str | None, list[str], list[dict]]:
394347
"""Run the agent iteration loop.
395348
396349
*on_stream*: called with each content delta during streaming.
397350
*on_stream_end(resuming)*: called when a streaming session finishes.
398351
``resuming=True`` means tool calls follow (spinner should restart);
399352
``resuming=False`` means this is the final response.
400-
401-
*interruption_checker*: when steering is enabled, used to merge user
402-
messages that arrive during tool execution.
353+
*extra_hooks*: per-call hooks (e.g. SteeringHook) merged with instance hooks.
403354
"""
404355
loop_hook = _LoopHook(
405356
self,
@@ -409,11 +360,11 @@ async def _run_agent_loop(
409360
channel=channel,
410361
chat_id=chat_id,
411362
message_id=message_id,
412-
interruption_checker=interruption_checker,
413363
)
364+
all_extra = self._extra_hooks + (extra_hooks or [])
414365
hook: AgentHook = (
415-
_LoopHookChain(loop_hook, self._extra_hooks)
416-
if self._extra_hooks
366+
_LoopHookChain(loop_hook, all_extra)
367+
if all_extra
417368
else loop_hook
418369
)
419370

@@ -450,16 +401,14 @@ async def run(self) -> None:
450401
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
451402
self._running = True
452403
await self._connect_mcp()
453-
logger.info("Agent loop started (steering={})", self.enable_steering)
404+
logger.info("Agent loop started")
454405

455406
while self._running:
456407
try:
457408
msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
458409
except asyncio.TimeoutError:
459410
continue
460411
except asyncio.CancelledError:
461-
# Preserve real task cancellation so shutdown can complete cleanly.
462-
# Only ignore non-task CancelledError signals that may leak from integrations.
463412
if not self._running or asyncio.current_task().cancelling():
464413
raise
465414
continue
@@ -475,23 +424,20 @@ async def run(self) -> None:
475424
await self.bus.publish_outbound(result)
476425
continue
477426

478-
if self.enable_steering:
479-
active = [t for t in self._active_tasks.get(msg.session_key, []) if not t.done()]
480-
if active and msg.session_key in self._interrupt_checkers:
481-
await self._interrupt_checkers[msg.session_key].signal(msg)
482-
logger.info("Steering: signaled interruption for {}", msg.session_key)
483-
continue
427+
active = [t for t in self._active_tasks.get(msg.session_key, []) if not t.done()]
428+
if active and msg.session_key in self._interrupt_checkers:
429+
await self._interrupt_checkers[msg.session_key].signal(msg)
430+
logger.info("Steering: signaled interruption for {}", msg.session_key)
431+
continue
484432

485433
task = asyncio.create_task(self._dispatch(msg))
486434
self._active_tasks.setdefault(msg.session_key, []).append(task)
487435
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
488436

489437
async def _dispatch(self, msg: InboundMessage) -> None:
490438
"""Process a message: per-session serial, cross-session concurrent."""
491-
checker: InterruptionChecker | None = None
492-
if self.enable_steering:
493-
checker = InterruptionChecker()
494-
self._interrupt_checkers[msg.session_key] = checker
439+
checker = InterruptionChecker()
440+
self._interrupt_checkers[msg.session_key] = checker
495441
lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock())
496442
gate = self._concurrency_gate or nullcontext()
497443
try:
@@ -531,7 +477,7 @@ async def on_stream_end(*, resuming: bool = False) -> None:
531477
response = await self._process_message(
532478
msg,
533479
on_stream=on_stream, on_stream_end=on_stream_end,
534-
interruption_checker=checker,
480+
extra_hooks=[SteeringHook(checker)],
535481
)
536482
if response is not None:
537483
await self.bus.publish_outbound(response)
@@ -550,8 +496,7 @@ async def on_stream_end(*, resuming: bool = False) -> None:
550496
content="Sorry, I encountered an error.",
551497
))
552498
finally:
553-
if self.enable_steering:
554-
self._interrupt_checkers.pop(msg.session_key, None)
499+
self._interrupt_checkers.pop(msg.session_key, None)
555500

556501
async def close_mcp(self) -> None:
557502
"""Drain pending background archives, then close MCP connections."""
@@ -583,7 +528,7 @@ async def _process_message(
583528
on_progress: Callable[[str], Awaitable[None]] | None = None,
584529
on_stream: Callable[[str], Awaitable[None]] | None = None,
585530
on_stream_end: Callable[..., Awaitable[None]] | None = None,
586-
interruption_checker: InterruptionChecker | None = None,
531+
extra_hooks: list[AgentHook] | None = None,
587532
) -> OutboundMessage | None:
588533
"""Process a single inbound message and return the response."""
589534
# System messages: parse origin from chat_id ("channel:chat_id")
@@ -607,7 +552,7 @@ async def _process_message(
607552
final_content, _, all_msgs = await self._run_agent_loop(
608553
messages, session=session, channel=channel, chat_id=chat_id,
609554
message_id=msg.metadata.get("message_id"),
610-
interruption_checker=interruption_checker,
555+
extra_hooks=extra_hooks,
611556
)
612557
self._save_turn(session, all_msgs, 1 + len(history))
613558
self._clear_runtime_checkpoint(session)
@@ -661,7 +606,7 @@ async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
661606
session=session,
662607
channel=msg.channel, chat_id=msg.chat_id,
663608
message_id=msg.metadata.get("message_id"),
664-
interruption_checker=interruption_checker,
609+
extra_hooks=extra_hooks,
665610
)
666611

667612
if final_content is None or not final_content.strip():
@@ -841,13 +786,13 @@ async def process_direct(
841786
on_progress: Callable[[str], Awaitable[None]] | None = None,
842787
on_stream: Callable[[str], Awaitable[None]] | None = None,
843788
on_stream_end: Callable[..., Awaitable[None]] | None = None,
844-
interruption_checker: InterruptionChecker | None = None,
789+
extra_hooks: list[AgentHook] | None = None,
845790
) -> OutboundMessage | None:
846791
"""Process a message directly and return the outbound payload."""
847792
await self._connect_mcp()
848793
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
849794
return await self._process_message(
850795
msg, session_key=session_key, on_progress=on_progress,
851796
on_stream=on_stream, on_stream_end=on_stream_end,
852-
interruption_checker=interruption_checker,
797+
extra_hooks=extra_hooks,
853798
)

nanobot/agent/steering.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from loguru import logger
99

10+
from nanobot.agent.hook import AgentHook, AgentHookContext
11+
1012
if TYPE_CHECKING:
1113
from nanobot.bus.events import InboundMessage
1214

@@ -15,7 +17,7 @@ class InterruptionChecker:
1517
"""
1618
Per-session interruption queue.
1719
18-
Written by ``AgentLoop.run()``; read by ``_run_agent_loop``.
20+
Written by ``AgentLoop.run()``; read by ``SteeringHook``.
1921
"""
2022

2123
def __init__(self) -> None:
@@ -41,3 +43,27 @@ def drain_all(self) -> list[InboundMessage]:
4143
@property
4244
def has_pending(self) -> bool:
4345
return not self._queue.empty()
46+
47+
48+
class SteeringHook(AgentHook):
49+
"""AgentHook that injects pending user interruptions before each LLM call.
50+
51+
Drop this into the ``hooks`` list of ``AgentLoop`` to enable steering.
52+
The hook is stateless beyond its reference to the per-session checker.
53+
"""
54+
55+
def __init__(self, checker: InterruptionChecker) -> None:
56+
self._checker = checker
57+
58+
async def before_iteration(self, context: AgentHookContext) -> None:
59+
pending = self._checker.drain_all()
60+
if not pending:
61+
return
62+
combined = "\n\n---\n\n".join(m.content for m in pending)
63+
injection = (
64+
"[The user just sent a new message while you were working. "
65+
"Read it and decide: continue current work, switch to the "
66+
"new request, or address both.]\n\n" + combined
67+
)
68+
context.messages.append({"role": "user", "content": injection})
69+
logger.info("Steering: injected {} interruption(s) before LLM call", len(pending))

nanobot/cli/commands.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,6 @@ def gateway(
681681
mcp_servers=config.tools.mcp_servers,
682682
channels_config=config.channels,
683683
timezone=config.agents.defaults.timezone,
684-
enable_steering=config.agents.defaults.enable_steering,
685684
)
686685

687686
# Set cron callback (needs agent)
@@ -913,7 +912,6 @@ def agent(
913912
mcp_servers=config.tools.mcp_servers,
914913
channels_config=config.channels,
915914
timezone=config.agents.defaults.timezone,
916-
enable_steering=config.agents.defaults.enable_steering,
917915
)
918916
restart_notice = consume_restart_notice_from_env()
919917
if restart_notice and should_show_cli_restart_notice(restart_notice, session_id):

0 commit comments

Comments
 (0)