-
Notifications
You must be signed in to change notification settings - Fork 765
PoC: InferenceClient
is also a MCPClient
#2986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 20 commits
c720d86
a0be544
2c7329c
cef1bba
42f036e
990a926
879d2ee
9ee3c68
c827256
e5d205b
7c08143
67304ce
3d422f8
1a12eb5
1f2181c
5313d8b
ff1d39b
bc8448d
b03ef86
5d9af3a
ee648eb
0d6981a
63a37f9
b273cba
834cef2
b3ea2ee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,236 @@ | ||||||
import json | ||||||
import logging | ||||||
import warnings | ||||||
from contextlib import AsyncExitStack | ||||||
from pathlib import Path | ||||||
from typing import TYPE_CHECKING, AsyncIterable, Dict, List, Optional, Union | ||||||
|
||||||
from typing_extensions import TypeAlias | ||||||
|
||||||
from ...utils._runtime import get_hf_hub_version | ||||||
from .._generated._async_client import AsyncInferenceClient | ||||||
from .._generated.types import ( | ||||||
ChatCompletionInputMessage, | ||||||
ChatCompletionInputTool, | ||||||
ChatCompletionStreamOutput, | ||||||
ChatCompletionStreamOutputDeltaToolCall, | ||||||
) | ||||||
from .._providers import PROVIDER_OR_POLICY_T | ||||||
from .utils import format_result | ||||||
|
||||||
|
||||||
if TYPE_CHECKING: | ||||||
from mcp import ClientSession | ||||||
|
||||||
logger = logging.getLogger(__name__) | ||||||
|
||||||
# Type alias for tool names | ||||||
ToolName: TypeAlias = str | ||||||
|
||||||
|
||||||
class MCPClient: | ||||||
def __init__( | ||||||
self, | ||||||
*, | ||||||
model: str, | ||||||
provider: Optional[PROVIDER_OR_POLICY_T] = None, | ||||||
api_key: Optional[str] = None, | ||||||
): | ||||||
warnings.warn( | ||||||
"'MCPClient' is experimental and might be subject to breaking changes in the future without prior notice.", | ||||||
UserWarning, | ||||||
) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. personally think we could omit this, but no strong opinion There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. moved it the docstring |
||||||
|
||||||
# Initialize MCP sessions as a dictionary of ClientSession objects | ||||||
self.sessions: Dict[ToolName, "ClientSession"] = {} | ||||||
self.exit_stack = AsyncExitStack() | ||||||
self.available_tools: List[ChatCompletionInputTool] = [] | ||||||
|
||||||
# Initialize the AsyncInferenceClient | ||||||
self.client = AsyncInferenceClient(model=model, provider=provider, api_key=api_key) | ||||||
|
||||||
async def __aenter__(self): | ||||||
"""Enter the context manager""" | ||||||
await self.client.__aenter__() | ||||||
Wauplin marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the point of entering the client context manager here? It is only used within the
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch! I initially intended to delete the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated in b273cba |
||||||
return self | ||||||
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb): | ||||||
"""Exit the context manager""" | ||||||
await self.client.__aexit__(exc_type, exc_val, exc_tb) | ||||||
Wauplin marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
await self.cleanup() | ||||||
|
||||||
async def add_mcp_server( | ||||||
self, | ||||||
*, | ||||||
command: str, | ||||||
args: Optional[List[str]] = None, | ||||||
env: Optional[Dict[str, str]] = None, | ||||||
cwd: Union[str, Path, None] = None, | ||||||
): | ||||||
"""Connect to an MCP server | ||||||
|
||||||
Args: | ||||||
command (str): | ||||||
The command to run the MCP server. | ||||||
args (List[str], optional): | ||||||
Arguments for the command. | ||||||
env (Dict[str, str], optional): | ||||||
Environment variables for the command. Default to empty environment. | ||||||
Wauplin marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
cwd (Union[str, Path, None], optional): | ||||||
Working directory for the command. Default to current directory. | ||||||
""" | ||||||
from mcp import ClientSession, StdioServerParameters | ||||||
from mcp import types as mcp_types | ||||||
from mcp.client.stdio import stdio_client | ||||||
|
||||||
logger.info(f"Connecting to MCP server with command: {command} {args}") | ||||||
server_params = StdioServerParameters( | ||||||
command=command, | ||||||
args=args if args is not None else [], | ||||||
env=env, | ||||||
cwd=cwd, | ||||||
) | ||||||
|
||||||
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) | ||||||
stdio, write = stdio_transport | ||||||
session = await self.exit_stack.enter_async_context( | ||||||
ClientSession( | ||||||
read_stream=stdio, | ||||||
write_stream=write, | ||||||
Wauplin marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
client_info=mcp_types.Implementation( | ||||||
name="huggingface_hub.MCPClient", | ||||||
version=get_hf_hub_version(), | ||||||
), | ||||||
) | ||||||
) | ||||||
|
||||||
logger.debug("Initializing session...") | ||||||
await session.initialize() | ||||||
|
||||||
# List available tools | ||||||
response = await session.list_tools() | ||||||
logger.debug("Connected to server with tools:", [tool.name for tool in response.tools]) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm getting a silent error
Suggested change
|
||||||
|
||||||
for tool in response.tools: | ||||||
if tool.name in self.sessions: | ||||||
logger.warning(f"Tool '{tool.name}' already defined by another server. Skipping.") | ||||||
continue | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should eventually consider supporting tools with the same name coming from different servers. One possible approach could be to prepend the server name to the tool name to ensure uniqueness. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes agree 👍 (been mentioned above but for now post-poned for later PR as you said) |
||||||
|
||||||
# Map tool names to their server for later lookup | ||||||
self.sessions[tool.name] = session | ||||||
|
||||||
# Add tool to the list of available tools (for use in chat completions) | ||||||
self.available_tools.append( | ||||||
ChatCompletionInputTool.parse_obj_as_instance( | ||||||
{ | ||||||
"type": "function", | ||||||
"function": { | ||||||
"name": tool.name, | ||||||
"description": tool.description, | ||||||
"parameters": tool.inputSchema, | ||||||
}, | ||||||
} | ||||||
) | ||||||
) | ||||||
|
||||||
async def process_single_turn_with_tools( | ||||||
self, | ||||||
messages: List[Union[Dict, ChatCompletionInputMessage]], | ||||||
exit_loop_tools: Optional[List[ChatCompletionInputTool]] = None, | ||||||
exit_if_first_chunk_no_tool: bool = False, | ||||||
) -> AsyncIterable[Union[ChatCompletionStreamOutput, ChatCompletionInputMessage]]: | ||||||
"""Process a query using `self.model` and available tools, yielding chunks and tool outputs. | ||||||
|
||||||
Args: | ||||||
query: The user query to process | ||||||
opts: Optional parameters including: | ||||||
- exit_loop_tools: List of tools that should exit the generator when called | ||||||
- exit_if_first_chunk_no_tool: Exit if no tool is present in the first chunks | ||||||
|
||||||
Yields: | ||||||
ChatCompletionStreamOutput chunks or ChatCompletionInputMessage objects | ||||||
""" | ||||||
Wauplin marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# Prepare tools list based on options | ||||||
tools = self.available_tools | ||||||
if exit_loop_tools is not None: | ||||||
tools = [*exit_loop_tools, *self.available_tools] | ||||||
|
||||||
# Create the streaming request | ||||||
async with self.client: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See my comment above: I proposed keeping the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indeed I forgot to remove this part. See my comment here #2986 (comment).
Currently, if |
||||||
response = await self.client.chat.completions.create( | ||||||
messages=messages, | ||||||
tools=tools, | ||||||
tool_choice="auto", | ||||||
stream=True, | ||||||
) | ||||||
|
||||||
message = {"role": "unknown", "content": ""} | ||||||
final_tool_calls: Dict[int, ChatCompletionStreamOutputDeltaToolCall] = {} | ||||||
num_of_chunks = 0 | ||||||
|
||||||
# Read from stream | ||||||
async for chunk in response: | ||||||
# Yield each chunk to caller | ||||||
yield chunk | ||||||
|
||||||
num_of_chunks += 1 | ||||||
delta = chunk.choices[0].delta if chunk.choices and len(chunk.choices) > 0 else None | ||||||
if not delta: | ||||||
continue | ||||||
|
||||||
# Process message | ||||||
if delta.role: | ||||||
message["role"] = delta.role | ||||||
if delta.content: | ||||||
message["content"] += delta.content | ||||||
|
||||||
# Process tool calls | ||||||
if delta.tool_calls: | ||||||
for tool_call in delta.tool_calls: | ||||||
# Aggregate chunks into tool calls | ||||||
if tool_call.index not in final_tool_calls: | ||||||
if tool_call.function.arguments is None: # Corner case (depends on provider) | ||||||
tool_call.function.arguments = "" | ||||||
final_tool_calls[tool_call.index] = tool_call | ||||||
|
||||||
if tool_call.function.arguments: | ||||||
final_tool_calls[tool_call.index].function.arguments += tool_call.function.arguments | ||||||
|
||||||
# Optionally exit early if no tools in first chunks | ||||||
if exit_if_first_chunk_no_tool and num_of_chunks <= 2 and len(final_tool_calls) == 0: | ||||||
return | ||||||
|
||||||
if message["content"]: | ||||||
messages.append(message) | ||||||
|
||||||
# Process tool calls one by one | ||||||
for tool_call in final_tool_calls.values(): | ||||||
function_name = tool_call.function.name | ||||||
function_args = json.loads(tool_call.function.arguments or "{}") | ||||||
|
||||||
tool_message = {"role": "tool", "tool_call_id": tool_call.id, "content": "", "name": function_name} | ||||||
|
||||||
# Check if this is an exit loop tool | ||||||
if exit_loop_tools and function_name in [t.function.name for t in exit_loop_tools]: | ||||||
tool_message_as_obj = ChatCompletionInputMessage.parse_obj_as_instance(tool_message) | ||||||
messages.append(tool_message_as_obj) | ||||||
yield tool_message_as_obj | ||||||
return | ||||||
|
||||||
# Execute tool call with the appropriate session | ||||||
session = self.sessions.get(function_name) | ||||||
if session is not None: | ||||||
result = await session.call_tool(function_name, function_args) | ||||||
tool_message["content"] = format_result(result) | ||||||
else: | ||||||
error_msg = f"Error: No session found for tool: {function_name}" | ||||||
tool_message["content"] = error_msg | ||||||
|
||||||
# Yield tool message | ||||||
tool_message_as_obj = ChatCompletionInputMessage.parse_obj_as_instance(tool_message) | ||||||
messages.append(tool_message_as_obj) | ||||||
yield tool_message_as_obj | ||||||
|
||||||
async def cleanup(self): | ||||||
"""Clean up resources""" | ||||||
await self.exit_stack.aclose() |
Uh oh!
There was an error while loading. Please reload this page.