Skip to content

Commit 3c016ea

Browse files
4shubShubham Naikmattzh72
authored
feat: allow list tools by tool type [PRO-870] (#4036)
* feat: allow list tools by tool type * chore: update list * chore: respond to comments * chore: refactor tools hella * Add tests to managers * chore: branch --------- Co-authored-by: Shubham Naik <[email protected]> Co-authored-by: Matt Zhou <[email protected]>
1 parent 3e7e063 commit 3c016ea

File tree

3 files changed

+789
-18
lines changed

3 files changed

+789
-18
lines changed

letta/server/rest_api/routers/v1/tools.py

Lines changed: 158 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from letta.orm.errors import UniqueConstraintViolationError
2828
from letta.orm.mcp_oauth import OAuthSessionStatus
2929
from letta.prompts.gpt_system import get_system_text
30-
from letta.schemas.enums import MessageRole
30+
from letta.schemas.enums import MessageRole, ToolType
3131
from letta.schemas.letta_message import ToolReturnMessage
3232
from letta.schemas.letta_message_content import TextContent
3333
from letta.schemas.mcp import UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer
@@ -62,16 +62,94 @@ async def delete_tool(
6262

6363
@router.get("/count", response_model=int, operation_id="count_tools")
6464
async def count_tools(
65+
name: Optional[str] = None,
66+
names: Optional[List[str]] = Query(None, description="Filter by specific tool names"),
67+
tool_ids: Optional[List[str]] = Query(
68+
None, description="Filter by specific tool IDs - accepts repeated params or comma-separated values"
69+
),
70+
search: Optional[str] = Query(None, description="Search tool names (case-insensitive partial match)"),
71+
tool_types: Optional[List[str]] = Query(None, description="Filter by tool type(s) - accepts repeated params or comma-separated values"),
72+
exclude_tool_types: Optional[List[str]] = Query(
73+
None, description="Tool type(s) to exclude - accepts repeated params or comma-separated values"
74+
),
75+
return_only_letta_tools: Optional[bool] = Query(False, description="Count only tools with tool_type starting with 'letta_'"),
76+
exclude_letta_tools: Optional[bool] = Query(False, description="Exclude built-in Letta tools from the count"),
6577
server: SyncServer = Depends(get_letta_server),
6678
actor_id: Optional[str] = Header(None, alias="user_id"),
67-
include_base_tools: Optional[bool] = Query(False, description="Include built-in Letta tools in the count"),
6879
):
6980
"""
7081
Get a count of all tools available to agents belonging to the org of the user.
7182
"""
7283
try:
84+
# Helper function to parse tool types - supports both repeated params and comma-separated values
85+
def parse_tool_types(tool_types_input: Optional[List[str]]) -> Optional[List[str]]:
86+
if tool_types_input is None:
87+
return None
88+
89+
# Flatten any comma-separated values and validate against ToolType enum
90+
flattened_types = []
91+
for item in tool_types_input:
92+
# Split by comma in case user provided comma-separated values
93+
types_in_item = [t.strip() for t in item.split(",") if t.strip()]
94+
flattened_types.extend(types_in_item)
95+
96+
# Validate each type against the ToolType enum
97+
valid_types = []
98+
valid_values = [tt.value for tt in ToolType]
99+
100+
for tool_type in flattened_types:
101+
if tool_type not in valid_values:
102+
raise HTTPException(
103+
status_code=400, detail=f"Invalid tool_type '{tool_type}'. Must be one of: {', '.join(valid_values)}"
104+
)
105+
valid_types.append(tool_type)
106+
107+
return valid_types if valid_types else None
108+
109+
# Parse and validate tool types (same logic as list_tools)
110+
tool_types_str = parse_tool_types(tool_types)
111+
exclude_tool_types_str = parse_tool_types(exclude_tool_types)
112+
73113
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
74-
return await server.tool_manager.size_async(actor=actor, include_base_tools=include_base_tools)
114+
115+
# Combine single name with names list for unified processing (same logic as list_tools)
116+
combined_names = []
117+
if name is not None:
118+
combined_names.append(name)
119+
if names is not None:
120+
combined_names.extend(names)
121+
122+
# Use None if no names specified, otherwise use the combined list
123+
final_names = combined_names if combined_names else None
124+
125+
# Helper function to parse tool IDs - supports both repeated params and comma-separated values
126+
def parse_tool_ids(tool_ids_input: Optional[List[str]]) -> Optional[List[str]]:
127+
if tool_ids_input is None:
128+
return None
129+
130+
# Flatten any comma-separated values
131+
flattened_ids = []
132+
for item in tool_ids_input:
133+
# Split by comma in case user provided comma-separated values
134+
ids_in_item = [id.strip() for id in item.split(",") if id.strip()]
135+
flattened_ids.extend(ids_in_item)
136+
137+
return flattened_ids if flattened_ids else None
138+
139+
# Parse tool IDs (same logic as list_tools)
140+
final_tool_ids = parse_tool_ids(tool_ids)
141+
142+
# Get the count of tools using unified query
143+
return await server.tool_manager.count_tools_async(
144+
actor=actor,
145+
tool_types=tool_types_str,
146+
exclude_tool_types=exclude_tool_types_str,
147+
names=final_names,
148+
tool_ids=final_tool_ids,
149+
search=search,
150+
return_only_letta_tools=return_only_letta_tools,
151+
exclude_letta_tools=exclude_letta_tools,
152+
)
75153
except Exception as e:
76154
print(f"Error occurred: {e}")
77155
raise HTTPException(status_code=500, detail=str(e))
@@ -99,20 +177,93 @@ async def list_tools(
99177
after: Optional[str] = None,
100178
limit: Optional[int] = 50,
101179
name: Optional[str] = None,
180+
names: Optional[List[str]] = Query(None, description="Filter by specific tool names"),
181+
tool_ids: Optional[List[str]] = Query(
182+
None, description="Filter by specific tool IDs - accepts repeated params or comma-separated values"
183+
),
184+
search: Optional[str] = Query(None, description="Search tool names (case-insensitive partial match)"),
185+
tool_types: Optional[List[str]] = Query(None, description="Filter by tool type(s) - accepts repeated params or comma-separated values"),
186+
exclude_tool_types: Optional[List[str]] = Query(
187+
None, description="Tool type(s) to exclude - accepts repeated params or comma-separated values"
188+
),
189+
return_only_letta_tools: Optional[bool] = Query(False, description="Return only tools with tool_type starting with 'letta_'"),
102190
server: SyncServer = Depends(get_letta_server),
103191
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
104192
):
105193
"""
106194
Get a list of all tools available to agents belonging to the org of the user
107195
"""
108196
try:
197+
# Helper function to parse tool types - supports both repeated params and comma-separated values
198+
def parse_tool_types(tool_types_input: Optional[List[str]]) -> Optional[List[str]]:
199+
if tool_types_input is None:
200+
return None
201+
202+
# Flatten any comma-separated values and validate against ToolType enum
203+
flattened_types = []
204+
for item in tool_types_input:
205+
# Split by comma in case user provided comma-separated values
206+
types_in_item = [t.strip() for t in item.split(",") if t.strip()]
207+
flattened_types.extend(types_in_item)
208+
209+
# Validate each type against the ToolType enum
210+
valid_types = []
211+
valid_values = [tt.value for tt in ToolType]
212+
213+
for tool_type in flattened_types:
214+
if tool_type not in valid_values:
215+
raise HTTPException(
216+
status_code=400, detail=f"Invalid tool_type '{tool_type}'. Must be one of: {', '.join(valid_values)}"
217+
)
218+
valid_types.append(tool_type)
219+
220+
return valid_types if valid_types else None
221+
222+
# Parse and validate tool types
223+
tool_types_str = parse_tool_types(tool_types)
224+
exclude_tool_types_str = parse_tool_types(exclude_tool_types)
225+
109226
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
227+
228+
# Combine single name with names list for unified processing
229+
combined_names = []
110230
if name is not None:
111-
tool = await server.tool_manager.get_tool_by_name_async(tool_name=name, actor=actor)
112-
return [tool] if tool else []
231+
combined_names.append(name)
232+
if names is not None:
233+
combined_names.extend(names)
113234

114-
# Get the list of tools
115-
return await server.tool_manager.list_tools_async(actor=actor, after=after, limit=limit)
235+
# Use None if no names specified, otherwise use the combined list
236+
final_names = combined_names if combined_names else None
237+
238+
# Helper function to parse tool IDs - supports both repeated params and comma-separated values
239+
def parse_tool_ids(tool_ids_input: Optional[List[str]]) -> Optional[List[str]]:
240+
if tool_ids_input is None:
241+
return None
242+
243+
# Flatten any comma-separated values
244+
flattened_ids = []
245+
for item in tool_ids_input:
246+
# Split by comma in case user provided comma-separated values
247+
ids_in_item = [id.strip() for id in item.split(",") if id.strip()]
248+
flattened_ids.extend(ids_in_item)
249+
250+
return flattened_ids if flattened_ids else None
251+
252+
# Parse tool IDs
253+
final_tool_ids = parse_tool_ids(tool_ids)
254+
255+
# Get the list of tools using unified query
256+
return await server.tool_manager.list_tools_async(
257+
actor=actor,
258+
after=after,
259+
limit=limit,
260+
tool_types=tool_types_str,
261+
exclude_tool_types=exclude_tool_types_str,
262+
names=final_names,
263+
tool_ids=final_tool_ids,
264+
search=search,
265+
return_only_letta_tools=return_only_letta_tools,
266+
)
116267
except Exception as e:
117268
# Log or print the full exception here for debugging
118269
print(f"Error occurred: {e}")

0 commit comments

Comments
 (0)