diff --git a/.gitignore b/.gitignore index 054e5ce7032..733999f9e7e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,13 @@ .env .web .orion +.context/ + +# Pipeline artifacts +docs/brainstorms/ +docs/ideation/ +docs/plans/ +docs/solutions/ # webui (monorepo frontend) webui/node_modules/ diff --git a/docs/README.md b/docs/README.md index d8ff30247b9..c794cddb97b 100644 --- a/docs/README.md +++ b/docs/README.md @@ -29,6 +29,7 @@ Use these when you want deeper customization, integration, or extension details. | Memory | [`memory.md`](./memory.md) | How nanobot stores, consolidates, and restores memory | | Python SDK | [`python-sdk.md`](./python-sdk.md) | Use nanobot programmatically from Python | | Channel plugin guide | [`channel-plugin-guide.md`](./channel-plugin-guide.md) | Build and test custom chat channel plugins | +| Hook plugin guide | [`hook-plugin-guide.md`](./hook-plugin-guide.md) | Build and distribute hook plugins for agent lifecycle events | | WebSocket channel | [`websocket.md`](./websocket.md) | Real-time WebSocket access and protocol details | | Custom tools | [`my-tool.md`](./my-tool.md) | Inspect and tune runtime state with the `my` tool | diff --git a/docs/configuration.md b/docs/configuration.md index f295b50c554..a9b281bbea9 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -785,6 +785,27 @@ MCP tools are automatically discovered and registered on startup. The LLM can us +## Hook Plugins + +nanobot discovers hook plugins installed via `pip` through the `nanobot.hooks` entry point group. **No plugins are loaded by default** — you must explicitly list the plugins to enable in `hooks.enabled_plugins`. This is a security control to prevent unintended code execution. + +When `enabled_plugins` is omitted or set to `null`, all discovered plugins are **skipped**. + +```json +{ + "hooks": { + "enabled_plugins": ["ratelimit", "audit-log"] + } +} +``` + +| Option | Default | Description | +|--------|---------|-------------| +| `hooks.enabled_plugins` | `null` (deny all) | List of plugin entry-point names to allow. When `null` or unset, no external hook plugins are loaded. Only listed plugins have their entry points loaded; all others are skipped before module-level code executes. | + +See the [hook plugin guide](./hook-plugin-guide.md) for building and packaging hook plugins. + + ## Security > [!TIP] diff --git a/docs/hook-plugin-guide.md b/docs/hook-plugin-guide.md new file mode 100644 index 00000000000..2eee213616e --- /dev/null +++ b/docs/hook-plugin-guide.md @@ -0,0 +1,239 @@ +# Hook Plugin Guide + +Build a custom nanobot hook plugin in three steps: implement, package, install. + +Hooks let you observe, transform, or guard agent lifecycle events — without modifying nanobot internals. + +## How It Works + +nanobot discovers hook plugins via Python [entry points](https://packaging.python.org/en/latest/specifications/entry-points/), the same mechanism used by channel plugins. When `nanobot gateway` starts, the HookCenter scans: + +1. External packages registered under the `nanobot.hooks` entry point group +2. Plugins listed in `config.hooks.enabled_plugins` allowlist (when configured) + +## Quick Start + +We'll build a minimal rate-limiting hook plugin that blocks excessive tool calls. + +### Project Structure + +```text +nanobot-hook-ratelimit/ +├── nanobot_hook_ratelimit/ +│ ├── __init__.py # re-export RateLimitHandler +│ └── handler.py # handler implementation +└── pyproject.toml +``` + +### 1. Implement Your Handler + +```python +# nanobot_hook_ratelimit/__init__.py +from nanobot_hook_ratelimit.handler import RateLimitHandler + +__all__ = ["RateLimitHandler"] +``` + +```python +# nanobot_hook_ratelimit/handler.py +from nanobot.hooks import BeforeExecuteTools, Deny, Modified + + + +class RateLimitHandler: + """Block tool execution when a per-session limit is exceeded.""" + + # Register for tool execution events as a guard handler. + hook_events = [(BeforeExecuteTools, "guard")] + + def __init__(self, max_tools_per_turn: int = 10) -> None: + self._max_tools_per_turn = max_tools_per_turn + self._counts: dict[str, int] = {} + + async def __call__(self, event: BeforeExecuteTools): + session_id = getattr(event, "session_key", "default") + count = self._counts.get(session_id, 0) + + if count >= self._max_tools_per_turn: + return Deny( + f"Rate limit: max {self._max_tools_per_turn} tools per turn " + f"(current: {count})" + ) + + self._counts[session_id] = count + len(event.tool_calls) + return None + + +class BlocklistHandler: + """Abort the agent loop if a blocked tool is called.""" + + hook_events = [(BeforeExecuteTools, "guard")] + + def __init__(self, blocked_tools: list[str] | None = None) -> None: + self._blocked = set(blocked_tools or []) + + async def __call__(self, event: BeforeExecuteTools): + for tc in event.tool_calls: + if tc.name in self._blocked: + return Deny( + f"Blocked tool '{tc.name}' — agent loop aborted", + abort=True, + ) + return None +``` + +### 2. Register the Entry Point + +```toml +# pyproject.toml +[project] +name = "nanobot-hook-ratelimit" +version = "0.1.0" +dependencies = ["nanobot-ai"] + +[project.entry-points."nanobot.hooks"] +ratelimit = "nanobot_hook_ratelimit:RateLimitHandler" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["nanobot_hook_ratelimit"] +``` + +The key (`ratelimit`) becomes the plugin name shown in logs and used in the `enabled_plugins` allowlist. The value points to your handler class. + +### 3. Install & Configure + +```bash +pip install -e . +``` + +Edit `~/.nanobot/config.json` to enable the plugin: + +```json +{ + "hooks": { + "enabled_plugins": ["ratelimit"] + } +} +``` + +When `enabled_plugins` is set, **only** listed plugins are loaded. Without this key (or when set to `null`), **no** discovered plugins are loaded — this is a default-deny security policy. + +### 4. Verify + +```bash +nanobot gateway +``` + +Check the logs — you should see: +``` +Registered hook plugin 'ratelimit' with 1 events +``` + +## Handler Contract + +A hook handler is any callable matching the `HookHandler` protocol: + +```python +async def handler(event: EventType) -> HookResult | Modified | Deny | None: ... +``` + +| Return value | Semantic | Effect | +|-------------|----------|--------| +| `None` | Observe | No action needed; handler ran as a side-effect | +| `Modified(data)` | Transform | Apply the returned data to the event (dict keys mapped to event fields) | +| `Deny(reason)` | Guard (soft) | Block the operation — runner injects the reason as a tool result and continues the loop | +| `Deny(reason, abort=True)` | Guard (hard) | Block the operation — runner terminates the agent loop immediately, `reason` becomes the final content | + +### Declaring subscriptions + +Your handler class or module must expose a `hook_events` attribute: + +```python +hook_events: list[tuple[type, str]] = [ + (BeforeIteration, "observe"), + (BeforeExecuteTools, "guard"), + (AfterIteration, "observe"), +] +``` + +Each tuple is `(event_type, mode)`. Mode must be one of `"guard"`, `"transform"`, or `"observe"`. + +## Event Types + +v1 exposes six event types covering the agent iteration lifecycle: + +| Event | Fields | Mode | +|-------|--------|------| +| `BeforeIteration` | `iteration`, `messages` | guard, observe | +| `OnStream` | `delta`, `iteration` | observe | +| `OnStreamEnd` | `resuming`, `iteration` | observe | +| `BeforeExecuteTools` | `iteration`, `tool_calls`, `response` | guard, observe | +| `AfterIteration` | `iteration`, `final_content`, `stop_reason`, `usage`, `tool_calls`, `tool_events`, `tool_results`, `error` | observe | +| `FinalizeContent` | *(registration marker only)* | transform pipeline | + +All event types are importable from `nanobot.hooks`: + +```python +from nanobot.hooks import ( + BeforeIteration, + AfterIteration, + BeforeExecuteTools, + OnStream, + OnStreamEnd, + FinalizeContent, + Deny, + Modified, +) +``` + +## Dispatch Order + +Within a single event emission, handlers run in this order: + +1. **Guards** (internal, then external) — first `Deny` value short-circuits; remaining handlers are skipped. +2. **Transforms** (internal, then external) — chained pipeline; each handler receives data modified by the previous one. +3. **Observes** (internal, then external) — sequential execution with per-handler error isolation. + +Internal handlers (built-in framework logic such as streaming and progress) always run before external plugins. + +## Security + +Hook plugin entry-point loading carries inherent security implications. When the `nanobot gateway` starts, the `HookCenter` loads **only** the hooks listed in `hooks.enabled_plugins`. No plugins are loaded by default — you must explicitly opt in. Any hook plugin has **full access to the agent process** — all conversational data, in-memory state, filesystem access, and network access. + +**Important controls:** + +- Set `hooks.enabled_plugins` to an explicit allowlist to control which plugins load. Plugins not in this list are skipped before their module-level code executes. +- Audit your plugin dependencies. Any installed hook package can execute arbitrary Python code at `ep.load()` time. +- For high-security deployments, consider running nanobot in a sandboxed environment (`tools.restrictToWorkspace`, `tools.exec.sandbox: bwrap`). + +## Naming Convention + +| What | Format | Example | +|------|--------|---------| +| PyPI package | `nanobot-hook-{name}` | `nanobot-hook-ratelimit` | +| Entry point key | `{name}` | `ratelimit` | +| Config allowlist | `hooks.enabled_plugins[{name}]` | `ratelimit` | +| Python package | `nanobot_hook_{name}` | `nanobot_hook_ratelimit` | + +## Built-in Hook API (AgentHook, backward-compatible) + +Legacy `AgentHook` subclasses remain fully supported through a compatibility adapter. Existing hook code (such as the Python SDK usage below) continues to work unchanged: + +```python +from nanobot.agent import AgentHook, AgentHookContext + + +class AuditHook(AgentHook): + async def before_execute_tools(self, context: AgentHookContext) -> None: + for tc in context.tool_calls: + print(f"[audit] {tc.name}") + +# Works as before — adapted internally to HookCenter +result = await bot.run("hello", hooks=[AuditHook()]) +``` + +See the [Python SDK guide](./python-sdk.md) for the full SDK hooks API reference. diff --git a/docs/python-sdk.md b/docs/python-sdk.md index 5ee66a349e2..00b16602904 100644 --- a/docs/python-sdk.md +++ b/docs/python-sdk.md @@ -93,6 +93,48 @@ Run the agent once and return a `RunResult`. ## Hooks +There are two ways to write hooks: the new **event-based HookCenter API** and the **legacy AgentHook API**. Both work — the legacy AgentHook is automatically adapted to the HookCenter at runtime. + +### Event-based hooks (recommended for new plugins) + +The event-based API uses typed event dataclasses. Handlers subscribe to specific event types and return `None` (observe), `Modified(data)` (transform), or `Deny(reason)` (guard): + +```python +from nanobot.hooks import BeforeExecuteTools, Deny + + + +class RateLimitHandler: + hook_events = [(BeforeExecuteTools, "guard")] + + def __init__(self, max_calls: int = 10): + self._count = 0 + self._max_calls = max_calls + + async def __call__(self, event: BeforeExecuteTools): + self._count += len(event.tool_calls) + if self._count > self._max_calls: + return Deny(f"Rate limit exceeded ({self._max_calls} max)") + return None +``` + +Event types importable from `nanobot.hooks`: + +| Event | Purpose | +|-------|---------| +| `BeforeIteration(iteration, messages)` | Before each LLM iteration | +| `OnStream(delta, iteration)` | On each streaming token | +| `OnStreamEnd(resuming, iteration)` | When streaming finishes | +| `BeforeExecuteTools(iteration, tool_calls, response)` | Before tool execution | +| `AfterIteration(iteration, final_content, stop_reason, usage, ...)` | After each iteration | +| `FinalizeContent` *(registration marker)* | Content transform pipeline | + +Return types: `Deny(reason)`, `Modified(data)`, `None`. + +For packaging and distribution (entry_points), see the [hook plugin guide](./hook-plugin-guide.md). + +### Legacy AgentHook API (backward-compatible) + Hooks let you observe or customize the agent loop. Subclass `AgentHook` and override the methods you need. ### Hook lifecycle diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index cbddfc28640..c7318c61639 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -41,6 +41,8 @@ from nanobot.bus.queue import MessageBus from nanobot.command import CommandContext, CommandRouter, register_builtin_commands from nanobot.config.schema import AgentDefaults +from nanobot.hooks.adapters import adapt_agent_hook_list +from nanobot.hooks.center import HookCenter from nanobot.providers.base import LLMProvider from nanobot.providers.factory import ProviderSnapshot from nanobot.session.manager import Session, SessionManager @@ -209,6 +211,7 @@ def __init__( tools_config: ToolsConfig | None = None, provider_snapshot_loader: Callable[[], ProviderSnapshot] | None = None, provider_signature: tuple[object, ...] | None = None, + hooks_config: Any | None = None, ): from nanobot.config.schema import ExecToolConfig, ToolsConfig, WebToolsConfig @@ -243,6 +246,8 @@ def __init__( self._start_time = time.time() self._last_usage: dict[str, int] = {} self._extra_hooks: list[AgentHook] = hooks or [] + self._hook_center = HookCenter() + self._hook_center.discover(hooks_config) self.context = ContextBuilder(workspace, timezone=timezone, disabled_skills=disabled_skills) self.sessions = session_manager or SessionManager(workspace) @@ -259,6 +264,7 @@ def __init__( restrict_to_workspace=restrict_to_workspace, disabled_skills=disabled_skills, max_iterations=self.max_iterations, + hook_center=self._hook_center, ) self._unified_session = unified_session self._max_messages = max_messages if max_messages > 0 else 120 @@ -618,13 +624,17 @@ def _to_user_message(pending_msg: InboundMessage) -> dict[str, Any]: return items + center = self._hook_center + hook_session = center.create_session() + adapt_agent_hook_list([hook], hook_session, center) result = await self.runner.run(AgentRunSpec( initial_messages=initial_messages, tools=self.tools, model=self.model, max_iterations=self.max_iterations, max_tool_result_chars=self.max_tool_result_chars, - hook=hook, + center=center, + session=hook_session, error_message="Sorry, I encountered an error calling the AI model.", concurrent_tools=True, workspace=self.workspace, diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index c7cf126c31b..11c0280f607 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -11,9 +11,18 @@ from loguru import logger -from nanobot.agent.hook import AgentHook, AgentHookContext +from nanobot.agent.hook import AgentHookContext from nanobot.agent.tools.ask import AskUserInterrupt from nanobot.agent.tools.registry import ToolRegistry +from nanobot.hooks import ( + AfterIteration, + BeforeExecuteTools, + BeforeIteration, + Deny, + OnStream, + OnStreamEnd, +) +from nanobot.hooks.center import HookCenter, HookSession from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.utils.helpers import ( build_assistant_message, @@ -63,7 +72,8 @@ class AgentRunSpec: temperature: float | None = None max_tokens: int | None = None reasoning_effort: str | None = None - hook: AgentHook | None = None + center: HookCenter | None = None + session: HookSession | None = None error_message: str | None = _DEFAULT_ERROR_MESSAGE max_iterations_message: str | None = None concurrent_tools: bool = False @@ -94,6 +104,19 @@ class AgentRunResult: had_injections: bool = False +def _make_after_iteration(context: AgentHookContext, iteration: int) -> AfterIteration: + return AfterIteration( + iteration=iteration, + final_content=context.final_content, + stop_reason=context.stop_reason, + usage=dict(context.usage), + tool_calls=list(context.tool_calls), + tool_events=list(context.tool_events), + tool_results=list(context.tool_results), + error=context.error, + ) + + class AgentRunner: """Run a tool-capable LLM loop without product-layer concerns.""" @@ -229,7 +252,8 @@ async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]: return injected_messages async def run(self, spec: AgentRunSpec) -> AgentRunResult: - hook = spec.hook or AgentHook() + center = spec.center or HookCenter() + session = spec.session or center.create_session() messages = list(spec.initial_messages) final_content: str | None = None tools_used: list[str] = [] @@ -242,6 +266,7 @@ async def run(self, spec: AgentRunSpec) -> AgentRunResult: length_recovery_count = 0 had_injections = False injection_cycles = 0 + context: AgentHookContext | None = None for iteration in range(spec.max_iterations): try: @@ -270,8 +295,25 @@ async def run(self, spec: AgentRunSpec) -> AgentRunResult: except Exception: messages_for_model = messages context = AgentHookContext(iteration=iteration, messages=messages) - await hook.before_iteration(context) - response = await self._request_model(spec, messages_for_model, hook, context) + session.context = context + bi_result = await center.emit(BeforeIteration(iteration=iteration, messages=messages), session) + if isinstance(bi_result, Deny): + if bi_result.abort: + final_content = bi_result.reason + stop_reason = "aborted" + context.final_content = final_content + context.stop_reason = stop_reason + await center.emit(_make_after_iteration(context, iteration), session) + break + messages.append({ + "role": "user", + "content": f"[System] Operation denied by hook guard: {bi_result.reason}", + }) + empty_content_retries = 0 + length_recovery_count = 0 + await center.emit(_make_after_iteration(context, iteration), session) + continue + response = await self._request_model(spec, messages_for_model, center, session, context) raw_usage = self._usage_dict(response.usage) context.response = response context.usage = dict(raw_usage) @@ -284,8 +326,8 @@ async def run(self, spec: AgentRunSpec) -> AgentRunResult: if ask_index is not None: tool_calls = tool_calls[: ask_index + 1] context.tool_calls = list(tool_calls) - if hook.wants_streaming(): - await hook.on_stream_end(context, resuming=True) + if center.wants_streaming(session): + await center.emit(OnStreamEnd(resuming=True, iteration=iteration), session) assistant_message = build_assistant_message( response.content or "", @@ -307,7 +349,30 @@ async def run(self, spec: AgentRunSpec) -> AgentRunResult: }, ) - await hook.before_execute_tools(context) + result = await center.emit(BeforeExecuteTools(iteration=iteration, tool_calls=context.tool_calls, response=context.response), session) + if isinstance(result, Deny): + if result.abort: + final_content = result.reason + stop_reason = "aborted" + context.final_content = final_content + context.stop_reason = stop_reason + await center.emit(_make_after_iteration(context, iteration), session) + break + for tc in tool_calls: + messages.append({ + "role": "tool", + "tool_call_id": tc.id, + "name": tc.name, + "content": f"Tool execution was denied: {result.reason}", + }) + tool_events.extend( + {"name": tc.name, "status": "ok", "detail": f"[denied] {result.reason}"} + for tc in tool_calls + ) + empty_content_retries = 0 + length_recovery_count = 0 + await center.emit(_make_after_iteration(context, iteration), session) + continue results, new_events, fatal_error = await self._execute_tools( spec, @@ -340,9 +405,9 @@ async def run(self, spec: AgentRunSpec) -> AgentRunResult: stop_reason = "ask_user" context.final_content = final_content context.stop_reason = stop_reason - if hook.wants_streaming(): - await hook.on_stream_end(context, resuming=False) - await hook.after_iteration(context) + if center.wants_streaming(session): + await center.emit(OnStreamEnd(resuming=False, iteration=iteration), session) + await center.emit(_make_after_iteration(context, iteration), session) break error = f"Error: {type(fatal_error).__name__}: {fatal_error}" final_content = error @@ -351,7 +416,7 @@ async def run(self, spec: AgentRunSpec) -> AgentRunResult: context.final_content = final_content context.error = error context.stop_reason = stop_reason - await hook.after_iteration(context) + await center.emit(_make_after_iteration(context, iteration), session) should_continue, injection_cycles = await self._try_drain_injections( spec, messages, None, injection_cycles, phase="after tool error", @@ -380,7 +445,7 @@ async def run(self, spec: AgentRunSpec) -> AgentRunResult: ) if _drained: had_injections = True - await hook.after_iteration(context) + await center.emit(_make_after_iteration(context, iteration), session) continue if response.has_tool_calls: @@ -390,7 +455,7 @@ async def run(self, spec: AgentRunSpec) -> AgentRunResult: spec.session_key or "default", ) - clean = hook.finalize_content(context, response.content) + clean = center.finalize_content(response.content, session) if response.finish_reason != "error" and is_blank_text(clean): empty_content_retries += 1 if empty_content_retries < _MAX_EMPTY_RETRIES: @@ -401,9 +466,9 @@ async def run(self, spec: AgentRunSpec) -> AgentRunResult: empty_content_retries, _MAX_EMPTY_RETRIES, ) - if hook.wants_streaming(): - await hook.on_stream_end(context, resuming=False) - await hook.after_iteration(context) + if center.wants_streaming(session): + await center.emit(OnStreamEnd(resuming=False, iteration=iteration), session) + await center.emit(_make_after_iteration(context, iteration), session) continue logger.warning( "Empty response on turn {} for {} after {} retries; attempting finalization", @@ -411,8 +476,8 @@ async def run(self, spec: AgentRunSpec) -> AgentRunResult: spec.session_key or "default", empty_content_retries, ) - if hook.wants_streaming(): - await hook.on_stream_end(context, resuming=False) + if center.wants_streaming(session): + await center.emit(OnStreamEnd(resuming=False, iteration=iteration), session) response = await self._request_finalization_retry(spec, messages_for_model) retry_usage = self._usage_dict(response.usage) self._accumulate_usage(usage, retry_usage) @@ -420,7 +485,7 @@ async def run(self, spec: AgentRunSpec) -> AgentRunResult: context.response = response context.usage = dict(raw_usage) context.tool_calls = list(response.tool_calls) - clean = hook.finalize_content(context, response.content) + clean = center.finalize_content(response.content, session) if response.finish_reason == "length" and not is_blank_text(clean): length_recovery_count += 1 @@ -432,15 +497,15 @@ async def run(self, spec: AgentRunSpec) -> AgentRunResult: length_recovery_count, _MAX_LENGTH_RECOVERIES, ) - if hook.wants_streaming(): - await hook.on_stream_end(context, resuming=True) + if center.wants_streaming(session): + await center.emit(OnStreamEnd(resuming=True, iteration=iteration), session) messages.append(build_assistant_message( clean, reasoning_content=response.reasoning_content, thinking_blocks=response.thinking_blocks, )) messages.append(build_length_recovery_message()) - await hook.after_iteration(context) + await center.emit(_make_after_iteration(context, iteration), session) continue assistant_message: dict[str, Any] | None = None @@ -462,11 +527,11 @@ async def run(self, spec: AgentRunSpec) -> AgentRunResult: if should_continue: had_injections = True - if hook.wants_streaming(): - await hook.on_stream_end(context, resuming=should_continue) + if center.wants_streaming(session): + await center.emit(OnStreamEnd(resuming=should_continue, iteration=iteration), session) if should_continue: - await hook.after_iteration(context) + await center.emit(_make_after_iteration(context, iteration), session) continue if response.finish_reason == "error": @@ -477,7 +542,7 @@ async def run(self, spec: AgentRunSpec) -> AgentRunResult: context.final_content = final_content context.error = error context.stop_reason = stop_reason - await hook.after_iteration(context) + await center.emit(_make_after_iteration(context, iteration), session) should_continue, injection_cycles = await self._try_drain_injections( spec, messages, None, injection_cycles, phase="after LLM error", @@ -494,7 +559,7 @@ async def run(self, spec: AgentRunSpec) -> AgentRunResult: context.final_content = final_content context.error = error context.stop_reason = stop_reason - await hook.after_iteration(context) + await center.emit(_make_after_iteration(context, iteration), session) should_continue, injection_cycles = await self._try_drain_injections( spec, messages, None, injection_cycles, phase="after empty response", @@ -523,7 +588,7 @@ async def run(self, spec: AgentRunSpec) -> AgentRunResult: final_content = clean context.final_content = final_content context.stop_reason = stop_reason - await hook.after_iteration(context) + await center.emit(_make_after_iteration(context, iteration), session) break else: stop_reason = "max_iterations" @@ -549,6 +614,10 @@ async def run(self, spec: AgentRunSpec) -> AgentRunResult: ) if drained_after_max_iterations: had_injections = True + if context is not None: + context.final_content = final_content + context.stop_reason = stop_reason + await center.emit(_make_after_iteration(context, iteration), session) return AgentRunResult( final_content=final_content, @@ -587,7 +656,8 @@ async def _request_model( self, spec: AgentRunSpec, messages: list[dict[str, Any]], - hook: AgentHook, + center: HookCenter, + session: HookSession, context: AgentHookContext, ): timeout_s: float | None = spec.llm_timeout_s @@ -608,7 +678,7 @@ async def _request_model( messages, tools=spec.tools.get_definitions(), ) - wants_streaming = hook.wants_streaming() + wants_streaming = center.wants_streaming(session) wants_progress_streaming = ( not wants_streaming and spec.progress_callback is not None @@ -619,7 +689,7 @@ async def _request_model( async def _stream(delta: str) -> None: if delta: context.streamed_content = True - await hook.on_stream(context, delta) + await center.emit(OnStream(delta=delta, iteration=context.iteration), session) coro = self.provider.chat_stream_with_retry( **kwargs, diff --git a/nanobot/agent/subagent.py b/nanobot/agent/subagent.py index e64dc8f97c3..ee0106b454c 100644 --- a/nanobot/agent/subagent.py +++ b/nanobot/agent/subagent.py @@ -21,6 +21,8 @@ from nanobot.bus.events import InboundMessage from nanobot.bus.queue import MessageBus from nanobot.config.schema import AgentDefaults, ExecToolConfig, WebToolsConfig +from nanobot.hooks.adapters import adapt_agent_hook +from nanobot.hooks.center import HookCenter from nanobot.providers.base import LLMProvider from nanobot.utils.prompt_templates import render_template @@ -82,6 +84,7 @@ def __init__( restrict_to_workspace: bool = False, disabled_skills: list[str] | None = None, max_iterations: int | None = None, + hook_center: HookCenter | None = None, ): self.provider = provider self.workspace = workspace @@ -98,6 +101,7 @@ def __init__( else AgentDefaults().max_tool_iterations ) self.runner = AgentRunner(provider) + self._hook_center = hook_center self._running_tasks: dict[str, asyncio.Task[None]] = {} self._task_statuses: dict[str, SubagentStatus] = {} self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...} @@ -204,13 +208,18 @@ async def _on_checkpoint(payload: dict) -> None: {"role": "user", "content": task}, ] + center = self._hook_center or HookCenter() + hook_session = center.create_session() + subagent_hook = _SubagentHook(task_id, status) + adapt_agent_hook(subagent_hook, hook_session, center) result = await self.runner.run(AgentRunSpec( initial_messages=messages, tools=tools, model=self.model, max_iterations=self.max_iterations, max_tool_result_chars=self.max_tool_result_chars, - hook=_SubagentHook(task_id, status), + center=center, + session=hook_session, max_iterations_message="Task completed but no final response was generated.", error_message=None, fail_on_tool_error=True, diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 903555b4784..c3b6d13f6d4 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -540,6 +540,7 @@ def serve( consolidation_ratio=runtime_config.agents.defaults.consolidation_ratio, max_messages=runtime_config.agents.defaults.max_messages, tools_config=runtime_config.tools, + hooks_config=runtime_config.hooks, ) model_name = runtime_config.agents.defaults.model @@ -656,6 +657,7 @@ def _run_gateway( tools_config=config.tools, provider_snapshot_loader=load_provider_snapshot, provider_signature=provider_snapshot.signature, + hooks_config=config.hooks, ) from nanobot.agent.loop import UNIFIED_SESSION_KEY @@ -1047,6 +1049,7 @@ def agent( consolidation_ratio=config.agents.defaults.consolidation_ratio, max_messages=config.agents.defaults.max_messages, tools_config=config.tools, + hooks_config=config.hooks, ) restart_notice = consume_restart_notice_from_env() if restart_notice and should_show_cli_restart_notice(restart_notice, session_id): diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index a6c9d10c4dc..a48cdcb55a2 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -32,6 +32,17 @@ class ChannelsConfig(Base): transcription_language: str | None = Field(default=None, pattern=r"^[a-z]{2,3}$") # Optional ISO-639-1 hint for audio transcription +class HooksConfig(Base): + """Configuration for hook plugins. + + Default-deny: when ``enabled_plugins`` is ``None`` (the default), no + external hook plugins are loaded. Set to an explicit list of entry-point + names to opt in to loading discovered plugins. + """ + + enabled_plugins: list[str] | None = None + + class DreamConfig(Base): """Dream memory consolidation configuration.""" @@ -250,6 +261,7 @@ class Config(BaseSettings): agents: AgentsConfig = Field(default_factory=AgentsConfig) channels: ChannelsConfig = Field(default_factory=ChannelsConfig) + hooks: HooksConfig = Field(default_factory=HooksConfig) providers: ProvidersConfig = Field(default_factory=ProvidersConfig) api: ApiConfig = Field(default_factory=ApiConfig) gateway: GatewayConfig = Field(default_factory=GatewayConfig) diff --git a/nanobot/hooks/__init__.py b/nanobot/hooks/__init__.py new file mode 100644 index 00000000000..6b82a93fc8c --- /dev/null +++ b/nanobot/hooks/__init__.py @@ -0,0 +1,41 @@ +"""HookCenter — typed-event hook system for nanobot. + +Provides: +- Typed event dataclasses for agent lifecycle hooks +- Handler Protocol and return types (Modified, Deny) +- HookCenter registry and dispatch engine +- Entry-point plugin discovery +- AgentHook compatibility adapter +""" + +from nanobot.hooks.adapters import adapt_agent_hook, adapt_agent_hook_list +from nanobot.hooks.center import HookCenter, HookSession +from nanobot.hooks.discovery import discover_hook_plugins, register_discovered +from nanobot.hooks.event_types import ( + AfterIteration, + BeforeExecuteTools, + BeforeIteration, + FinalizeContent, + OnStream, + OnStreamEnd, +) +from nanobot.hooks.protocols import Deny, HookHandler, HookResult, Modified + +__all__ = [ + "AfterIteration", + "BeforeExecuteTools", + "BeforeIteration", + "Deny", + "FinalizeContent", + "HookCenter", + "HookHandler", + "HookResult", + "HookSession", + "Modified", + "OnStream", + "OnStreamEnd", + "adapt_agent_hook", + "adapt_agent_hook_list", + "discover_hook_plugins", + "register_discovered", +] diff --git a/nanobot/hooks/adapters.py b/nanobot/hooks/adapters.py new file mode 100644 index 00000000000..68563b28f4a --- /dev/null +++ b/nanobot/hooks/adapters.py @@ -0,0 +1,191 @@ +"""AgentHook to HookCenter compatibility adapter. + +Wraps legacy AgentHook subclasses as typed-event handlers registered +onto the per-dispatch HookSession. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook +from nanobot.hooks.event_types import ( + AfterIteration, + BeforeExecuteTools, + BeforeIteration, + FinalizeContent, + OnStream, + OnStreamEnd, +) +from nanobot.hooks.protocols import HookResult + +if TYPE_CHECKING: + from nanobot.hooks.center import HookCenter, HookSession + + +def _is_overridden(agent_hook: AgentHook, method_name: str) -> bool: + inst_method = getattr(agent_hook, method_name, None) + if inst_method is None: + return False + base_method = AgentHook.__dict__.get(method_name) + if base_method is None: + return False + inst_func = getattr(inst_method, "__func__", inst_method) + return inst_func is not base_method + + +def adapt_agent_hook( + agent_hook: AgentHook, + session: "HookSession", + center: "HookCenter", + *, + reraise: bool | None = None, +) -> None: + """Wrap a legacy AgentHook instance as typed-event handlers on *session*. + + Each non-default (overridden) method is converted to a handler + registered via ``center.register_internal``. The *reraise* flag + controls error propagation: when ``None`` (the default), the + adapter reads ``agent_hook._reraise`` and falls back to ``False``. + """ + if reraise is None: + reraise = getattr(agent_hook, "_reraise", False) + + ctx_cell: dict[str, AgentHookContext | None] = {"ctx": None} + + # ── before_iteration ────────────────────────────────────────── + if _is_overridden(agent_hook, "before_iteration"): + + async def _bi_wrapper(event: BeforeIteration) -> HookResult: + ctx = getattr(session, "context", None) + if ctx is None: + ctx = AgentHookContext(iteration=event.iteration, messages=event.messages) + ctx_cell["ctx"] = ctx + await agent_hook.before_iteration(ctx) + return None + + center.register_internal( + session, BeforeIteration, _bi_wrapper, reraise=reraise, mode="observe" + ) + + # ── on_stream ───────────────────────────────────────────────── + if _is_overridden(agent_hook, "on_stream"): + + async def _os_wrapper(event: OnStream) -> HookResult: + ctx = ctx_cell["ctx"] + if ctx is None: + ctx = getattr(session, "context", None) + if ctx is None: + ctx = AgentHookContext(iteration=event.iteration, messages=[]) + ctx_cell["ctx"] = ctx + await agent_hook.on_stream(ctx, event.delta) + return None + + center.register_internal( + session, OnStream, _os_wrapper, reraise=reraise, mode="observe", stream=False + ) + + # ── on_stream_end ───────────────────────────────────────────── + if _is_overridden(agent_hook, "on_stream_end"): + + async def _ose_wrapper(event: OnStreamEnd) -> HookResult: + ctx = ctx_cell["ctx"] + if ctx is None: + ctx = getattr(session, "context", None) + if ctx is None: + ctx = AgentHookContext(iteration=event.iteration, messages=[]) + ctx_cell["ctx"] = ctx + await agent_hook.on_stream_end(ctx, resuming=event.resuming) + return None + + center.register_internal( + session, OnStreamEnd, _ose_wrapper, reraise=reraise, mode="observe", stream=False + ) + + # ── before_execute_tools ────────────────────────────────────── + if _is_overridden(agent_hook, "before_execute_tools"): + + async def _bet_wrapper(event: BeforeExecuteTools) -> HookResult: + ctx = ctx_cell["ctx"] + if ctx is None: + ctx = getattr(session, "context", None) + if ctx is None: + ctx = AgentHookContext(iteration=event.iteration, messages=[]) + ctx_cell["ctx"] = ctx + ctx.tool_calls = list(event.tool_calls) + ctx.response = event.response + await agent_hook.before_execute_tools(ctx) + return None + + center.register_internal( + session, BeforeExecuteTools, _bet_wrapper, reraise=reraise, mode="observe" + ) + + # ── after_iteration ─────────────────────────────────────────── + if _is_overridden(agent_hook, "after_iteration"): + + async def _ai_wrapper(event: AfterIteration) -> HookResult: + ctx = ctx_cell["ctx"] + if ctx is None: + ctx = getattr(session, "context", None) + if ctx is None: + ctx = AgentHookContext(iteration=event.iteration, messages=[]) + ctx.final_content = event.final_content + ctx.stop_reason = event.stop_reason + ctx.usage = dict(event.usage) + ctx.tool_calls = list(event.tool_calls) + ctx.tool_events = list(event.tool_events) + ctx.tool_results = list(event.tool_results) + ctx.error = event.error + await agent_hook.after_iteration(ctx) + return None + + center.register_internal( + session, AfterIteration, _ai_wrapper, reraise=reraise, mode="observe" + ) + + # ── finalize_content ────────────────────────────────────────── + if _is_overridden(agent_hook, "finalize_content"): + + def _fc_wrapper(content: str | None) -> str | None: + ctx = ctx_cell["ctx"] + if ctx is None: + ctx = getattr(session, "context", None) + if ctx is None: + ctx = AgentHookContext(iteration=0, messages=[]) + return agent_hook.finalize_content(ctx, content) + + center.register_internal( + session, FinalizeContent, _fc_wrapper, reraise=reraise, mode="transform" + ) + + # ── wants_streaming ─────────────────────────────────────────── + if agent_hook.wants_streaming(): + + async def _ws_sentinel(_event: Any) -> None: + return None + + session.wants_streaming_handlers.add(_ws_sentinel) + + +def adapt_agent_hook_list( + hooks: list[AgentHook], + session: "HookSession", + center: "HookCenter", +) -> None: + """Adapt a list of hooks, flattening CompositeHook instances. + + CompositeHook instances are recursively expanded into their + ``_hooks`` children so the adapter wires each leaf directly + rather than producing a double-layer fan-out. + """ + + def _flatten(hook_list: list[AgentHook]): + for h in hook_list: + if isinstance(h, CompositeHook): + yield from _flatten(h._hooks) + else: + yield h + + for leaf in _flatten(hooks): + adapt_agent_hook(leaf, session, center) diff --git a/nanobot/hooks/center.py b/nanobot/hooks/center.py new file mode 100644 index 00000000000..3e8042a967f --- /dev/null +++ b/nanobot/hooks/center.py @@ -0,0 +1,227 @@ +"""HookCenter — typed-event registry and dispatch engine. + +Guards, transforms, and observes are dispatched in strict order. +Internal handlers (session) run before external handlers (global). +""" + +from __future__ import annotations + +import inspect +from dataclasses import dataclass, field +from typing import Any + +from loguru import logger + +from nanobot.hooks.event_types import FinalizeContent as FinalizeContentEvent +from nanobot.hooks.event_types import OnStream, OnStreamEnd +from nanobot.hooks.protocols import Deny, HookHandler, HookResult, Modified + +_STREAMING_EVENT_TYPES = (OnStream, OnStreamEnd) +_VALID_MODES = {"guard", "transform", "observe"} + + +@dataclass(slots=True) +class HookSession: + internal_handlers: dict[type, dict[str, list[tuple[HookHandler, bool]]]] = field( + default_factory=dict + ) + wants_streaming_handlers: set[HookHandler] = field(default_factory=set) + finalize_handlers: list[tuple[HookHandler, bool]] = field(default_factory=list) + context: Any = field(default=None, init=False) + """Placeholder for runner-provided AgentHookContext. + + Set by AgentRunner.run() before each iteration. Adapter wrappers read + this reference to share the runner's mutable context object with + legacy AgentHook subclasses. + """ + + +class HookCenter: + __slots__ = ("_external_handlers", "_streaming_plugins") + + def __init__(self) -> None: + self._external_handlers: dict[type, dict[str, list[HookHandler]]] = {} + self._streaming_plugins: set[HookHandler] = set() + + # ------------------------------------------------------------------ + # registry + # ------------------------------------------------------------------ + + def register(self, event_type: type, handler: HookHandler, mode: str) -> None: + if mode not in _VALID_MODES: + raise ValueError(f"Unknown mode {mode!r}; expected one of {_VALID_MODES}") + if event_type not in self._external_handlers: + self._external_handlers[event_type] = {"guard": [], "transform": [], "observe": []} + group = self._external_handlers[event_type][mode] + if handler not in group: + group.append(handler) + + def register_internal( + self, + session: HookSession, + event_type: type, + handler: HookHandler, + *, + reraise: bool = False, + mode: str = "observe", + stream: bool = True, + ) -> None: + if mode not in _VALID_MODES: + raise ValueError(f"Unknown mode {mode!r}; expected one of {_VALID_MODES}") + session.internal_handlers.setdefault(event_type, {"guard": [], "transform": [], "observe": []}) + group = session.internal_handlers[event_type][mode] + item = (handler, reraise) + if item not in group: + group.append(item) + if stream and event_type in _STREAMING_EVENT_TYPES: + session.wants_streaming_handlers.add(handler) + if event_type is FinalizeContentEvent: + session.finalize_handlers.append((handler, reraise)) + + def create_session(self) -> HookSession: + return HookSession() + + # ------------------------------------------------------------------ + # dispatch + # ------------------------------------------------------------------ + + async def emit(self, event: Any, session: HookSession) -> HookResult: + event_type = type(event) + + internal = session.internal_handlers.get(event_type, {}) + external = self._external_handlers.get(event_type, {}) + + # guards: internal first, then external + for handler, reraise in internal.get("guard", []): + result: HookResult = await self._invoke_handler(handler, event, reraise) + if isinstance(result, Deny): + return result + for handler in external.get("guard", []): + result = await self._invoke_handler(handler, event, reraise=False) + if isinstance(result, Deny): + return result + + # transforms: internal first, then external + for handler, reraise in internal.get("transform", []): + result = await self._invoke_handler(handler, event, reraise) + if isinstance(result, Modified): + event = self._apply_modified(event, result) + for handler in external.get("transform", []): + result = await self._invoke_handler(handler, event, reraise=False) + if isinstance(result, Modified): + event = self._apply_modified(event, result) + + # observes: internal first, then external + for handler, reraise in internal.get("observe", []): + await self._invoke_handler(handler, event, reraise) + for handler in external.get("observe", []): + await self._invoke_handler(handler, event, reraise=False) + + return None + + def wants_streaming(self, session: HookSession) -> bool: + if session.wants_streaming_handlers: + return True + if self._streaming_plugins: + return True + for et in _STREAMING_EVENT_TYPES: + if self._external_handlers.get(et): + return True + return False + + def finalize_content(self, content: str | None, session: HookSession) -> str | None: + for handler, reraise in session.finalize_handlers: + result, content = self._call_finalize_handler(handler, content, reraise) + if isinstance(result, Deny): + return content + + for handler in self._external_handlers.get(FinalizeContentEvent, {}).get("transform", []): + result, content = self._call_finalize_handler(handler, content, reraise=False) + if isinstance(result, Deny): + return content + + return content + + # ------------------------------------------------------------------ + # helpers + # ------------------------------------------------------------------ + + @staticmethod + async def _invoke_handler( + handler: HookHandler, event: Any, reraise: bool + ) -> HookResult: + async def _call() -> HookResult: + result = handler(event) + if inspect.isawaitable(result): + result = await result + return result + + if reraise: + return await _call() + try: + return await _call() + except Exception: + logger.exception( + "HookCenter handler {} error in event {}", + type(handler).__name__, + type(event).__name__, + ) + return None + + @staticmethod + def _apply_modified(event: Any, modified: Modified) -> Any: + data = modified.data + if not isinstance(data, dict): + logger.warning( + "Transform handler returned non-dict Modified.data ({}) — " + "discarding, downstream handlers receive original event", + type(data).__name__, + ) + return event + for key, value in data.items(): + if hasattr(event, key): + setattr(event, key, value) + else: + logger.warning( + "Modified.data key {!r} not found on event type {} — typo?", + key, type(event).__name__, + ) + return event + + @staticmethod + def _call_finalize_handler( + handler: Any, content: str | None, reraise: bool + ) -> tuple[HookResult, str | None]: + try: + result = handler(content) + if inspect.isawaitable(result): + logger.warning( + "FinalizeContent handler {} returned coroutine — " + "finalize_content handlers must be synchronous; discarding result", + type(handler).__name__, + ) + return None, content + except Exception: + if reraise: + raise + logger.exception("HookCenter finalize_content error in {}", type(handler).__name__) + return None, content + if isinstance(result, Modified): + return result, result.data + if isinstance(result, Deny): + return result, content + if result is None: + return None, content + return None, result + + # ------------------------------------------------------------------ + # lifecycle + # ------------------------------------------------------------------ + + def reset(self) -> None: + self._external_handlers.clear() + + def discover(self, config: Any = None) -> None: + from nanobot.hooks.discovery import register_discovered + + register_discovered(self, config) diff --git a/nanobot/hooks/discovery.py b/nanobot/hooks/discovery.py new file mode 100644 index 00000000000..d135fde208e --- /dev/null +++ b/nanobot/hooks/discovery.py @@ -0,0 +1,86 @@ +"""Entry-point plugin discovery for HookCenter. + +Plugin contract: + Plugin objects expose ``hook_events: list[tuple[type, str]]`` — + list of (event_type, mode) tuples declaring which hook points + the plugin subscribes to. + + Plugin objects may expose ``hook_streaming: bool`` flag for + wants_streaming indication (session-independent streaming). + + Plugin module-level code executes at ``ep.load()`` time. +""" + +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor, TimeoutError +from typing import TYPE_CHECKING, Any + +from loguru import logger + +if TYPE_CHECKING: + from nanobot.hooks.center import HookCenter + from nanobot.hooks.protocols import HookHandler + +_PLUGIN_LOAD_TIMEOUT = 10 # seconds per plugin + + +def discover_hook_plugins(enabled: list[str] | None = None) -> dict[str, "HookHandler"]: + """Scan ``entry_points(group="nanobot.hooks")`` and return ``{name: handler}``. + + **Default-deny**: when *enabled* is ``None`` (the default), no plugins + are loaded. Callers must pass an explicit allowlist to opt in to + loading discovered entry points. ``ep.load()`` is called *after* the + allowlist check so blocked plugins never execute module-level code. + Each entry point is loaded independently — a single failed plugin does + not prevent other plugins from being discovered. + """ + from importlib.metadata import entry_points + + plugins: dict[str, HookHandler] = {} + for ep in entry_points(group="nanobot.hooks"): + if enabled is None or ep.name not in enabled: + logger.info("Hook plugin '{}' not in enabled_plugins, skipping", ep.name) + continue + try: + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(ep.load) + handler = future.result(timeout=_PLUGIN_LOAD_TIMEOUT) + plugins[ep.name] = handler + except TimeoutError: + logger.warning("Hook plugin '{}' load timed out after {}s, skipping", ep.name, _PLUGIN_LOAD_TIMEOUT) + except Exception: + logger.warning("Failed to load hook plugin '{}'", ep.name) + return plugins + + +def register_discovered(center: "HookCenter", config: Any = None) -> None: + """Discover external hook plugins and register them into *center*. + + Filters by ``config.hooks.enabled_plugins`` allowlist when present. + The allowlist is enforced *before* ``ep.load()`` so blocked plugins + never execute their module-level code. Each plugin must expose + ``hook_events: list[tuple[type, str]]``. + """ + enabled = getattr(getattr(config, "hooks", None), "enabled_plugins", None) + + try: + discovered = discover_hook_plugins(enabled=enabled) + except Exception: + logger.warning("entry_points discovery failed, continuing with core hook handlers only") + return + + for name, handler in discovered.items(): + hook_events = getattr(handler, "hook_events", []) + if not hook_events: + logger.warning("Hook plugin '{}' has no hook_events, skipping", name) + continue + + for event_type, mode in hook_events: + center.register(event_type, handler, mode) + + if getattr(handler, "hook_streaming", False): + center._streaming_plugins.add(handler) + logger.info("Hook plugin '{}' enabled streaming", name) + + logger.debug("Registered hook plugin '{}' with {} events", name, len(hook_events)) diff --git a/nanobot/hooks/event_types.py b/nanobot/hooks/event_types.py new file mode 100644 index 00000000000..2acfa152f91 --- /dev/null +++ b/nanobot/hooks/event_types.py @@ -0,0 +1,62 @@ +"""Typed event dataclasses for HookCenter lifecycle events. + +Each hook point in the agent lifecycle is represented by a typed dataclass. +Handlers subscribe by event type; HookCenter dispatches by ``type(event)``. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from nanobot.providers.base import LLMResponse, ToolCallRequest + + +@dataclass(slots=True) +class BeforeIteration: + iteration: int + messages: list[dict[str, Any]] + + +@dataclass(slots=True) +class OnStream: + delta: str + iteration: int + + +@dataclass(slots=True) +class OnStreamEnd: + resuming: bool + iteration: int + + +@dataclass(slots=True) +class BeforeExecuteTools: + iteration: int + tool_calls: list["ToolCallRequest"] + response: "LLMResponse | None" = None + + +@dataclass(slots=True) +class AfterIteration: + iteration: int + final_content: str | None = None + stop_reason: str | None = None + usage: dict[str, int] = field(default_factory=dict) + tool_calls: list["ToolCallRequest"] = field(default_factory=list) + tool_events: list[dict[str, str]] = field(default_factory=list) + tool_results: list[Any] = field(default_factory=list) + error: str | None = None + + +@dataclass(slots=True) +class FinalizeContent: + """Registration-only marker for finalize_content pipeline handlers. + + Not dispatched through emit(). HookCenter.finalize_content() collects + handlers registered under this type and runs them as a sync pipeline. + This dataclass carries no fields — it exists purely as a type key for + handler registration. + """ + pass diff --git a/nanobot/hooks/protocols.py b/nanobot/hooks/protocols.py new file mode 100644 index 00000000000..4531ff84107 --- /dev/null +++ b/nanobot/hooks/protocols.py @@ -0,0 +1,52 @@ +"""HookCenter handler protocol and return types.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Protocol + + +@dataclass(slots=True) +class Modified: + """Returned by transform-mode handlers to indicate data was modified. + + The ``data`` field carries the transformed value that replaces the + original event data in the pipeline. + """ + + data: Any + + +@dataclass(slots=True) +class Deny: + """Returned by guard-mode handlers to block an operation. + + ``reason`` is a human-readable string explaining the denial. + The caller (emit site) decides how to act on it — soft-deny + (inject reason into the conversation) or hard-deny (abort the + agent loop entirely). + + When ``abort`` is ``False`` (default), the operation is denied + but the agent loop continues — the reason is injected as a tool + result so the LLM can adapt. When ``abort`` is ``True``, the + agent loop terminates immediately and ``reason`` becomes the + final content returned to the caller. + """ + + reason: str + abort: bool = False + + +HookResult = Modified | Deny | None + + +class HookHandler(Protocol): + """Protocol for hook handlers. + + Handlers accept an event dataclass and may return: + - ``None``: observe only, no action needed + - ``Modified(data)``: the event data was transformed + - ``Deny(reason)``: the operation is denied + """ + + async def __call__(self, event: Any) -> HookResult: ... diff --git a/nanobot/nanobot.py b/nanobot/nanobot.py index d2bff97d765..e723b819aa3 100644 --- a/nanobot/nanobot.py +++ b/nanobot/nanobot.py @@ -86,6 +86,7 @@ def from_config( session_ttl_minutes=defaults.session_ttl_minutes, consolidation_ratio=defaults.consolidation_ratio, tools_config=config.tools, + hooks_config=config.hooks, ) return cls(loop) diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index 86ec18b8c87..b5e13873db6 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -10,9 +10,13 @@ import pytest -from nanobot.config.schema import AgentDefaults from nanobot.agent.tools.base import Tool from nanobot.agent.tools.registry import ToolRegistry +from nanobot.config.schema import AgentDefaults +from nanobot.hooks.adapters import adapt_agent_hook +from nanobot.hooks.center import HookCenter +from nanobot.hooks.event_types import BeforeExecuteTools, BeforeIteration +from nanobot.hooks.protocols import Deny from nanobot.providers.base import LLMResponse, ToolCallRequest _MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars @@ -46,7 +50,7 @@ def _make_loop(tmp_path): @pytest.mark.asyncio async def test_runner_preserves_reasoning_fields_and_tool_results(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() captured_second_call: list[dict] = [] @@ -104,7 +108,7 @@ async def chat_with_retry(*, messages, **kwargs): @pytest.mark.asyncio async def test_runner_calls_hooks_in_order(): from nanobot.agent.hook import AgentHook, AgentHookContext - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() call_count = {"n": 0} @@ -150,13 +154,17 @@ def finalize_content(self, context: AgentHookContext, content: str | None) -> st return content.upper() if content else content runner = AgentRunner(provider) + center = HookCenter() + session = center.create_session() + adapt_agent_hook(RecordingHook(), session, center) result = await runner.run(AgentRunSpec( initial_messages=[], tools=tools, model="test-model", max_iterations=3, max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - hook=RecordingHook(), + center=center, + session=session, )) assert result.final_content == "DONE" @@ -180,7 +188,7 @@ def finalize_content(self, context: AgentHookContext, content: str | None) -> st @pytest.mark.asyncio async def test_runner_streaming_hook_receives_deltas_and_end_signal(): from nanobot.agent.hook import AgentHook, AgentHookContext - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() streamed: list[str] = [] @@ -207,13 +215,17 @@ async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> N endings.append(resuming) runner = AgentRunner(provider) + center = HookCenter() + session = center.create_session() + adapt_agent_hook(StreamingHook(), session, center) result = await runner.run(AgentRunSpec( initial_messages=[], tools=tools, model="test-model", max_iterations=1, max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - hook=StreamingHook(), + center=center, + session=session, )) assert result.final_content == "hello" @@ -222,9 +234,181 @@ async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> N provider.chat_with_retry.assert_not_awaited() +# --------------------------------------------------------------------------- +# Deny.abort — soft deny (continue) vs hard abort (break loop) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_runner_before_execute_tools_soft_deny_continues_loop(): + from nanobot.agent.runner import AgentRunner, AgentRunSpec + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="", + tool_calls=[ToolCallRequest(id="c1", name="list_dir", arguments={"path": "."})], + finish_reason="tool_calls", + usage={}, + ) + return LLMResponse(content="recovered", finish_reason="stop", usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + center = HookCenter() + session = center.create_session() + center.register(BeforeExecuteTools, lambda e: Deny(reason="not allowed"), mode="guard") + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hi"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + center=center, + session=session, + )) + + assert result.final_content == "recovered" + assert result.stop_reason == "completed" + assert call_count["n"] == 2 + + +@pytest.mark.asyncio +async def test_runner_before_execute_tools_abort_deny_breaks_loop(): + from nanobot.agent.runner import AgentRunner, AgentRunSpec + + provider = MagicMock() + + async def chat_with_retry(*, messages, **kwargs): + return LLMResponse( + content="", + tool_calls=[ToolCallRequest(id="c1", name="list_dir", arguments={"path": "."})], + finish_reason="tool_calls", + usage={}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + center = HookCenter() + session = center.create_session() + center.register( + BeforeExecuteTools, + lambda e: Deny(reason="security policy violation", abort=True), + mode="guard", + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hi"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + center=center, + session=session, + )) + + assert result.final_content == "security policy violation" + assert result.stop_reason == "aborted" + + +@pytest.mark.asyncio +async def test_runner_before_iteration_abort_deny_breaks_loop(): + from nanobot.agent.runner import AgentRunner, AgentRunSpec + + provider = MagicMock() + provider.chat_with_retry = AsyncMock( + return_value=LLMResponse(content="should not reach", tool_calls=[], usage={}), + ) + tools = MagicMock() + tools.get_definitions.return_value = [] + + center = HookCenter() + session = center.create_session() + center.register( + BeforeIteration, + lambda e: Deny(reason="iteration blocked", abort=True), + mode="guard", + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hi"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + center=center, + session=session, + )) + + assert result.final_content == "iteration blocked" + assert result.stop_reason == "aborted" + provider.chat_with_retry.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_runner_before_iteration_soft_deny_continues_loop(): + from nanobot.agent.runner import AgentRunner, AgentRunSpec + + provider = MagicMock() + call_count = {"n": 0} + deny_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse(content="recovered", tool_calls=[], usage={}) + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + async def conditional_guard(event): + deny_count["n"] += 1 + if deny_count["n"] == 1: + return Deny(reason="rate limited") + return None + + center = HookCenter() + session = center.create_session() + center.register( + BeforeIteration, + conditional_guard, + mode="guard", + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hi"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + center=center, + session=session, + )) + + assert result.final_content == "recovered" + assert result.stop_reason == "completed" + assert call_count["n"] == 1 + assert deny_count["n"] == 2 # fired on both iterations + + @pytest.mark.asyncio async def test_runner_returns_max_iterations_fallback(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() provider.chat_with_retry = AsyncMock(return_value=LLMResponse( @@ -255,7 +439,7 @@ async def test_runner_returns_max_iterations_fallback(): @pytest.mark.asyncio async def test_runner_times_out_hung_llm_request(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() @@ -283,7 +467,7 @@ async def chat_with_retry(**kwargs): @pytest.mark.asyncio async def test_runner_returns_structured_tool_error(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() provider.chat_with_retry = AsyncMock(return_value=LLMResponse( @@ -314,7 +498,7 @@ async def test_runner_returns_structured_tool_error(): @pytest.mark.asyncio async def test_runner_stops_on_workspace_violation_without_fail_on_tool_error(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() provider.chat_with_retry = AsyncMock(side_effect=[ @@ -354,7 +538,7 @@ async def test_runner_stops_on_workspace_violation_without_fail_on_tool_error(): @pytest.mark.asyncio async def test_runner_persists_large_tool_results_for_follow_up_calls(tmp_path): - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() captured_second_call: list[dict] = [] @@ -467,7 +651,7 @@ def test_persist_tool_result_logs_cleanup_failures(monkeypatch, tmp_path): @pytest.mark.asyncio async def test_runner_replaces_empty_tool_result_with_marker(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() captured_second_call: list[dict] = [] @@ -505,7 +689,7 @@ async def chat_with_retry(*, messages, **kwargs): @pytest.mark.asyncio async def test_runner_uses_raw_messages_when_context_governance_fails(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() captured_messages: list[dict] = [] @@ -539,7 +723,7 @@ async def chat_with_retry(*, messages, **kwargs): @pytest.mark.asyncio async def test_runner_retries_empty_final_response_with_summary_prompt(): """Empty responses get 2 silent retries before finalization kicks in.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() calls: list[dict] = [] @@ -584,7 +768,7 @@ async def chat_with_retry(*, messages, tools=None, **kwargs): @pytest.mark.asyncio async def test_runner_uses_specific_message_after_empty_finalization_retry(): """After silent retries + finalization all return empty, stop_reason is empty_final_response.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE provider = MagicMock() @@ -616,7 +800,7 @@ async def test_runner_empty_response_does_not_break_tool_chain(): Sequence: tool_call → empty → tool_call → final text. The runner should recover via silent retry and complete normally. """ - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() call_count = 0 @@ -670,7 +854,7 @@ async def fake_tool(name, args, **kw): def test_snip_history_drops_orphaned_tool_results_from_trimmed_slice(monkeypatch): - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() tools = MagicMock() @@ -721,7 +905,7 @@ def test_snip_history_drops_orphaned_tool_results_from_trimmed_slice(monkeypatch @pytest.mark.asyncio async def test_runner_keeps_going_when_tool_result_persistence_fails(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() captured_second_call: list[dict] = [] @@ -803,7 +987,7 @@ async def execute(self, **kwargs): @pytest.mark.asyncio async def test_runner_batches_read_only_tools_before_exclusive_work(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec tools = ToolRegistry() shared_events: list[str] = [] @@ -841,7 +1025,7 @@ async def test_runner_batches_read_only_tools_before_exclusive_work(): @pytest.mark.asyncio async def test_runner_does_not_batch_exclusive_read_only_tools(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec tools = ToolRegistry() shared_events: list[str] = [] @@ -883,7 +1067,7 @@ async def test_runner_does_not_batch_exclusive_read_only_tools(): @pytest.mark.asyncio async def test_runner_blocks_repeated_external_fetches(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() captured_final_call: list[dict] = [] @@ -996,9 +1180,9 @@ async def test_llm_error_not_appended_to_session_messages(): """When LLM returns finish_reason='error', the error content must NOT be appended to the messages list (prevents polluting session history).""" from nanobot.agent.runner import ( - AgentRunSpec, - AgentRunner, _PERSISTED_MODEL_ERROR_PLACEHOLDER, + AgentRunner, + AgentRunSpec, ) provider = MagicMock() @@ -1110,7 +1294,7 @@ async def test_next_turn_after_llm_error_keeps_turn_boundary(tmp_path): @pytest.mark.asyncio async def test_runner_tool_error_sets_final_content(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() @@ -1178,7 +1362,7 @@ async def fake_execute(self, **kwargs): async def test_runner_accumulates_usage_and_preserves_cached_tokens(): """Runner should accumulate prompt/completion tokens across iterations and preserve cached_tokens from provider responses.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() call_count = {"n": 0} @@ -1221,7 +1405,7 @@ async def chat_with_retry(*, messages, **kwargs): async def test_runner_passes_cached_tokens_to_hook_context(): """Hook context.usage should contain cached_tokens.""" from nanobot.agent.hook import AgentHook, AgentHookContext - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() captured_usage: list[dict] = [] @@ -1242,13 +1426,17 @@ async def chat_with_retry(**kwargs): tools.get_definitions.return_value = [] runner = AgentRunner(provider) + center = HookCenter() + session = center.create_session() + adapt_agent_hook(UsageHook(), session, center) await runner.run(AgentRunSpec( initial_messages=[], tools=tools, model="test-model", max_iterations=1, max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - hook=UsageHook(), + center=center, + session=session, )) assert len(captured_usage) == 1 @@ -1264,7 +1452,7 @@ async def chat_with_retry(**kwargs): async def test_length_recovery_continues_from_truncated_output(): """When finish_reason is 'length', runner should insert a continuation prompt and retry, stitching partial outputs into the final result.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() call_count = {"n": 0} @@ -1304,7 +1492,7 @@ async def test_length_recovery_streaming_calls_on_stream_end_with_resuming(): """During length recovery with streaming, on_stream_end should be called with resuming=True so the hook knows the conversation is continuing.""" from nanobot.agent.hook import AgentHook, AgentHookContext - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() call_count = {"n": 0} @@ -1331,13 +1519,17 @@ async def chat_stream_with_retry(*, messages, on_content_delta=None, **kwargs): tools.get_definitions.return_value = [] runner = AgentRunner(provider) + center = HookCenter() + session = center.create_session() + adapt_agent_hook(StreamHook(), session, center) await runner.run(AgentRunSpec( initial_messages=[{"role": "user", "content": "go"}], tools=tools, model="test-model", max_iterations=10, max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - hook=StreamHook(), + center=center, + session=session, )) assert len(stream_end_calls) == 2 @@ -1348,7 +1540,7 @@ async def chat_stream_with_retry(*, messages, on_content_delta=None, **kwargs): @pytest.mark.asyncio async def test_length_recovery_gives_up_after_max_retries(): """After _MAX_LENGTH_RECOVERIES attempts the runner should stop retrying.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_LENGTH_RECOVERIES + from nanobot.agent.runner import _MAX_LENGTH_RECOVERIES, AgentRunner, AgentRunSpec provider = MagicMock() call_count = {"n": 0} @@ -1386,7 +1578,7 @@ async def chat_with_retry(*, messages, **kwargs): @pytest.mark.asyncio async def test_backfill_missing_tool_results_inserts_error(): """Orphaned tool_use (no matching tool_result) should get a synthetic error.""" - from nanobot.agent.runner import AgentRunner, _BACKFILL_CONTENT + from nanobot.agent.runner import _BACKFILL_CONTENT, AgentRunner messages = [ {"role": "user", "content": "hi"}, @@ -1467,7 +1659,7 @@ async def test_backfill_noop_when_complete(): @pytest.mark.asyncio async def test_runner_drops_orphan_tool_results_before_model_request(): - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() captured_messages: list[dict] = [] @@ -1592,7 +1784,7 @@ async def test_backfill_repairs_model_context_without_shifting_save_turn_boundar @pytest.mark.asyncio async def test_runner_backfill_only_mutates_model_context_not_returned_messages(): """Runner should repair orphaned tool calls for the model without rewriting result.messages.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner, _BACKFILL_CONTENT + from nanobot.agent.runner import _BACKFILL_CONTENT, AgentRunner, AgentRunSpec provider = MagicMock() captured_messages: list[dict] = [] @@ -1675,7 +1867,7 @@ async def chat_with_retry(*, messages, **kwargs): @pytest.mark.asyncio async def test_microcompact_replaces_old_tool_results(): """Tool results beyond _MICROCOMPACT_KEEP_RECENT should be summarized.""" - from nanobot.agent.runner import AgentRunner, _MICROCOMPACT_KEEP_RECENT + from nanobot.agent.runner import _MICROCOMPACT_KEEP_RECENT, AgentRunner total = _MICROCOMPACT_KEEP_RECENT + 5 long_content = "x" * 600 @@ -1703,7 +1895,7 @@ async def test_microcompact_replaces_old_tool_results(): @pytest.mark.asyncio async def test_microcompact_preserves_short_results(): """Short tool results (< _MICROCOMPACT_MIN_CHARS) should not be replaced.""" - from nanobot.agent.runner import AgentRunner, _MICROCOMPACT_KEEP_RECENT + from nanobot.agent.runner import _MICROCOMPACT_KEEP_RECENT, AgentRunner total = _MICROCOMPACT_KEEP_RECENT + 5 messages: list[dict] = [] @@ -1725,7 +1917,7 @@ async def test_microcompact_preserves_short_results(): @pytest.mark.asyncio async def test_microcompact_skips_non_compactable_tools(): """Non-compactable tools (e.g. 'message') should never be replaced.""" - from nanobot.agent.runner import AgentRunner, _MICROCOMPACT_KEEP_RECENT + from nanobot.agent.runner import _MICROCOMPACT_KEEP_RECENT, AgentRunner total = _MICROCOMPACT_KEEP_RECENT + 5 long_content = "y" * 1000 @@ -1749,7 +1941,7 @@ async def test_microcompact_skips_non_compactable_tools(): async def test_runner_tool_error_preserves_tool_results_in_messages(): """When a tool raises a fatal error, its results must still be appended to messages so the session never contains orphan tool_calls (#2943).""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() @@ -1864,7 +2056,7 @@ def test_governance_fallback_still_repairs_orphans(): @pytest.mark.asyncio async def test_drain_injections_returns_empty_when_no_callback(): """No injection_callback → empty list.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() runner = AgentRunner(provider) @@ -1882,7 +2074,7 @@ async def test_drain_injections_returns_empty_when_no_callback(): @pytest.mark.asyncio async def test_drain_injections_extracts_content_from_inbound_messages(): """Should extract .content from InboundMessage objects.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec from nanobot.bus.events import InboundMessage provider = MagicMock() @@ -1913,7 +2105,7 @@ async def cb(): @pytest.mark.asyncio async def test_drain_injections_passes_limit_to_callback_when_supported(): """Limit-aware callbacks can preserve overflow in their own queue.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN + from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec from nanobot.bus.events import InboundMessage provider = MagicMock() @@ -1948,7 +2140,7 @@ async def cb(*, limit: int): @pytest.mark.asyncio async def test_drain_injections_skips_empty_content(): """Messages with blank content should be filtered out.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec from nanobot.bus.events import InboundMessage provider = MagicMock() @@ -1977,7 +2169,7 @@ async def cb(): @pytest.mark.asyncio async def test_drain_injections_handles_callback_exception(): """If the callback raises, return empty list (error is logged).""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() runner = AgentRunner(provider) @@ -1999,7 +2191,7 @@ async def cb(): @pytest.mark.asyncio async def test_checkpoint1_injects_after_tool_execution(): """Follow-up messages are injected after tool execution, before next LLM call.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec from nanobot.bus.events import InboundMessage provider = MagicMock() @@ -2052,8 +2244,8 @@ async def chat_with_retry(*, messages, **kwargs): @pytest.mark.asyncio async def test_checkpoint2_injects_after_final_response_with_resuming_stream(): """After final response, if injections exist, stream_end should get resuming=True.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner from nanobot.agent.hook import AgentHook, AgentHookContext + from nanobot.agent.runner import AgentRunner, AgentRunSpec from nanobot.bus.events import InboundMessage provider = MagicMock() @@ -2089,13 +2281,17 @@ async def chat_stream_with_retry(*, messages, on_content_delta=None, **kwargs): ) runner = AgentRunner(provider) + center = HookCenter() + session = center.create_session() + adapt_agent_hook(TrackingHook(), session, center) result = await runner.run(AgentRunSpec( initial_messages=[{"role": "user", "content": "hello"}], tools=tools, model="test-model", max_iterations=5, max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, - hook=TrackingHook(), + center=center, + session=session, injection_callback=inject_cb, )) @@ -2111,7 +2307,7 @@ async def chat_stream_with_retry(*, messages, on_content_delta=None, **kwargs): @pytest.mark.asyncio async def test_checkpoint2_preserves_final_response_in_history_before_followup(): """A follow-up injected after a final answer must still see that answer in history.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec from nanobot.bus.events import InboundMessage provider = MagicMock() @@ -2226,7 +2422,7 @@ async def chat_with_retry(*, messages, **kwargs): @pytest.mark.asyncio async def test_runner_merges_multiple_injected_user_messages_without_losing_media(): """Multiple injected follow-ups should not create lossy consecutive user messages.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() call_count = {"n": 0} @@ -2289,7 +2485,7 @@ async def inject_cb(): @pytest.mark.asyncio async def test_injection_cycles_capped_at_max(): """Injection cycles should be capped at _MAX_INJECTION_CYCLES.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTION_CYCLES + from nanobot.agent.runner import _MAX_INJECTION_CYCLES, AgentRunner, AgentRunSpec from nanobot.bus.events import InboundMessage provider = MagicMock() @@ -2330,7 +2526,7 @@ async def inject_cb(): @pytest.mark.asyncio async def test_no_injections_flag_is_false_by_default(): """had_injections should be False when no injection callback or no messages.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() @@ -2410,9 +2606,9 @@ async def test_followup_routed_to_pending_queue(tmp_path): async def test_pending_queue_preserves_overflow_for_next_injection_cycle(tmp_path): """Pending queue should leave overflow messages queued for later drains.""" from nanobot.agent.loop import AgentLoop + from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN from nanobot.bus.events import InboundMessage from nanobot.bus.queue import MessageBus - from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN bus = MessageBus() provider = MagicMock() @@ -2535,7 +2731,7 @@ async def test_dispatch_republishes_leftover_queue_messages(tmp_path): @pytest.mark.asyncio async def test_drain_injections_on_fatal_tool_error(): """Pending injections should be drained even when a fatal tool error occurs.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec from nanobot.bus.events import InboundMessage provider = MagicMock() @@ -2588,7 +2784,7 @@ async def chat_with_retry(*, messages, **kwargs): @pytest.mark.asyncio async def test_drain_injections_on_llm_error(): """Pending injections should be drained when the LLM returns an error finish_reason.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec from nanobot.bus.events import InboundMessage provider = MagicMock() @@ -2643,7 +2839,7 @@ async def chat_with_retry(*, messages, **kwargs): @pytest.mark.asyncio async def test_drain_injections_on_empty_final_response(): """Pending injections should be drained when the runner exits due to empty response.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_EMPTY_RETRIES + from nanobot.agent.runner import _MAX_EMPTY_RETRIES, AgentRunner, AgentRunSpec from nanobot.bus.events import InboundMessage provider = MagicMock() @@ -2698,7 +2894,7 @@ async def test_drain_injections_on_max_iterations(): injections are appended to messages but not processed by the LLM. The key point is they are consumed from the queue to prevent re-publish. """ - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec from nanobot.bus.events import InboundMessage provider = MagicMock() @@ -2750,7 +2946,7 @@ async def chat_with_retry(*, messages, **kwargs): async def test_drain_injections_set_flag_when_followup_arrives_after_last_iteration(): """Late follow-ups drained in max_iterations should still flip had_injections.""" from nanobot.agent.hook import AgentHook - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec from nanobot.bus.events import InboundMessage provider = MagicMock() @@ -2789,6 +2985,9 @@ async def after_iteration(self, context) -> None: ) runner = AgentRunner(provider) + center = HookCenter() + session = center.create_session() + adapt_agent_hook(InjectOnLastAfterIterationHook(), session, center) result = await runner.run(AgentRunSpec( initial_messages=[{"role": "user", "content": "hello"}], tools=tools, @@ -2796,7 +2995,8 @@ async def after_iteration(self, context) -> None: max_iterations=2, max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, injection_callback=inject_cb, - hook=InjectOnLastAfterIterationHook(), + center=center, + session=session, )) assert result.stop_reason == "max_iterations" @@ -2812,7 +3012,7 @@ async def after_iteration(self, context) -> None: @pytest.mark.asyncio async def test_injection_cycle_cap_on_error_path(): """Injection cycles should be capped even when every iteration hits an LLM error.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTION_CYCLES + from nanobot.agent.runner import _MAX_INJECTION_CYCLES, AgentRunner, AgentRunSpec from nanobot.bus.events import InboundMessage provider = MagicMock() @@ -2876,7 +3076,7 @@ def test_snip_history_preserves_user_message_after_truncation(monkeypatch): - _snip_history activates, keeping only recent assistant/tool pairs. - The injected user message is in the truncated prefix and gets lost. """ - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() tools = MagicMock() @@ -2940,7 +3140,7 @@ def test_snip_history_preserves_user_message_after_truncation(monkeypatch): def test_snip_history_no_user_at_all_falls_back_gracefully(monkeypatch): """Edge case: if non_system has zero user messages, _snip_history should still return a valid sequence (not crash or produce system→assistant).""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock() tools = MagicMock() @@ -2996,7 +3196,7 @@ async def test_runner_binds_on_retry_wait_to_retry_callback_not_progress(): internal retry diagnostics like "Model request failed, retry in 1s" to leak to end-user channels as normal progress updates. """ - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec captured: dict = {} diff --git a/tests/hooks/test_adapters.py b/tests/hooks/test_adapters.py new file mode 100644 index 00000000000..e52257dfed1 --- /dev/null +++ b/tests/hooks/test_adapters.py @@ -0,0 +1,591 @@ +"""Tests for AgentHook → HookCenter adapter (U4).""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook +from nanobot.hooks.adapters import adapt_agent_hook, adapt_agent_hook_list +from nanobot.hooks.center import HookCenter +from nanobot.hooks.event_types import ( + AfterIteration, + BeforeExecuteTools, + BeforeIteration, + OnStream, + OnStreamEnd, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _bi_ev(iteration=0, messages=None): + if messages is None: + messages = [] + return BeforeIteration(iteration=iteration, messages=messages) + + +def _os_ev(delta="x", iteration=0): + return OnStream(delta=delta, iteration=iteration) + + +def _ose_ev(resuming=True, iteration=0): + return OnStreamEnd(resuming=resuming, iteration=iteration) + + +def _bet_ev(iteration=0, tool_calls=None, response=None): + if tool_calls is None: + tool_calls = [] + return BeforeExecuteTools(iteration=iteration, tool_calls=tool_calls, response=response) + + +def _ai_ev(iteration=0, **kw): + defaults: dict[str, Any] = { + "final_content": "ok", + "stop_reason": "completed", + "usage": {}, + "tool_calls": [], + "tool_events": [], + "tool_results": [], + "error": None, + } + defaults.update(kw) + return AfterIteration(iteration=iteration, **defaults) + + +# --------------------------------------------------------------------------- +# Happy path: overridden methods are adapted and called +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_before_iteration_adapted_and_called(): + center = HookCenter() + session = center.create_session() + calls: list[int] = [] + + class H(AgentHook): + async def before_iteration(self, ctx: AgentHookContext) -> None: + calls.append(ctx.iteration) + + hook = H() + adapt_agent_hook(hook, session, center) + + await center.emit(_bi_ev(iteration=3), session) + + assert calls == [3] + + +@pytest.mark.asyncio +async def test_on_stream_adapted_and_called(): + center = HookCenter() + session = center.create_session() + deltas: list[str] = [] + + class H(AgentHook): + async def on_stream(self, ctx: AgentHookContext, delta: str) -> None: + deltas.append(delta) + + hook = H() + adapt_agent_hook(hook, session, center) + + await center.emit(_bi_ev(), session) + + await center.emit(_os_ev(delta="hello"), session) + await center.emit(_os_ev(delta=" world"), session) + + assert deltas == ["hello", " world"] + + +@pytest.mark.asyncio +async def test_on_stream_end_adapted_and_called(): + center = HookCenter() + session = center.create_session() + recv: list[bool] = [] + + class H(AgentHook): + async def on_stream_end(self, ctx: AgentHookContext, *, resuming: bool) -> None: + recv.append(resuming) + + hook = H() + adapt_agent_hook(hook, session, center) + + await center.emit(_bi_ev(), session) # seed context + await center.emit(_ose_ev(resuming=True), session) + await center.emit(_ose_ev(resuming=False), session) + + assert recv == [True, False] + + +@pytest.mark.asyncio +async def test_before_execute_tools_adapted_and_called(): + center = HookCenter() + session = center.create_session() + saw_tool_calls: list[list[Any]] = [] + + class H(AgentHook): + async def before_execute_tools(self, ctx: AgentHookContext) -> None: + saw_tool_calls.append(list(ctx.tool_calls)) + + hook = H() + adapt_agent_hook(hook, session, center) + + await center.emit(_bi_ev(), session) # seed context + fake_tc = [object()] + await center.emit(_bet_ev(tool_calls=fake_tc), session) + + assert len(saw_tool_calls) == 1 + assert saw_tool_calls[0] == fake_tc + + +@pytest.mark.asyncio +async def test_after_iteration_adapted_and_called(): + center = HookCenter() + session = center.create_session() + seen: list[AgentHookContext] = [] + + class H(AgentHook): + async def after_iteration(self, ctx: AgentHookContext) -> None: + seen.append(ctx) + + hook = H() + adapt_agent_hook(hook, session, center) + + await center.emit(_bi_ev(), session) # seed + await center.emit( + _ai_ev( + iteration=5, + final_content="done", + stop_reason="completed", + error=None, + ), + session, + ) + + assert len(seen) == 1 + assert seen[0].iteration == 5 + assert seen[0].final_content == "done" + assert seen[0].stop_reason == "completed" + + +def test_finalize_content_adapted_and_called(): + center = HookCenter() + session = center.create_session() + + class H(AgentHook): + def finalize_content(self, ctx: AgentHookContext, content: str | None) -> str | None: + return (content or "") + "_adapted" + + hook = H() + adapt_agent_hook(hook, session, center) + + result = center.finalize_content("hello", session) + + assert result == "hello_adapted" + + +def test_finalize_content_pipeline_ordering(): + center = HookCenter() + session = center.create_session() + + class Upper(AgentHook): + def finalize_content(self, ctx, content): + return content.upper() if content else content + + class Suffix(AgentHook): + def finalize_content(self, ctx, content): + return (content + "!") if content else content + + adapt_agent_hook(Upper(), session, center) + adapt_agent_hook(Suffix(), session, center) + + result = center.finalize_content("hello", session) + assert result == "HELLO!" + + +def test_wants_streaming_true_adapted(): + center = HookCenter() + session = center.create_session() + + class H(AgentHook): + def wants_streaming(self) -> bool: + return True + + adapt_agent_hook(H(), session, center) + + assert center.wants_streaming(session) is True + + +def test_wants_streaming_false_not_set(): + center = HookCenter() + session = center.create_session() + + adapt_agent_hook(AgentHook(), session, center) + + assert center.wants_streaming(session) is False + + +# --------------------------------------------------------------------------- +# Edge: only overridden methods are registered +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_only_overridden_methods_are_registered(): + center = HookCenter() + session = center.create_session() + calls: list[str] = [] + + class PartialHook(AgentHook): + async def before_iteration(self, ctx): + calls.append("bi") + + # on_stream, on_stream_end, before_execute_tools NOT overridden + # after_iteration NOT overridden + + hook = PartialHook() + adapt_agent_hook(hook, session, center) + + await center.emit(_bi_ev(), session) + # These events have no registered handlers (just the base no-op) + await center.emit(_os_ev(), session) + await center.emit(_ose_ev(), session) + + assert calls == ["bi"] + + +# --------------------------------------------------------------------------- +# Edge: _reraise propagation +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_reraise_true_from_hook_attribute(): + center = HookCenter() + session = center.create_session() + + class H(AgentHook): + def __init__(self): + self._reraise = True + + async def before_iteration(self, ctx): + raise RuntimeError("reraise-me") + + hook = H() + adapt_agent_hook(hook, session, center) + + with pytest.raises(RuntimeError, match="reraise-me"): + await center.emit(_bi_ev(), session) + + +@pytest.mark.asyncio +async def test_reraise_explicit_overrides(): + center = HookCenter() + session = center.create_session() + + class H(AgentHook): + def __init__(self): + self._reraise = True + + async def before_iteration(self, ctx): + raise RuntimeError("bad") + + hook = H() + adapt_agent_hook(hook, session, center, reraise=False) + + await center.emit(_bi_ev(), session) # should not raise + + +@pytest.mark.asyncio +async def test_reraise_false_catches(): + center = HookCenter() + session = center.create_session() + + class H(AgentHook): + async def before_iteration(self, ctx): + raise RuntimeError("caught") + + adapt_agent_hook(H(), session, center) + + await center.emit(_bi_ev(), session) # no exception + + +# --------------------------------------------------------------------------- +# Edge: context without prior before_iteration (lazy init) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_on_stream_works_without_prior_context(): + center = HookCenter() + session = center.create_session() + deltas: list[str] = [] + + class H(AgentHook): + async def on_stream(self, ctx, delta): + deltas.append(delta) + + adapt_agent_hook(H(), session, center) + + await center.emit(_os_ev(delta="direct"), session) + + assert deltas == ["direct"] + + +# --------------------------------------------------------------------------- +# Integration: adapt_agent_hook_list +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_adapt_agent_hook_list_fan_out(): + center = HookCenter() + session = center.create_session() + calls: list[str] = [] + + class H1(AgentHook): + async def before_iteration(self, ctx): + calls.append("h1") + + class H2(AgentHook): + async def before_iteration(self, ctx): + calls.append("h2") + + adapt_agent_hook_list([H1(), H2()], session, center) + + await center.emit(_bi_ev(), session) + + assert calls == ["h1", "h2"] + + +@pytest.mark.asyncio +async def test_adapt_agent_hook_list_flattens_composite(): + center = HookCenter() + session = center.create_session() + calls: list[str] = [] + + class Inner1(AgentHook): + async def before_iteration(self, ctx): + calls.append("inner1") + + class Inner2(AgentHook): + async def before_iteration(self, ctx): + calls.append("inner2") + + composite = CompositeHook([Inner1(), CompositeHook([Inner2()])]) + + adapt_agent_hook_list([composite], session, center) + + await center.emit(_bi_ev(), session) + + assert calls == ["inner1", "inner2"] + + +@pytest.mark.asyncio +async def test_adapt_agent_hook_list_no_double_adapt(): + """CompositeHook itself is NOT adapted — only its leaves.""" + center = HookCenter() + session = center.create_session() + calls: list[str] = [] + + class Leaf(AgentHook): + async def before_iteration(self, ctx): + calls.append("leaf") + + composite = CompositeHook([Leaf()]) + adapt_agent_hook_list([composite], session, center) + + await center.emit(_bi_ev(), session) + + # If CompositeHook were adapted directly AND its leaves, + # we'd get duplicate calls. We expect exactly one. + assert calls == ["leaf"] + + +# --------------------------------------------------------------------------- +# Integration: RecordingHook pattern from test_hook_composite.py +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_recording_hook_through_adapter(): + """RecordingHook (same pattern as test_hook_composite.py) works through adapter.""" + center = HookCenter() + session = center.create_session() + events: list[str] = [] + + class RecordingHook(AgentHook): + async def before_iteration(self, ctx): + events.append("before_iteration") + + async def on_stream(self, ctx, delta): + events.append(f"on_stream:{delta}") + + async def on_stream_end(self, ctx, *, resuming): + events.append(f"on_stream_end:{resuming}") + + async def before_execute_tools(self, ctx): + events.append("before_execute_tools") + + async def after_iteration(self, ctx): + events.append("after_iteration") + + hook = RecordingHook() + adapt_agent_hook(hook, session, center) + + await center.emit(_bi_ev(), session) + await center.emit(_os_ev(delta="hi"), session) + await center.emit(_ose_ev(resuming=True), session) + await center.emit(_bet_ev(), session) + await center.emit(_ai_ev(), session) + + assert events == [ + "before_iteration", + "on_stream:hi", + "on_stream_end:True", + "before_execute_tools", + "after_iteration", + ] + + +# --------------------------------------------------------------------------- +# Context accumulation across iteration lifecycle +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_context_accumulates_across_events(): + """Context built in BeforeIteration is visible to subsequent events.""" + center = HookCenter() + session = center.create_session() + seen_messages: list[list[dict]] = [] + + class H(AgentHook): + async def before_iteration(self, ctx): + pass # just creates context + + async def on_stream(self, ctx, delta): + seen_messages.append(list(ctx.messages)) + + adapt_agent_hook(H(), session, center) + + test_msgs = [{"role": "user", "content": "hello"}] + await center.emit(_bi_ev(messages=test_msgs), session) + await center.emit(_os_ev(delta="a"), session) + + assert seen_messages == [test_msgs] + + +@pytest.mark.asyncio +async def test_before_execute_tools_receives_accumulated_context(): + """before_execute_tools sees the context built from before_iteration + streaming.""" + center = HookCenter() + session = center.create_session() + captured_iter: list[int] = [] + captured_messages: list[list[dict]] = [] + + class H(AgentHook): + async def before_iteration(self, ctx): + pass # seeds the context cell + + async def before_execute_tools(self, ctx): + captured_iter.append(ctx.iteration) + captured_messages.append(list(ctx.messages)) + + adapt_agent_hook(H(), session, center) + + test_msgs = [{"role": "user", "content": "test"}] + await center.emit(_bi_ev(iteration=7, messages=test_msgs), session) + await center.emit(_bet_ev(iteration=7), session) + + assert captured_iter == [7] + assert captured_messages == [test_msgs] + + +# --------------------------------------------------------------------------- +# finalize_content with None content +# --------------------------------------------------------------------------- + + +def test_finalize_content_none_passthrough(): + center = HookCenter() + session = center.create_session() + + class H(AgentHook): + def finalize_content(self, ctx, content): + return content + + adapt_agent_hook(H(), session, center) + + result = center.finalize_content(None, session) + assert result is None + + +def test_finalize_content_error_caught_by_center(): + center = HookCenter() + session = center.create_session() + + class Bad(AgentHook): + def finalize_content(self, ctx, content): + raise RuntimeError("bad finalize") + + class Good(AgentHook): + def finalize_content(self, ctx, content): + return (content or "") + "_good" + + adapt_agent_hook(Bad(), session, center) + adapt_agent_hook(Good(), session, center) + + result = center.finalize_content("test", session) + assert result == "test_good" + + +# --------------------------------------------------------------------------- +# adapt_agent_hook_list mixed: some with overridden, some without +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_adapt_list_mixed_overrides(): + center = HookCenter() + session = center.create_session() + calls: list[str] = [] + + class Overriding(AgentHook): + async def before_iteration(self, ctx): + calls.append("overridden") + + adapt_agent_hook_list([Overriding(), AgentHook()], session, center) + + await center.emit(_bi_ev(), session) + + # Only the overriding hook should trigger + assert calls == ["overridden"] + + +# --------------------------------------------------------------------------- +# on_stream_end: resuming keyword is passed correctly +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_on_stream_end_resuming_keyword(): + center = HookCenter() + session = center.create_session() + resuming_values: list[bool] = [] + + class H(AgentHook): + async def on_stream_end(self, ctx, *, resuming): + resuming_values.append(resuming) + + adapt_agent_hook(H(), session, center) + + await center.emit(_bi_ev(), session) + await center.emit(_ose_ev(resuming=True), session) + await center.emit(_ose_ev(resuming=False), session) + await center.emit(_ose_ev(resuming=False), session) + + assert resuming_values == [True, False, False] diff --git a/tests/hooks/test_center.py b/tests/hooks/test_center.py new file mode 100644 index 00000000000..0d35bb85e2a --- /dev/null +++ b/tests/hooks/test_center.py @@ -0,0 +1,672 @@ +"""Tests for HookCenter registry and dispatch engine.""" + +from __future__ import annotations + +from unittest.mock import Mock + +import pytest + +from nanobot.hooks.center import HookCenter +from nanobot.hooks.event_types import ( + BeforeIteration, + FinalizeContent, + OnStream, + OnStreamEnd, +) +from nanobot.hooks.protocols import Deny, HookResult, Modified + +# --------------------------------------------------------------------------- +# Happy path: register_internal handler +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_register_internal_handler_is_called(): + center = HookCenter() + session = center.create_session() + handler = Mock(return_value=None) + center.register_internal(session, BeforeIteration, handler) + event = BeforeIteration(iteration=0, messages=[]) + + await center.emit(event, session) + + handler.assert_called_once_with(event) + + +@pytest.mark.asyncio +async def test_register_internal_handler_sets_mode(): + center = HookCenter() + session = center.create_session() + handler = Mock(return_value=None) + + center.register_internal(session, BeforeIteration, handler, mode="guard") + event = BeforeIteration(iteration=0, messages=[]) + await center.emit(event, session) + + handler.assert_called_once_with(event) + + +# --------------------------------------------------------------------------- +# Happy path: register external handler +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_register_external_handler_is_called(): + center = HookCenter() + session = center.create_session() + handler = Mock(return_value=None) + center.register(BeforeIteration, handler, mode="observe") + event = BeforeIteration(iteration=0, messages=[]) + + await center.emit(event, session) + + handler.assert_called_once_with(event) + + +@pytest.mark.asyncio +async def test_guard_handler_returns_deny_stops_emit(): + center = HookCenter() + session = center.create_session() + guard = Mock(return_value=Deny(reason="blocked")) + later = Mock() + + center.register(BeforeIteration, guard, mode="guard") + center.register(BeforeIteration, later, mode="observe") + event = BeforeIteration(iteration=0, messages=[]) + + result = await center.emit(event, session) + + assert isinstance(result, Deny) + assert result.reason == "blocked" + guard.assert_called_once_with(event) + later.assert_not_called() + + +# --------------------------------------------------------------------------- +# Transforms +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_transform_pipeline_second_sees_first_result(): + center = HookCenter() + session = center.create_session() + events_seen: list[int] = [] + + async def t1(ev: BeforeIteration) -> HookResult: + events_seen.append(ev.iteration) + ev.iteration = 99 + return Modified(data={"iteration": 99}) + + async def t2(ev: BeforeIteration) -> HookResult: + events_seen.append(ev.iteration) + return None + + center.register(BeforeIteration, t1, mode="transform") + center.register(BeforeIteration, t2, mode="transform") + event = BeforeIteration(iteration=0, messages=[]) + + await center.emit(event, session) + + assert events_seen == [0, 99] + + +# --------------------------------------------------------------------------- +# Observe +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_observe_handler_called_return_none_ignored(): + center = HookCenter() + session = center.create_session() + handler = Mock(return_value=None) + + center.register(BeforeIteration, handler, mode="observe") + event = BeforeIteration(iteration=0, messages=[]) + result = await center.emit(event, session) + + handler.assert_called_once_with(event) + assert result is None + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_emit_with_no_handlers_is_noop(): + center = HookCenter() + session = center.create_session() + event = BeforeIteration(iteration=0, messages=[]) + + result = await center.emit(event, session) + + assert result is None + + +@pytest.mark.asyncio +async def test_guard_deny_blocks_transforms_and_observes(): + center = HookCenter() + session = center.create_session() + guard = Mock(return_value=Deny("stop")) + tx = Mock() + obs = Mock() + + center.register(BeforeIteration, guard, mode="guard") + center.register(BeforeIteration, tx, mode="transform") + center.register(BeforeIteration, obs, mode="observe") + event = BeforeIteration(iteration=0, messages=[]) + + result = await center.emit(event, session) + + assert isinstance(result, Deny) + tx.assert_not_called() + obs.assert_not_called() + + +# --------------------------------------------------------------------------- +# Dedup: same handler + event type + mode +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_same_handler_registered_twice_dedup(): + center = HookCenter() + session = center.create_session() + handler = Mock(return_value=None) + + center.register(BeforeIteration, handler, mode="guard") + center.register(BeforeIteration, handler, mode="guard") + event = BeforeIteration(iteration=0, messages=[]) + + await center.emit(event, session) + + handler.assert_called_once_with(event) + + +@pytest.mark.asyncio +async def test_same_handler_different_modes_not_deduped(): + center = HookCenter() + session = center.create_session() + handler = Mock(return_value=None) + + center.register(BeforeIteration, handler, mode="guard") + center.register(BeforeIteration, handler, mode="observe") + event = BeforeIteration(iteration=0, messages=[]) + + await center.emit(event, session) + + assert handler.call_count == 2 + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_handler_error_reraise_false_caught(): + center = HookCenter() + session = center.create_session() + + async def bad_handler(event): + raise RuntimeError("boom") + + good = Mock(return_value=None) + + center.register(BeforeIteration, bad_handler, mode="observe") + center.register(BeforeIteration, good, mode="observe") + event = BeforeIteration(iteration=0, messages=[]) + + result = await center.emit(event, session) + + assert result is None + good.assert_called_once_with(event) + + +@pytest.mark.asyncio +async def test_handler_error_reraise_true_propagates(): + center = HookCenter() + session = center.create_session() + + async def bad_handler(event): + raise RuntimeError("propagate-me") + + center.register_internal(session, BeforeIteration, bad_handler, reraise=True, mode="observe") + event = BeforeIteration(iteration=0, messages=[]) + + with pytest.raises(RuntimeError, match="propagate-me"): + await center.emit(event, session) + + +@pytest.mark.asyncio +async def test_internal_handler_reraise_false_caught_others_continue(): + center = HookCenter() + session = center.create_session() + + async def bad(event): + raise RuntimeError("err") + + good = Mock(return_value=None) + + center.register_internal(session, BeforeIteration, bad, reraise=False, mode="observe") + center.register_internal(session, BeforeIteration, good, reraise=False, mode="observe") + event = BeforeIteration(iteration=0, messages=[]) + + await center.emit(event, session) + + good.assert_called_once_with(event) + + +# --------------------------------------------------------------------------- +# wants_streaming +# --------------------------------------------------------------------------- + + +def test_wants_streaming_false_with_no_handlers(): + center = HookCenter() + session = center.create_session() + + assert center.wants_streaming(session) is False + + +def test_wants_streaming_true_when_internal_streaming_handler_registered(): + center = HookCenter() + session = center.create_session() + + center.register_internal(session, OnStream, Mock(), mode="observe") + + assert center.wants_streaming(session) is True + + +def test_wants_streaming_true_when_external_streaming_handler_registered(): + center = HookCenter() + + center.register(OnStream, Mock(), mode="observe") + + session = center.create_session() + assert center.wants_streaming(session) is True + + +def test_wants_streaming_true_for_on_stream_end(): + center = HookCenter() + + center.register(OnStreamEnd, Mock(), mode="observe") + + session = center.create_session() + assert center.wants_streaming(session) is True + + +# --------------------------------------------------------------------------- +# finalize_content +# --------------------------------------------------------------------------- + + +def test_finalize_content_pipeline(): + center = HookCenter() + session = center.create_session() + + def upper(c: str | None) -> str | None: + return c.upper() if c else c + + def suffix(c: str | None) -> str | None: + return (c + "!") if c else c + + center.register_internal(session, FinalizeContent, upper, mode="transform") + center.register_internal(session, FinalizeContent, suffix, mode="transform") + + result = center.finalize_content("hello", session) + assert result == "HELLO!" + + +def test_finalize_content_none_passthrough(): + center = HookCenter() + session = center.create_session() + + result = center.finalize_content(None, session) + assert result is None + + +def test_finalize_content_external_handlers(): + center = HookCenter() + session = center.create_session() + + def internal(c): + return c.upper() if c else c + + def external(c): + return (c + "!") if c else c + + center.register_internal(session, FinalizeContent, internal, mode="transform") + center.register(FinalizeContent, external, mode="transform") + + result = center.finalize_content("hey", session) + + assert result == "HEY!" + + +def test_finalize_content_error_caught(): + center = HookCenter() + session = center.create_session() + + def bad(c): + raise RuntimeError("boom") + + def good(c): + return (c or "") + "suffix" + + center.register_internal(session, FinalizeContent, bad, reraise=False, mode="transform") + center.register_internal(session, FinalizeContent, good, mode="transform") + + result = center.finalize_content("hi", session) + assert result == "hisuffix" + + +# --------------------------------------------------------------------------- +# reset +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_reset_clears_all_handlers(): + center = HookCenter() + session = center.create_session() + handler = Mock(return_value=None) + + center.register(BeforeIteration, handler, mode="observe") + center.reset() + + event = BeforeIteration(iteration=0, messages=[]) + await center.emit(event, session) + + handler.assert_not_called() + + +# --------------------------------------------------------------------------- +# Internal vs external independence +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_external_does_not_overwrite_internal(): + center = HookCenter() + session = center.create_session() + + internal = Mock(return_value=None) + external = Mock(return_value=None) + + center.register_internal(session, BeforeIteration, internal, mode="observe") + center.register(BeforeIteration, external, mode="observe") + + event = BeforeIteration(iteration=0, messages=[]) + await center.emit(event, session) + + internal.assert_called_once_with(event) + external.assert_called_once_with(event) + + +# --------------------------------------------------------------------------- +# Guard → Transform → Observe order +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_dispatch_order_guard_transform_observe(): + center = HookCenter() + session = center.create_session() + order: list[str] = [] + + async def g(event): + order.append("guard") + return None + + async def t(event): + order.append("transform") + return None + + async def o(event): + order.append("observe") + return None + + center.register(BeforeIteration, g, mode="guard") + center.register(BeforeIteration, t, mode="transform") + center.register(BeforeIteration, o, mode="observe") + event = BeforeIteration(iteration=0, messages=[]) + + await center.emit(event, session) + + assert order == ["guard", "transform", "observe"] + + +# --------------------------------------------------------------------------- +# Internal before external order +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_internal_runs_before_external(): + center = HookCenter() + session = center.create_session() + order: list[str] = [] + + async def internal_obs(event): + order.append("internal") + return None + + async def external_obs(event): + order.append("external") + return None + + center.register_internal(session, BeforeIteration, internal_obs, mode="observe") + center.register(BeforeIteration, external_obs, mode="observe") + event = BeforeIteration(iteration=0, messages=[]) + + await center.emit(event, session) + + assert order == ["internal", "external"] + + +# --------------------------------------------------------------------------- +# register_point +# --------------------------------------------------------------------------- + + +# --------------------------------------------------------------------------- +# Invalid mode +# --------------------------------------------------------------------------- + + +def test_register_invalid_mode_raises(): + center = HookCenter() + with pytest.raises(ValueError, match="Unknown mode"): + center.register(BeforeIteration, Mock(), mode="invalid") + + +def test_register_internal_invalid_mode_raises(): + center = HookCenter() + session = center.create_session() + with pytest.raises(ValueError, match="Unknown mode"): + center.register_internal(session, BeforeIteration, Mock(), mode="nope") + + +# --------------------------------------------------------------------------- +# discover placeholder +# --------------------------------------------------------------------------- + + +def test_discover_is_noop_placeholder(): + center = HookCenter() + center.discover(None) + + +# --------------------------------------------------------------------------- +# Deny.abort +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_deny_abort_defaults_false(): + d = Deny(reason="blocked") + assert d.abort is False + + +@pytest.mark.asyncio +async def test_deny_abort_true_propagated_through_emit(): + center = HookCenter() + session = center.create_session() + guard = Mock(return_value=Deny(reason="stop", abort=True)) + later = Mock() + + center.register(BeforeIteration, guard, mode="guard") + center.register(BeforeIteration, later, mode="observe") + event = BeforeIteration(iteration=0, messages=[]) + + result = await center.emit(event, session) + + assert isinstance(result, Deny) + assert result.reason == "stop" + assert result.abort is True + later.assert_not_called() + + +@pytest.mark.asyncio +async def test_deny_abort_false_does_not_set_abort(): + center = HookCenter() + session = center.create_session() + guard = Mock(return_value=Deny(reason="soft")) + + center.register(BeforeIteration, guard, mode="guard") + event = BeforeIteration(iteration=0, messages=[]) + + result = await center.emit(event, session) + + assert isinstance(result, Deny) + assert result.abort is False + + +# --------------------------------------------------------------------------- +# Guard exception handling +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_guard_exception_reraise_false_others_continue(): + center = HookCenter() + session = center.create_session() + + async def bad_guard(event): + raise RuntimeError("guard boom") + + good_guard = Mock(return_value=None) + + center.register(BeforeIteration, bad_guard, mode="guard") + center.register(BeforeIteration, good_guard, mode="guard") + event = BeforeIteration(iteration=0, messages=[]) + + result = await center.emit(event, session) + + assert result is None + good_guard.assert_called_once_with(event) + + +@pytest.mark.asyncio +async def test_guard_exception_reraise_true_propagates(): + center = HookCenter() + session = center.create_session() + + async def bad_guard(event): + raise RuntimeError("guard propagate") + + center.register_internal( + session, BeforeIteration, bad_guard, reraise=True, mode="guard" + ) + event = BeforeIteration(iteration=0, messages=[]) + + with pytest.raises(RuntimeError, match="guard propagate"): + await center.emit(event, session) + + +# --------------------------------------------------------------------------- +# Transform exception handling +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_transform_exception_non_reraise_pipeline_continues(): + center = HookCenter() + session = center.create_session() + events_seen: list[str] = [] + + async def bad_transform(event): + events_seen.append("bad") + raise RuntimeError("transform boom") + + async def good_transform(event): + events_seen.append("good") + return None + + center.register(BeforeIteration, bad_transform, mode="transform") + center.register(BeforeIteration, good_transform, mode="transform") + event = BeforeIteration(iteration=0, messages=[]) + + await center.emit(event, session) + + assert events_seen == ["bad", "good"] + + +# --------------------------------------------------------------------------- +# finalize_content: Deny stops pipeline +# --------------------------------------------------------------------------- + + +def test_finalize_content_deny_stops_pipeline(): + center = HookCenter() + session = center.create_session() + + def deny_handler(c): + return Deny(reason="blocked") + + def after_deny(c): + return "should not run" + + center.register_internal(session, FinalizeContent, deny_handler, mode="transform") + center.register_internal(session, FinalizeContent, after_deny, mode="transform") + + result = center.finalize_content("hello", session) + assert result == "hello" + + +# --------------------------------------------------------------------------- +# Independent sessions +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_emit_independent_sessions(): + center = HookCenter() + session_a = center.create_session() + session_b = center.create_session() + + events_a: list[int] = [] + events_b: list[int] = [] + + async def handler_a(event): + events_a.append(event.iteration) + + async def handler_b(event): + events_b.append(event.iteration) + + center.register_internal(session_a, BeforeIteration, handler_a, mode="observe") + center.register_internal(session_b, BeforeIteration, handler_b, mode="observe") + + event_a = BeforeIteration(iteration=1, messages=[]) + event_b = BeforeIteration(iteration=2, messages=[]) + + await center.emit(event_a, session_a) + await center.emit(event_b, session_b) + + assert events_a == [1] + assert events_b == [2] diff --git a/tests/hooks/test_discovery.py b/tests/hooks/test_discovery.py new file mode 100644 index 00000000000..8900bf89894 --- /dev/null +++ b/tests/hooks/test_discovery.py @@ -0,0 +1,380 @@ +"""Tests for HookCenter entry-point plugin discovery.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from nanobot.hooks.center import HookCenter +from nanobot.hooks.discovery import discover_hook_plugins, register_discovered +from nanobot.hooks.event_types import ( + AfterIteration, + BeforeExecuteTools, + BeforeIteration, +) + +_EP_TARGET = "importlib.metadata.entry_points" + + +def _make_entry_point(name: str, handler): + ep = SimpleNamespace(name=name, load=lambda _h=handler: _h) + return ep + + +def _make_entry_point_with_error(name: str): + def _boom(): + raise ImportError(f"cannot import {name}") + + ep = SimpleNamespace(name=name, load=_boom) + return ep + + +# --------------------------------------------------------------------------- +# discover_hook_plugins +# --------------------------------------------------------------------------- + + +def test_discover_returns_empty_dict_for_no_entry_points(): + with patch(_EP_TARGET, return_value=[]): + result = discover_hook_plugins() + + assert result == {} + + +def test_discover_with_none_enabled_loads_nothing(): + handler = object() + + with patch(_EP_TARGET, return_value=[_make_entry_point("my_plugin", handler)]): + result = discover_hook_plugins(enabled=None) + + assert result == {} + + +def test_discover_loads_single_plugin(): + handler = object() + + with patch(_EP_TARGET, return_value=[_make_entry_point("my_plugin", handler)]): + result = discover_hook_plugins(enabled=["my_plugin"]) + + assert result == {"my_plugin": handler} + + +def test_discover_loads_multiple_plugins(): + h1, h2 = object(), object() + + with patch( + _EP_TARGET, + return_value=[_make_entry_point("a", h1), _make_entry_point("b", h2)], + ): + result = discover_hook_plugins(enabled=["a", "b"]) + + assert result == {"a": h1, "b": h2} + + +def test_discover_skips_failed_plugin_loads_others(): + h2 = object() + + with patch( + _EP_TARGET, + return_value=[ + _make_entry_point_with_error("broken"), + _make_entry_point("ok", h2), + ], + ): + result = discover_hook_plugins(enabled=["broken", "ok"]) + + assert "broken" not in result + assert result == {"ok": h2} + + +# --------------------------------------------------------------------------- +# register_discovered — happy path +# --------------------------------------------------------------------------- + + +def test_register_discovered_registers_handler(): + center = HookCenter() + handler = Mock(return_value=None) + handler.hook_events = [(BeforeIteration, "guard")] + + config = SimpleNamespace(hooks=SimpleNamespace(enabled_plugins=["testguard"])) + + with patch( + _EP_TARGET, + return_value=[_make_entry_point("testguard", handler)], + ): + register_discovered(center, config) + + external = center._external_handlers + assert BeforeIteration in external + assert external[BeforeIteration]["guard"] == [handler] + + +def test_register_discovered_registers_multiple_plugins(): + center = HookCenter() + h1 = Mock(return_value=None) + h1.hook_events = [(BeforeIteration, "guard")] + h2 = Mock(return_value=None) + h2.hook_events = [(BeforeIteration, "observe")] + + config = SimpleNamespace(hooks=SimpleNamespace(enabled_plugins=["g", "o"])) + + with patch( + _EP_TARGET, + return_value=[ + _make_entry_point("g", h1), + _make_entry_point("o", h2), + ], + ): + register_discovered(center, config) + + external = center._external_handlers + assert external[BeforeIteration]["guard"] == [h1] + assert external[BeforeIteration]["observe"] == [h2] + + +def test_register_discovered_plugin_subscribes_multiple_event_types(): + center = HookCenter() + handler = Mock(return_value=None) + handler.hook_events = [ + (BeforeIteration, "guard"), + (AfterIteration, "observe"), + ] + + config = SimpleNamespace(hooks=SimpleNamespace(enabled_plugins=["multi"])) + + with patch( + _EP_TARGET, + return_value=[_make_entry_point("multi", handler)], + ): + register_discovered(center, config) + + external = center._external_handlers + assert external[BeforeIteration]["guard"] == [handler] + assert external[AfterIteration]["observe"] == [handler] + + +def test_register_discovered_multiple_plugins_same_event(): + center = HookCenter() + h1 = Mock(return_value=None) + h1.hook_events = [(BeforeIteration, "observe")] + h2 = Mock(return_value=None) + h2.hook_events = [(BeforeIteration, "observe")] + + config = SimpleNamespace(hooks=SimpleNamespace(enabled_plugins=["p1", "p2"])) + + with patch( + _EP_TARGET, + return_value=[ + _make_entry_point("p1", h1), + _make_entry_point("p2", h2), + ], + ): + register_discovered(center, config) + + external = center._external_handlers + assert external[BeforeIteration]["observe"] == [h1, h2] + + +# --------------------------------------------------------------------------- +# register_discovered — allowlist +# --------------------------------------------------------------------------- + + +def test_register_discovered_respects_enabled_plugins_allowlist(): + center = HookCenter() + allowed = Mock(return_value=None) + allowed.hook_events = [(BeforeIteration, "observe")] + blocked = Mock(return_value=None) + blocked.hook_events = [(BeforeIteration, "observe")] + + config = SimpleNamespace( + hooks=SimpleNamespace(enabled_plugins=["allowed_plugin"]), + ) + + with patch( + _EP_TARGET, + return_value=[ + _make_entry_point("allowed_plugin", allowed), + _make_entry_point("blocked_plugin", blocked), + ], + ): + register_discovered(center, config) + + external = center._external_handlers + assert external[BeforeIteration]["observe"] == [allowed] + assert blocked not in external.get(BeforeIteration, {}).get("observe", []) + + +def test_register_discovered_allowlist_none_denies_all(): + center = HookCenter() + h1 = Mock(return_value=None) + h1.hook_events = [(BeforeIteration, "observe")] + h2 = Mock(return_value=None) + h2.hook_events = [(BeforeIteration, "observe")] + + config = SimpleNamespace(hooks=SimpleNamespace(enabled_plugins=None)) + + with patch( + _EP_TARGET, + return_value=[ + _make_entry_point("p1", h1), + _make_entry_point("p2", h2), + ], + ): + register_discovered(center, config) + + assert center._external_handlers == {} + + +def test_register_discovered_no_hooks_config_denies_all(): + center = HookCenter() + handler = Mock(return_value=None) + handler.hook_events = [(BeforeIteration, "observe")] + + config = SimpleNamespace() + + with patch( + _EP_TARGET, + return_value=[_make_entry_point("p", handler)], + ): + register_discovered(center, config) + + assert center._external_handlers == {} + + +def test_register_discovered_no_config_denies_all(): + center = HookCenter() + handler = Mock(return_value=None) + handler.hook_events = [(BeforeIteration, "observe")] + + with patch( + _EP_TARGET, + return_value=[_make_entry_point("p", handler)], + ): + register_discovered(center) + + assert center._external_handlers == {} + + +# --------------------------------------------------------------------------- +# register_discovered — edge cases +# --------------------------------------------------------------------------- + + +def test_register_discovered_empty_entry_points_noop(): + center = HookCenter() + + with patch(_EP_TARGET, return_value=[]): + register_discovered(center) + + assert center._external_handlers == {} + + +def test_register_discovered_skips_plugin_without_hook_events(): + center = HookCenter() + handler = object() + + with patch( + _EP_TARGET, + return_value=[_make_entry_point("noevents", handler)], + ): + register_discovered(center) + + assert center._external_handlers == {} + + +# --------------------------------------------------------------------------- +# register_discovered — error paths +# --------------------------------------------------------------------------- + + +def test_register_discovered_single_plugin_load_error_skips_and_continues(): + center = HookCenter() + ok_handler = Mock(return_value=None) + ok_handler.hook_events = [(BeforeExecuteTools, "transform")] + + config = SimpleNamespace(hooks=SimpleNamespace(enabled_plugins=["bad", "good"])) + + with patch( + _EP_TARGET, + return_value=[ + _make_entry_point_with_error("bad"), + _make_entry_point("good", ok_handler), + ], + ): + register_discovered(center, config) + + external = center._external_handlers + assert BeforeExecuteTools in external + assert external[BeforeExecuteTools]["transform"] == [ok_handler] + + +def test_register_discovered_entry_points_raises_no_handlers_registered(): + center = HookCenter() + + def _fail(group): + raise RuntimeError("metadata not available") + + with patch(_EP_TARGET, side_effect=_fail): + register_discovered(center) + + assert center._external_handlers == {} + + +# --------------------------------------------------------------------------- +# HookCenter.discover integration +# --------------------------------------------------------------------------- + + +def test_center_discover_delegates_to_register_discovered(): + center = HookCenter() + handler = Mock(return_value=None) + handler.hook_events = [(BeforeIteration, "guard")] + + config = SimpleNamespace(hooks=SimpleNamespace(enabled_plugins=["p"])) + + with patch( + _EP_TARGET, + return_value=[_make_entry_point("p", handler)], + ): + center.discover(config) + + external = center._external_handlers + assert external[BeforeIteration]["guard"] == [handler] + + +def test_center_discover_with_config(): + center = HookCenter() + allowed = Mock(return_value=None) + allowed.hook_events = [(BeforeIteration, "observe")] + blocked = Mock(return_value=None) + blocked.hook_events = [(BeforeIteration, "observe")] + + config = SimpleNamespace( + hooks=SimpleNamespace(enabled_plugins=["allowed"]), + ) + + with patch( + _EP_TARGET, + return_value=[ + _make_entry_point("allowed", allowed), + _make_entry_point("blocked", blocked), + ], + ): + center.discover(config) + + external = center._external_handlers + assert external[BeforeIteration]["observe"] == [allowed] + + +def test_center_discover_does_not_prevent_agent_startup_on_error(): + center = HookCenter() + + def _fail(group): + raise RuntimeError("metadata unavailable") + + with patch(_EP_TARGET, side_effect=_fail): + center.discover() + + assert center._external_handlers == {}