Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/tools/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ pub mod shell;
pub mod tool_search;
pub mod traits;
pub mod web_fetch;
mod web_search_provider_routing;
pub mod web_search_tool;

pub use browser::{BrowserTool, ComputerUseConfig};
Expand Down
73 changes: 73 additions & 0 deletions src/tools/web_search_provider_routing.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WebSearchProviderRoute {
DuckDuckGo,
Brave,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct WebSearchProviderResolution {
pub route: WebSearchProviderRoute,
pub canonical_provider: &'static str,
pub used_fallback: bool,
}

pub const DEFAULT_WEB_SEARCH_PROVIDER: &str = "duckduckgo";
const BRAVE_PROVIDER: &str = "brave";

pub fn resolve_web_search_provider(raw_provider: &str) -> WebSearchProviderResolution {
let normalized = raw_provider.trim().to_ascii_lowercase();
match normalized.as_str() {
"" | "default" | "duckduckgo" | "ddg" | "duck-duck-go" | "duck_duck_go" => {
WebSearchProviderResolution {
route: WebSearchProviderRoute::DuckDuckGo,
canonical_provider: DEFAULT_WEB_SEARCH_PROVIDER,
used_fallback: false,
}
}
"brave" | "brave-search" | "brave_search" => WebSearchProviderResolution {
route: WebSearchProviderRoute::Brave,
canonical_provider: BRAVE_PROVIDER,
used_fallback: false,
},
_ => WebSearchProviderResolution {
route: WebSearchProviderRoute::DuckDuckGo,
canonical_provider: DEFAULT_WEB_SEARCH_PROVIDER,
used_fallback: true,
},
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn resolve_aliases_to_duckduckgo() {
let ddg_aliases = ["duckduckgo", "ddg", "duck-duck-go", "duck_duck_go"];
for alias in ddg_aliases {
let resolved = resolve_web_search_provider(alias);
assert_eq!(resolved.route, WebSearchProviderRoute::DuckDuckGo);
assert_eq!(resolved.canonical_provider, DEFAULT_WEB_SEARCH_PROVIDER);
assert!(!resolved.used_fallback);
}
}

#[test]
fn resolve_aliases_to_brave() {
let brave_aliases = ["brave", "brave-search", "brave_search"];
for alias in brave_aliases {
let resolved = resolve_web_search_provider(alias);
assert_eq!(resolved.route, WebSearchProviderRoute::Brave);
assert_eq!(resolved.canonical_provider, BRAVE_PROVIDER);
assert!(!resolved.used_fallback);
}
}

#[test]
fn resolve_unknown_provider_falls_back_to_default() {
let resolved = resolve_web_search_provider("bing");
assert_eq!(resolved.route, WebSearchProviderRoute::DuckDuckGo);
assert_eq!(resolved.canonical_provider, DEFAULT_WEB_SEARCH_PROVIDER);
assert!(resolved.used_fallback);
}
}
21 changes: 14 additions & 7 deletions src/tools/web_search_tool.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::traits::{Tool, ToolResult};
use super::web_search_provider_routing::{resolve_web_search_provider, WebSearchProviderRoute};
use async_trait::async_trait;
use regex::Regex;
use serde_json::json;
Expand All @@ -13,6 +14,7 @@ use std::time::Duration;
/// `[web_search] brave_api_key` field, and uses the result. This ensures that
/// keys set or rotated after boot, and encrypted keys, are correctly picked up.
pub struct WebSearchTool {
/// Provider selector as configured by user. Routed via provider aliases at runtime.
provider: String,
/// Boot-time key snapshot (may be `None` if not yet configured at startup).
boot_brave_api_key: Option<String>,
Expand Down Expand Up @@ -300,13 +302,18 @@ impl Tool for WebSearchTool {

tracing::info!("Searching web for: {}", query);

let result = match self.provider.as_str() {
"duckduckgo" | "ddg" => self.search_duckduckgo(query).await?,
"brave" => self.search_brave(query).await?,
_ => anyhow::bail!(
"Unknown search provider: '{}'. Set tools.web_search.provider to 'duckduckgo' or 'brave' in config.toml",
self.provider
),
let resolution = resolve_web_search_provider(&self.provider);
if resolution.used_fallback {
tracing::warn!(
"Unknown web search provider '{}'; falling back to '{}'",
self.provider,
resolution.canonical_provider
);
}

let result = match resolution.route {
WebSearchProviderRoute::DuckDuckGo => self.search_duckduckgo(query).await?,
WebSearchProviderRoute::Brave => self.search_brave(query).await?,
};

Ok(ToolResult {
Expand Down
Loading