Skip to content

Commit 99e8f22

Browse files
ilblackdragonclaude
authored andcommitted
fix: persist turns after approval and add agent-level tests (nearai#250)
* fix: persist turns after approval and add agent-level tests Port relevant changes from PR nearai#112 that were not carried over to nearai#237: - Add persist_turn calls in process_approval for the response, error, and auth-required paths. Previously, turns completed after tool approval were never persisted to DB — if the process crashed after approval the entire turn (user message + assistant response) was lost. - Add agent-level unit tests: StaticLlmProvider mock, make_test_agent helper, tests for auto-approval logic, destructive shell command detection, and PendingApproval backward-compatible deserialization (without deferred_tool_calls field). - Remove unused _thread_state binding in process_approval. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: address 14 audit findings in src/agent/ Audit of the agent module found 2 High, 7 Medium, 3 Low, and 2 Nit severity issues. This commit fixes all of them: High: - Remove 4 `.expect()` calls in session.rs (entry API, match, direct indexing, if-let) to eliminate panic paths in production - Add typed RoutineError enum replacing Result<_, String> across routine.rs, routine_engine.rs, and callers in history/store.rs and db/libsql/mod.rs Medium: - Sanitize routine names in path construction to prevent directory traversal (routine_engine.rs) - Log warnings for 5 silently-swallowed errors in scheduler.rs, compaction.rs, and worker.rs - Extract shared handle_auth_intercept helper to deduplicate auth interception in thread_ops.rs - Add session count warning threshold in session_manager.rs - Make FullJob stub degradation visible via warn-level log and prepended warning in output Low: - Restrict dead code visibility with #[cfg(test)] on 19 unused items in submission.rs, task.rs, and undo.rs - Narrow pub to pub(crate) on self_repair.rs builder methods - Remove TaskStatus from mod.rs re-exports (test-only type) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: address PR review comments - Reorder persist_turn before persist_response_chain so the conversation row exists before the metadata UPDATE runs - Add persist_response_chain call to handle_auth_intercept so auth-required paths preserve the response chain - Harden sanitize_routine_name to use allowlist (alphanumeric, dash, underscore) instead of denylist replacements - Fix stale active_thread ID in get_or_create_thread: fall back to create_thread() when the stored ID is missing from the map - Persist turn on approval rejection so user messages survive crashes after a tool is rejected Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 71253c5 commit 99e8f22

17 files changed

Lines changed: 582 additions & 155 deletions

src/agent/compaction.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,16 @@ impl ContextCompactor {
105105

106106
// Write to workspace if available
107107
let summary_written = if let Some(ws) = workspace {
108-
self.write_summary_to_workspace(ws, &summary).await.is_ok()
108+
match self.write_summary_to_workspace(ws, &summary).await {
109+
Ok(()) => true,
110+
Err(e) => {
111+
tracing::warn!(
112+
"Compaction summary write failed (turns will still be truncated): {}",
113+
e
114+
);
115+
false
116+
}
117+
}
109118
} else {
110119
false
111120
};
@@ -157,7 +166,16 @@ impl ContextCompactor {
157166
let content = format_turns_for_storage(old_turns);
158167

159168
// Write to workspace
160-
let written = self.write_context_to_workspace(ws, &content).await.is_ok();
169+
let written = match self.write_context_to_workspace(ws, &content).await {
170+
Ok(()) => true,
171+
Err(e) => {
172+
tracing::warn!(
173+
"Compaction context write failed (turns will still be truncated): {}",
174+
e
175+
);
176+
false
177+
}
178+
};
161179

162180
// Truncate
163181
thread.truncate_turns(keep_recent);

src/agent/dispatcher.rs

Lines changed: 214 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ impl Agent {
402402
// and short-circuit: return the instructions directly so
403403
// the LLM doesn't get a chance to hallucinate tool calls.
404404
if let Some((ext_name, instructions)) =
405-
detect_auth_awaiting(&tc.name, &tool_result)
405+
check_auth_required(&tc.name, &tool_result)
406406
{
407407
let auth_data = parse_auth_result(&tool_result);
408408
{
@@ -581,7 +581,7 @@ pub(super) fn parse_auth_result(result: &Result<String, Error>) -> ParsedAuthDat
581581
///
582582
/// Returns `Some((extension_name, instructions))` if the tool result contains
583583
/// `awaiting_token: true`, meaning the thread should enter auth mode.
584-
pub(super) fn detect_auth_awaiting(
584+
pub(super) fn check_auth_required(
585585
tool_name: &str,
586586
result: &Result<String, Error>,
587587
) -> Option<(String, String)> {
@@ -604,9 +604,213 @@ pub(super) fn detect_auth_awaiting(
604604

605605
#[cfg(test)]
606606
mod tests {
607+
use std::sync::Arc;
608+
use std::time::Duration;
609+
610+
use async_trait::async_trait;
611+
use rust_decimal::Decimal;
612+
613+
use crate::agent::agent_loop::{Agent, AgentDeps};
614+
use crate::agent::cost_guard::{CostGuard, CostGuardConfig};
615+
use crate::agent::session::Session;
616+
use crate::channels::ChannelManager;
617+
use crate::config::{AgentConfig, SafetyConfig, SkillsConfig};
618+
use crate::context::ContextManager;
607619
use crate::error::Error;
620+
use crate::hooks::HookRegistry;
621+
use crate::llm::{
622+
CompletionRequest, CompletionResponse, FinishReason, LlmProvider, ToolCall,
623+
ToolCompletionRequest, ToolCompletionResponse,
624+
};
625+
use crate::safety::SafetyLayer;
626+
use crate::tools::ToolRegistry;
627+
628+
use super::check_auth_required;
629+
630+
/// Minimal LLM provider for unit tests that always returns a static response.
631+
struct StaticLlmProvider;
632+
633+
#[async_trait]
634+
impl LlmProvider for StaticLlmProvider {
635+
fn model_name(&self) -> &str {
636+
"static-mock"
637+
}
638+
639+
fn cost_per_token(&self) -> (Decimal, Decimal) {
640+
(Decimal::ZERO, Decimal::ZERO)
641+
}
642+
643+
async fn complete(
644+
&self,
645+
_request: CompletionRequest,
646+
) -> Result<CompletionResponse, crate::error::LlmError> {
647+
Ok(CompletionResponse {
648+
content: "ok".to_string(),
649+
input_tokens: 0,
650+
output_tokens: 0,
651+
finish_reason: FinishReason::Stop,
652+
response_id: None,
653+
})
654+
}
655+
656+
async fn complete_with_tools(
657+
&self,
658+
_request: ToolCompletionRequest,
659+
) -> Result<ToolCompletionResponse, crate::error::LlmError> {
660+
Ok(ToolCompletionResponse {
661+
content: Some("ok".to_string()),
662+
tool_calls: Vec::new(),
663+
input_tokens: 0,
664+
output_tokens: 0,
665+
finish_reason: FinishReason::Stop,
666+
response_id: None,
667+
})
668+
}
669+
}
670+
671+
/// Build a minimal `Agent` for unit testing (no DB, no workspace, no extensions).
672+
fn make_test_agent() -> Agent {
673+
let deps = AgentDeps {
674+
store: None,
675+
llm: Arc::new(StaticLlmProvider),
676+
cheap_llm: None,
677+
safety: Arc::new(SafetyLayer::new(&SafetyConfig {
678+
max_output_length: 100_000,
679+
injection_check_enabled: true,
680+
})),
681+
tools: Arc::new(ToolRegistry::new()),
682+
workspace: None,
683+
extension_manager: None,
684+
skill_registry: None,
685+
skills_config: SkillsConfig::default(),
686+
hooks: Arc::new(HookRegistry::new()),
687+
cost_guard: Arc::new(CostGuard::new(CostGuardConfig::default())),
688+
};
689+
690+
Agent::new(
691+
AgentConfig {
692+
name: "test-agent".to_string(),
693+
max_parallel_jobs: 1,
694+
job_timeout: Duration::from_secs(60),
695+
stuck_threshold: Duration::from_secs(60),
696+
repair_check_interval: Duration::from_secs(30),
697+
max_repair_attempts: 1,
698+
use_planning: false,
699+
session_idle_timeout: Duration::from_secs(300),
700+
allow_local_tools: false,
701+
max_cost_per_day_cents: None,
702+
max_actions_per_hour: None,
703+
},
704+
deps,
705+
ChannelManager::new(),
706+
None,
707+
None,
708+
None,
709+
Some(Arc::new(ContextManager::new(1))),
710+
None,
711+
)
712+
}
713+
714+
#[test]
715+
fn test_make_test_agent_succeeds() {
716+
// Verify that a test agent can be constructed without panicking.
717+
let _agent = make_test_agent();
718+
}
719+
720+
#[test]
721+
fn test_auto_approved_tool_is_respected() {
722+
let _agent = make_test_agent();
723+
let mut session = Session::new("user-1");
724+
session.auto_approve_tool("http");
725+
726+
// A non-shell tool that is auto-approved should be approved.
727+
assert!(session.is_tool_auto_approved("http"));
728+
// A tool that hasn't been auto-approved should not be.
729+
assert!(!session.is_tool_auto_approved("shell"));
730+
}
731+
732+
#[test]
733+
fn test_shell_destructive_command_requires_approval_for() {
734+
// ShellTool::requires_approval_for should detect destructive commands.
735+
// This exercises the same code path used inline in run_agentic_loop.
736+
use crate::tools::builtin::shell::requires_explicit_approval;
737+
738+
let destructive_cmds = [
739+
"rm -rf /tmp/test",
740+
"git push --force origin main",
741+
"git reset --hard HEAD~5",
742+
];
743+
for cmd in &destructive_cmds {
744+
assert!(
745+
requires_explicit_approval(cmd),
746+
"'{}' should require explicit approval",
747+
cmd
748+
);
749+
}
750+
751+
let safe_cmds = ["git status", "cargo build", "ls -la"];
752+
for cmd in &safe_cmds {
753+
assert!(
754+
!requires_explicit_approval(cmd),
755+
"'{}' should not require explicit approval",
756+
cmd
757+
);
758+
}
759+
}
608760

609-
use super::detect_auth_awaiting;
761+
#[test]
762+
fn test_pending_approval_serialization_backcompat_without_deferred_calls() {
763+
// PendingApproval from before the deferred_tool_calls field was added
764+
// should deserialize with an empty vec (via #[serde(default)]).
765+
let json = serde_json::json!({
766+
"request_id": uuid::Uuid::new_v4(),
767+
"tool_name": "http",
768+
"parameters": {"url": "https://example.com", "method": "GET"},
769+
"description": "Make HTTP request",
770+
"tool_call_id": "call_123",
771+
"context_messages": [{"role": "user", "content": "go"}]
772+
})
773+
.to_string();
774+
775+
let parsed: crate::agent::session::PendingApproval =
776+
serde_json::from_str(&json).expect("should deserialize without deferred_tool_calls");
777+
778+
assert!(parsed.deferred_tool_calls.is_empty());
779+
assert_eq!(parsed.tool_name, "http");
780+
assert_eq!(parsed.tool_call_id, "call_123");
781+
}
782+
783+
#[test]
784+
fn test_pending_approval_serialization_roundtrip_with_deferred_calls() {
785+
let pending = crate::agent::session::PendingApproval {
786+
request_id: uuid::Uuid::new_v4(),
787+
tool_name: "shell".to_string(),
788+
parameters: serde_json::json!({"command": "echo hi"}),
789+
description: "Run shell command".to_string(),
790+
tool_call_id: "call_1".to_string(),
791+
context_messages: vec![],
792+
deferred_tool_calls: vec![
793+
ToolCall {
794+
id: "call_2".to_string(),
795+
name: "http".to_string(),
796+
arguments: serde_json::json!({"url": "https://example.com"}),
797+
},
798+
ToolCall {
799+
id: "call_3".to_string(),
800+
name: "echo".to_string(),
801+
arguments: serde_json::json!({"message": "done"}),
802+
},
803+
],
804+
};
805+
806+
let json = serde_json::to_string(&pending).expect("serialize");
807+
let parsed: crate::agent::session::PendingApproval =
808+
serde_json::from_str(&json).expect("deserialize");
809+
810+
assert_eq!(parsed.deferred_tool_calls.len(), 2);
811+
assert_eq!(parsed.deferred_tool_calls[0].name, "http");
812+
assert_eq!(parsed.deferred_tool_calls[1].name, "echo");
813+
}
610814

611815
#[test]
612816
fn test_detect_auth_awaiting_positive() {
@@ -619,7 +823,7 @@ mod tests {
619823
})
620824
.to_string());
621825

622-
let detected = detect_auth_awaiting("tool_auth", &result);
826+
let detected = check_auth_required("tool_auth", &result);
623827
assert!(detected.is_some());
624828
let (name, instructions) = detected.unwrap();
625829
assert_eq!(name, "telegram");
@@ -636,7 +840,7 @@ mod tests {
636840
})
637841
.to_string());
638842

639-
assert!(detect_auth_awaiting("tool_auth", &result).is_none());
843+
assert!(check_auth_required("tool_auth", &result).is_none());
640844
}
641845

642846
#[test]
@@ -647,14 +851,14 @@ mod tests {
647851
})
648852
.to_string());
649853

650-
assert!(detect_auth_awaiting("tool_list", &result).is_none());
854+
assert!(check_auth_required("tool_list", &result).is_none());
651855
}
652856

653857
#[test]
654858
fn test_detect_auth_awaiting_error_result() {
655859
let result: Result<String, Error> =
656860
Err(crate::error::ToolError::NotFound { name: "x".into() }.into());
657-
assert!(detect_auth_awaiting("tool_auth", &result).is_none());
861+
assert!(check_auth_required("tool_auth", &result).is_none());
658862
}
659863

660864
#[test]
@@ -666,7 +870,7 @@ mod tests {
666870
})
667871
.to_string());
668872

669-
let (_, instructions) = detect_auth_awaiting("tool_auth", &result).unwrap();
873+
let (_, instructions) = check_auth_required("tool_auth", &result).unwrap();
670874
assert_eq!(instructions, "Please provide your API token/key.");
671875
}
672876

@@ -681,7 +885,7 @@ mod tests {
681885
})
682886
.to_string());
683887

684-
let detected = detect_auth_awaiting("tool_activate", &result);
888+
let detected = check_auth_required("tool_activate", &result);
685889
assert!(detected.is_some());
686890
let (name, instructions) = detected.unwrap();
687891
assert_eq!(name, "slack");
@@ -697,6 +901,6 @@ mod tests {
697901
})
698902
.to_string());
699903

700-
assert!(detect_auth_awaiting("tool_activate", &result).is_none());
904+
assert!(check_auth_required("tool_activate", &result).is_none());
701905
}
702906
}

src/agent/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,6 @@ pub use self_repair::{BrokenTool, RepairResult, RepairTask, SelfRepair, StuckJob
4444
pub use session::{PendingApproval, PendingAuth, Session, Thread, ThreadState, Turn, TurnState};
4545
pub use session_manager::SessionManager;
4646
pub use submission::{Submission, SubmissionParser, SubmissionResult};
47-
pub use task::{Task, TaskContext, TaskHandler, TaskOutput, TaskStatus};
47+
pub use task::{Task, TaskContext, TaskHandler, TaskOutput};
4848
pub use undo::{Checkpoint, UndoManager};
4949
pub use worker::{Worker, WorkerDeps};

0 commit comments

Comments
 (0)