Skip to content

Commit a6eb14a

Browse files
adds additional search params
Introduces support for additional Brave Search API web-search parameters.
1 parent 63a508f commit a6eb14a

File tree

2 files changed

+161
-38
lines changed

2 files changed

+161
-38
lines changed

lib/crewai-tools/src/crewai_tools/tools/brave_search_tool/brave_search_tool.py

Lines changed: 141 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from datetime import datetime
2+
import json
23
import os
34
import time
4-
from typing import Any, ClassVar
5+
from typing import Annotated, Any, ClassVar, Literal
56

67
from crewai.tools import BaseTool, EnvVar
78
from pydantic import BaseModel, Field
9+
from pydantic.types import StringConstraints
810
import requests
911

1012

@@ -15,14 +17,75 @@ def _save_results_to_file(content: str) -> None:
1517
file.write(content)
1618

1719

20+
FreshnessPreset = Literal["pd", "pw", "pm", "py"]
21+
FreshnessRange = Annotated[
22+
str, StringConstraints(pattern=r"^\d{4}-\d{2}-\d{2}to\d{4}-\d{2}-\d{2}$")
23+
]
24+
Freshness = FreshnessPreset | FreshnessRange
25+
SafeSearch = Literal["off", "moderate", "strict"]
26+
27+
1828
class BraveSearchToolSchema(BaseModel):
19-
"""Input for BraveSearchTool."""
29+
"""Input for BraveSearchTool.
30+
31+
Attributes:
32+
query (str): The search query to use (e.g., "latest Brave news").
33+
country (Optional[str]): Two-letter country code for geo-targeting (e.g., "US", "BR"). Default is "US".
34+
search_language (Optional[str]): 2 or more character language code for which the search results are provided (e.g., "en", "es"). Default is "en".
35+
count (Optional[int]): Number of results desired. The maximum is 20. Actual number may be less. Default is 10.
36+
offset (Optional[int]): Number of result sets/pages to skip. Max is 9. Default is 0.
37+
safesearch (Optional[Literal["off", "moderate", "strict"]]): Level of safe search to apply. Default is "moderate".
38+
spellcheck (Optional[bool]): Whether to apply spell checking to the search query.
39+
freshness (Optional[Freshness]): Filters search results by date discovered. Supported values: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD
40+
text_decorations (Optional[bool]): Whether strings (e.g., result snippets) should include decoration markers (e.g. highlighting characters).
41+
extra_snippets (Optional[bool]): Whether to include up to 5 extra snippets of text for each result.
42+
operators (Optional[bool]): Whether to apply search operators to the query (e.g., 'site:example.com' or 'intitle:example'). Default is True.
43+
"""
2044

21-
search_query: str = Field(
22-
..., description="Mandatory search query you want to use to search the internet"
45+
query: str = Field(..., description="Search query to perform")
46+
country: str | None = Field(
47+
default="US",
48+
description="Two-letter country code for geo-targeting (e.g., 'US', 'BR').",
49+
)
50+
search_language: str | None = Field(
51+
default="en",
52+
description="2 or more character language code for which the search results are provided (e.g., 'en', 'es').",
53+
)
54+
count: int | None = Field(
55+
default=None,
56+
description="Number of results desired. The maximum is 20. Actual number may be less.",
57+
)
58+
offset: int | None = Field(
59+
default=0,
60+
description="Number of result sets/pages to skip. Max is 9.",
61+
)
62+
safesearch: SafeSearch | None = Field(
63+
default="moderate",
64+
description="Level of safe search to apply. Default is 'moderate'.",
65+
)
66+
spellcheck: bool | None = Field(
67+
default=True,
68+
description="Whether to apply spell checking to the search query.",
69+
)
70+
freshness: Freshness | None = Field(
71+
default=None,
72+
description="Filters search results by when they were discovered. Supported values: pd/pw/pm/py, or YYYY-MM-DDtoYYYY-MM-DD",
73+
)
74+
text_decorations: bool | None = Field(
75+
default=None,
76+
description="Whether strings (e.g., result snippets) should include decoration markers (e.g. highlighting characters).",
77+
)
78+
extra_snippets: bool | None = Field(
79+
default=None,
80+
description="Snippet is an excerpt from a page you get as a result of the query, and extra_snippets allow you to get up to 5 additional, alternative excerpts.",
81+
)
82+
operators: bool | None = Field(
83+
default=True,
84+
description="Whether to apply search operators to the query (e.g., 'site:example.com' or 'intitle:example').",
2385
)
2486

2587

88+
# TODO: Extend support to additional endpoints (e.g., /images, /news, etc.)
2689
class BraveSearchTool(BaseTool):
2790
"""BraveSearchTool - A tool for performing web searches using the Brave Search API.
2891
@@ -36,12 +99,9 @@ class BraveSearchTool(BaseTool):
3699
"""
37100

38101
name: str = "Brave Web Search the internet"
39-
description: str = (
40-
"A tool that can be used to search the internet with a search_query."
41-
)
102+
description: str = "Tool to perform web search using Brave Search API."
42103
args_schema: type[BaseModel] = BraveSearchToolSchema
43104
search_url: str = "https://api.search.brave.com/res/v1/web/search"
44-
country: str | None = ""
45105
n_results: int = 10
46106
save_file: bool = False
47107
_last_request_time: ClassVar[float] = 0
@@ -73,19 +133,61 @@ def _run(
73133
self._min_request_interval - (current_time - self._last_request_time)
74134
)
75135
BraveSearchTool._last_request_time = time.time()
136+
137+
# Construct and send the request
76138
try:
77-
search_query = kwargs.get("search_query") or kwargs.get("query")
78-
if not search_query:
79-
raise ValueError("Search query is required")
139+
# Maintain both "search_query" and "query" for backwards compatibility
140+
query = kwargs.get("search_query") or kwargs.get("query")
141+
if not query:
142+
raise ValueError("Query is required")
143+
144+
payload = {"q": query}
145+
146+
if country := kwargs.get("country"):
147+
payload["country"] = country
148+
149+
if search_language := kwargs.get("search_language"):
150+
payload["search_language"] = search_language
151+
152+
# Fallback to deprecated n_results parameter if no count is provided
153+
count = kwargs.get("count", self.n_results)
154+
payload["count"] = count
155+
156+
# Offset may be 0, so avoid truthiness check
157+
offset = kwargs.get("offset")
158+
if offset is not None:
159+
payload["offset"] = offset
160+
161+
if safesearch := kwargs.get("safesearch"):
162+
payload["safesearch"] = safesearch
80163

81164
save_file = kwargs.get("save_file", self.save_file)
82-
n_results = kwargs.get("n_results", self.n_results)
165+
if freshness := kwargs.get("freshness"):
166+
payload["freshness"] = freshness
167+
168+
# Boolean parameters
169+
spellcheck = kwargs.get("spellcheck")
170+
if spellcheck is not None:
171+
payload["spellcheck"] = spellcheck
83172

84-
payload = {"q": search_query, "count": n_results}
173+
text_decorations = kwargs.get("text_decorations")
174+
if text_decorations is not None:
175+
payload["text_decorations"] = text_decorations
85176

86-
if self.country != "":
87-
payload["country"] = self.country
177+
extra_snippets = kwargs.get("extra_snippets")
178+
if extra_snippets is not None:
179+
payload["extra_snippets"] = extra_snippets
88180

181+
operators = kwargs.get("operators")
182+
if operators is not None:
183+
payload["operators"] = operators
184+
185+
# Limit the result types to "web" since there is presently no
186+
# handling of other types like "discussions", "faq", "infobox",
187+
# "news", "videos", or "locations".
188+
payload["result_filter"] = "web"
189+
190+
# Setup Request Headers
89191
headers = {
90192
"X-Subscription-Token": os.environ["BRAVE_API_KEY"],
91193
"Accept": "application/json",
@@ -97,25 +199,32 @@ def _run(
97199
response.raise_for_status() # Handle non-200 responses
98200
results = response.json()
99201

202+
# TODO: Handle other result types like "discussions", "faq", etc.
203+
web_results_items = []
100204
if "web" in results:
101-
results = results["web"]["results"]
102-
string = []
103-
for result in results:
104-
try:
105-
string.append(
106-
"\n".join(
107-
[
108-
f"Title: {result['title']}",
109-
f"Link: {result['url']}",
110-
f"Snippet: {result['description']}",
111-
"---",
112-
]
113-
)
114-
)
115-
except KeyError: # noqa: PERF203
205+
web_results = results["web"]["results"]
206+
207+
for result in web_results:
208+
url = result.get("url")
209+
title = result.get("title")
210+
# If, for whatever reason, this entry does not have a title
211+
# or url, skip it.
212+
if not url or not title:
116213
continue
117-
118-
content = "\n".join(string)
214+
item = {
215+
"url": url,
216+
"title": title,
217+
}
218+
description = result.get("description")
219+
if description:
220+
item["description"] = description
221+
snippets = result.get("extra_snippets")
222+
if snippets:
223+
item["snippets"] = snippets
224+
225+
web_results_items.append(item)
226+
227+
content = json.dumps(web_results_items)
119228
except requests.RequestException as e:
120229
return f"Error performing search: {e!s}"
121230
except KeyError as e:

lib/crewai-tools/tests/tools/brave_search_tool_test.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,30 @@ def test_brave_tool_search(mock_get, brave_tool):
3030
}
3131
mock_get.return_value.json.return_value = mock_response
3232

33-
result = brave_tool.run(search_query="test")
33+
result = brave_tool.run(query="test")
3434
assert "Test Title" in result
3535
assert "http://test.com" in result
3636

3737

38-
def test_brave_tool():
39-
tool = BraveSearchTool(
40-
n_results=2,
41-
)
42-
tool.run(search_query="ChatGPT")
38+
@patch("requests.get")
39+
def test_brave_tool(mock_get):
40+
mock_response = {
41+
"web": {
42+
"results": [
43+
{
44+
"title": "Brave Browser",
45+
"url": "https://brave.com",
46+
"description": "Brave Browser description",
47+
}
48+
]
49+
}
50+
}
51+
mock_get.return_value.json.return_value = mock_response
52+
53+
tool = BraveSearchTool(n_results=2)
54+
result = tool.run(query="Brave Browser")
55+
assert "Brave Browser" in result
56+
assert "brave.com" in result
4357

4458

4559
if __name__ == "__main__":

0 commit comments

Comments
 (0)