Skip to content

Commit 37884eb

Browse files
adds additional search params
Introduces support for additional Brave Search API web-search parameters.
1 parent 63a508f commit 37884eb

File tree

2 files changed

+151
-37
lines changed

2 files changed

+151
-37
lines changed

lib/crewai-tools/src/crewai_tools/tools/brave_search_tool/brave_search_tool.py

Lines changed: 131 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from datetime import datetime
2+
import json
23
import os
34
import time
4-
from typing import Any, ClassVar
5+
from typing import Annotated, Any, ClassVar, Literal
56

67
from crewai.tools import BaseTool, EnvVar
78
from pydantic import BaseModel, Field
9+
from pydantic.types import StringConstraints
810
import requests
911

1012

@@ -15,14 +17,60 @@ def _save_results_to_file(content: str) -> None:
1517
file.write(content)
1618

1719

20+
FreshnessPreset = Literal["pd", "pw", "pm", "py"]
21+
FreshnessRange = Annotated[
22+
str, StringConstraints(pattern=r"^\d{4}-\d{2}-\d{2}to\d{4}-\d{2}-\d{2}$")
23+
]
24+
Freshness = FreshnessPreset | FreshnessRange
25+
SafeSearch = Literal["off", "moderate", "strict"]
26+
27+
1828
class BraveSearchToolSchema(BaseModel):
19-
"""Input for BraveSearchTool."""
29+
"""Input for BraveSearchTool"""
2030

21-
search_query: str = Field(
22-
..., description="Mandatory search query you want to use to search the internet"
31+
query: str = Field(..., description="Search query to perform")
32+
country: str | None = Field(
33+
default=None,
34+
description="Country code for geo-targeting (e.g., 'US', 'BR').",
35+
)
36+
search_language: str | None = Field(
37+
default=None,
38+
description="Language code for the search results (e.g., 'en', 'es').",
39+
)
40+
count: int | None = Field(
41+
default=None,
42+
description="The maximum number of results to return. Actual number may be less.",
43+
)
44+
offset: int | None = Field(
45+
default=None, description="Skip the first N result sets/pages. Max is 9."
46+
)
47+
safesearch: SafeSearch | None = Field(
48+
default=None,
49+
description="Filter out explicit content. Options: off/moderate/strict",
50+
)
51+
spellcheck: bool | None = Field(
52+
default=None,
53+
description="Attempt to correct spelling errors in the search query.",
54+
)
55+
freshness: Freshness | None = Field(
56+
default=None,
57+
description="Enforce freshness of results. Options: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
58+
)
59+
text_decorations: bool | None = Field(
60+
default=None,
61+
description="Include markup to highlight search terms in the results.",
62+
)
63+
extra_snippets: bool | None = Field(
64+
default=None,
65+
description="Include up to 5 text snippets for each page if possible.",
66+
)
67+
operators: bool | None = Field(
68+
default=None,
69+
description="Whether to apply search operators (e.g., site:example.com).",
2370
)
2471

2572

73+
# TODO: Extend support to additional endpoints (e.g., /images, /news, etc.)
2674
class BraveSearchTool(BaseTool):
2775
"""BraveSearchTool - A tool for performing web searches using the Brave Search API.
2876
@@ -35,13 +83,13 @@ class BraveSearchTool(BaseTool):
3583
- python-dotenv (for API key management)
3684
"""
3785

38-
name: str = "Brave Web Search the internet"
86+
name: str = "Brave Search"
3987
description: str = (
40-
"A tool that can be used to search the internet with a search_query."
88+
"A tool that performs web searches using the Brave Search API. "
89+
"Results are returned as structured JSON data."
4190
)
4291
args_schema: type[BaseModel] = BraveSearchToolSchema
4392
search_url: str = "https://api.search.brave.com/res/v1/web/search"
44-
country: str | None = ""
4593
n_results: int = 10
4694
save_file: bool = False
4795
_last_request_time: ClassVar[float] = 0
@@ -73,19 +121,64 @@ def _run(
73121
self._min_request_interval - (current_time - self._last_request_time)
74122
)
75123
BraveSearchTool._last_request_time = time.time()
124+
125+
# Construct and send the request
76126
try:
77-
search_query = kwargs.get("search_query") or kwargs.get("query")
78-
if not search_query:
79-
raise ValueError("Search query is required")
127+
# Maintain both "search_query" and "query" for backwards compatibility
128+
query = kwargs.get("search_query") or kwargs.get("query")
129+
if not query:
130+
raise ValueError("Query is required")
131+
132+
payload = {"q": query}
133+
134+
if country := kwargs.get("country"):
135+
payload["country"] = country
136+
137+
if search_language := kwargs.get("search_language"):
138+
payload["search_language"] = search_language
139+
140+
# Fallback to deprecated n_results parameter if no count is provided
141+
count = kwargs.get("count")
142+
if count is not None:
143+
payload["count"] = count
144+
else:
145+
payload["count"] = self.n_results
146+
147+
# Offset may be 0, so avoid truthiness check
148+
offset = kwargs.get("offset")
149+
if offset is not None:
150+
payload["offset"] = offset
151+
152+
if safesearch := kwargs.get("safesearch"):
153+
payload["safesearch"] = safesearch
80154

81155
save_file = kwargs.get("save_file", self.save_file)
82-
n_results = kwargs.get("n_results", self.n_results)
156+
if freshness := kwargs.get("freshness"):
157+
payload["freshness"] = freshness
158+
159+
# Boolean parameters
160+
spellcheck = kwargs.get("spellcheck")
161+
if spellcheck is not None:
162+
payload["spellcheck"] = spellcheck
163+
164+
text_decorations = kwargs.get("text_decorations")
165+
if text_decorations is not None:
166+
payload["text_decorations"] = text_decorations
83167

84-
payload = {"q": search_query, "count": n_results}
168+
extra_snippets = kwargs.get("extra_snippets")
169+
if extra_snippets is not None:
170+
payload["extra_snippets"] = extra_snippets
85171

86-
if self.country != "":
87-
payload["country"] = self.country
172+
operators = kwargs.get("operators")
173+
if operators is not None:
174+
payload["operators"] = operators
88175

176+
# Limit the result types to "web" since there is presently no
177+
# handling of other types like "discussions", "faq", "infobox",
178+
# "news", "videos", or "locations".
179+
payload["result_filter"] = "web"
180+
181+
# Setup Request Headers
89182
headers = {
90183
"X-Subscription-Token": os.environ["BRAVE_API_KEY"],
91184
"Accept": "application/json",
@@ -97,25 +190,32 @@ def _run(
97190
response.raise_for_status() # Handle non-200 responses
98191
results = response.json()
99192

193+
# TODO: Handle other result types like "discussions", "faq", etc.
194+
web_results_items = []
100195
if "web" in results:
101-
results = results["web"]["results"]
102-
string = []
103-
for result in results:
104-
try:
105-
string.append(
106-
"\n".join(
107-
[
108-
f"Title: {result['title']}",
109-
f"Link: {result['url']}",
110-
f"Snippet: {result['description']}",
111-
"---",
112-
]
113-
)
114-
)
115-
except KeyError: # noqa: PERF203
196+
web_results = results["web"]["results"]
197+
198+
for result in web_results:
199+
url = result.get("url")
200+
title = result.get("title")
201+
# If, for whatever reason, this entry does not have a title
202+
# or url, skip it.
203+
if not url or not title:
116204
continue
117-
118-
content = "\n".join(string)
205+
item = {
206+
"url": url,
207+
"title": title,
208+
}
209+
description = result.get("description")
210+
if description:
211+
item["description"] = description
212+
snippets = result.get("extra_snippets")
213+
if snippets:
214+
item["snippets"] = snippets
215+
216+
web_results_items.append(item)
217+
218+
content = json.dumps(web_results_items)
119219
except requests.RequestException as e:
120220
return f"Error performing search: {e!s}"
121221
except KeyError as e:

lib/crewai-tools/tests/tools/brave_search_tool_test.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,30 @@ def test_brave_tool_search(mock_get, brave_tool):
3030
}
3131
mock_get.return_value.json.return_value = mock_response
3232

33-
result = brave_tool.run(search_query="test")
33+
result = brave_tool.run(query="test")
3434
assert "Test Title" in result
3535
assert "http://test.com" in result
3636

3737

38-
def test_brave_tool():
39-
tool = BraveSearchTool(
40-
n_results=2,
41-
)
42-
tool.run(search_query="ChatGPT")
38+
@patch("requests.get")
39+
def test_brave_tool(mock_get):
40+
mock_response = {
41+
"web": {
42+
"results": [
43+
{
44+
"title": "Brave Browser",
45+
"url": "https://brave.com",
46+
"description": "Brave Browser description",
47+
}
48+
]
49+
}
50+
}
51+
mock_get.return_value.json.return_value = mock_response
52+
53+
tool = BraveSearchTool(n_results=2)
54+
result = tool.run(query="Brave Browser")
55+
assert "Brave Browser" in result
56+
assert "brave.com" in result
4357

4458

4559
if __name__ == "__main__":

0 commit comments

Comments
 (0)