Skip to content

PoC: InferenceClient is also a MCPClient #2986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c720d86
Add extra dependency
julien-c Apr 8, 2025
a0be544
PoC: `InferenceClient` is also a `MCPClient`
julien-c Apr 8, 2025
2c7329c
[using Claude] change the code to make MCPClient inherit from AsyncIn…
julien-c Apr 8, 2025
cef1bba
Update mcp_client.py
julien-c Apr 8, 2025
42f036e
mcp_client: Support multiple servers (#2987)
julien-c Apr 8, 2025
990a926
Revert "[using Claude] change the code to make MCPClient inherit from…
julien-c Apr 9, 2025
879d2ee
`add_mcp_server`: the env should not be hardcoded here
julien-c Apr 11, 2025
9ee3c68
Handle the "no tool call" case
julien-c Apr 11, 2025
c827256
Merge branch 'main' into mcp-client
Wauplin May 13, 2025
e5d205b
Update setup.py
Wauplin May 13, 2025
7c08143
Merge branch 'mcp-client' of github.com:huggingface/huggingface_hub i…
Wauplin May 13, 2025
67304ce
Async mcp client + example + code quality
Wauplin May 13, 2025
3d422f8
docstring
Wauplin May 13, 2025
1a12eb5
accept ChatCompletionInputMessage as input
Wauplin May 13, 2025
1f2181c
Merge branch 'main' into mcp-client
Wauplin May 13, 2025
5313d8b
lazy loading
Wauplin May 13, 2025
ff1d39b
style
Wauplin May 13, 2025
bc8448d
better type
Wauplin May 13, 2025
b03ef86
no need mcp for dev
Wauplin May 13, 2025
5d9af3a
code quality on Python 3.8
Wauplin May 13, 2025
ee648eb
Merge branch 'main' into mcp-client
Wauplin May 20, 2025
0d6981a
address feedback
Wauplin May 20, 2025
63a37f9
address feedback
Wauplin May 20, 2025
b273cba
do not close client inside of [200~process_single_turn_with_tools~
Wauplin May 20, 2025
834cef2
docstring, no more warning, garbage collection
Wauplin May 20, 2025
b3ea2ee
docs
Wauplin May 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def get_version() -> str:

extras["hf_xet"] = ["hf-xet>=1.1.1,<2.0.0"]

extras["mcp"] = [
"mcp>=1.8.0",
] + extras["inference"]

extras["testing"] = (
extras["cli"]
+ extras["inference"]
Expand Down
5 changes: 5 additions & 0 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,9 @@
"ZeroShotObjectDetectionOutputElement",
"ZeroShotObjectDetectionParameters",
],
"inference._mcp.mcp_client": [
"MCPClient",
],
"inference_api": [
"InferenceApi",
],
Expand Down Expand Up @@ -644,6 +647,7 @@
"InferenceEndpointType",
"InferenceTimeoutError",
"KerasModelHubMixin",
"MCPClient",
"ModelCard",
"ModelCardData",
"ModelHubMixin",
Expand Down Expand Up @@ -1402,6 +1406,7 @@ def __dir__():
ZeroShotObjectDetectionOutputElement, # noqa: F401
ZeroShotObjectDetectionParameters, # noqa: F401
)
from .inference._mcp.mcp_client import MCPClient # noqa: F401
from .inference_api import InferenceApi # noqa: F401
from .keras_mixin import (
KerasModelHubMixin, # noqa: F401
Expand Down
13 changes: 7 additions & 6 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
AudioToAudioOutputElement,
AutomaticSpeechRecognitionOutput,
ChatCompletionInputGrammarType,
ChatCompletionInputMessage,
ChatCompletionInputStreamOptions,
ChatCompletionInputTool,
ChatCompletionInputToolChoiceClass,
Expand Down Expand Up @@ -100,7 +101,7 @@
ZeroShotClassificationOutputElement,
ZeroShotImageClassificationOutputElement,
)
from huggingface_hub.inference._providers import PROVIDER_T, get_provider_helper
from huggingface_hub.inference._providers import PROVIDER_OR_POLICY_T, get_provider_helper
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
from huggingface_hub.utils._auth import get_token
from huggingface_hub.utils._deprecation import _deprecate_method
Expand Down Expand Up @@ -164,7 +165,7 @@ def __init__(
self,
model: Optional[str] = None,
*,
provider: Union[Literal["auto"], PROVIDER_T, None] = None,
provider: Optional[PROVIDER_OR_POLICY_T] = None,
token: Optional[str] = None,
timeout: Optional[float] = None,
headers: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -446,7 +447,7 @@ def automatic_speech_recognition(
@overload
def chat_completion( # type: ignore
self,
messages: List[Dict],
messages: List[Union[Dict, ChatCompletionInputMessage]],
*,
model: Optional[str] = None,
stream: Literal[False] = False,
Expand All @@ -472,7 +473,7 @@ def chat_completion( # type: ignore
@overload
def chat_completion( # type: ignore
self,
messages: List[Dict],
messages: List[Union[Dict, ChatCompletionInputMessage]],
*,
model: Optional[str] = None,
stream: Literal[True] = True,
Expand All @@ -498,7 +499,7 @@ def chat_completion( # type: ignore
@overload
def chat_completion(
self,
messages: List[Dict],
messages: List[Union[Dict, ChatCompletionInputMessage]],
*,
model: Optional[str] = None,
stream: bool = False,
Expand All @@ -523,7 +524,7 @@ def chat_completion(

def chat_completion(
self,
messages: List[Dict],
messages: List[Union[Dict, ChatCompletionInputMessage]],
*,
model: Optional[str] = None,
stream: bool = False,
Expand Down
13 changes: 7 additions & 6 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
AudioToAudioOutputElement,
AutomaticSpeechRecognitionOutput,
ChatCompletionInputGrammarType,
ChatCompletionInputMessage,
ChatCompletionInputStreamOptions,
ChatCompletionInputTool,
ChatCompletionInputToolChoiceClass,
Expand Down Expand Up @@ -85,7 +86,7 @@
ZeroShotClassificationOutputElement,
ZeroShotImageClassificationOutputElement,
)
from huggingface_hub.inference._providers import PROVIDER_T, get_provider_helper
from huggingface_hub.inference._providers import PROVIDER_OR_POLICY_T, get_provider_helper
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
from huggingface_hub.utils._auth import get_token
from huggingface_hub.utils._deprecation import _deprecate_method
Expand Down Expand Up @@ -154,7 +155,7 @@ def __init__(
self,
model: Optional[str] = None,
*,
provider: Union[Literal["auto"], PROVIDER_T, None] = None,
provider: Optional[PROVIDER_OR_POLICY_T] = None,
token: Optional[str] = None,
timeout: Optional[float] = None,
headers: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -480,7 +481,7 @@ async def automatic_speech_recognition(
@overload
async def chat_completion( # type: ignore
self,
messages: List[Dict],
messages: List[Union[Dict, ChatCompletionInputMessage]],
*,
model: Optional[str] = None,
stream: Literal[False] = False,
Expand All @@ -506,7 +507,7 @@ async def chat_completion( # type: ignore
@overload
async def chat_completion( # type: ignore
self,
messages: List[Dict],
messages: List[Union[Dict, ChatCompletionInputMessage]],
*,
model: Optional[str] = None,
stream: Literal[True] = True,
Expand All @@ -532,7 +533,7 @@ async def chat_completion( # type: ignore
@overload
async def chat_completion(
self,
messages: List[Dict],
messages: List[Union[Dict, ChatCompletionInputMessage]],
*,
model: Optional[str] = None,
stream: bool = False,
Expand All @@ -557,7 +558,7 @@ async def chat_completion(

async def chat_completion(
self,
messages: List[Dict],
messages: List[Union[Dict, ChatCompletionInputMessage]],
*,
model: Optional[str] = None,
stream: bool = False,
Expand Down
Empty file.
236 changes: 236 additions & 0 deletions src/huggingface_hub/inference/_mcp/mcp_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
import json
import logging
import warnings
from contextlib import AsyncExitStack
from pathlib import Path
from typing import TYPE_CHECKING, AsyncIterable, Dict, List, Optional, Union

from typing_extensions import TypeAlias

from ...utils._runtime import get_hf_hub_version
from .._generated._async_client import AsyncInferenceClient
from .._generated.types import (
ChatCompletionInputMessage,
ChatCompletionInputTool,
ChatCompletionStreamOutput,
ChatCompletionStreamOutputDeltaToolCall,
)
from .._providers import PROVIDER_OR_POLICY_T
from .utils import format_result


if TYPE_CHECKING:
from mcp import ClientSession

logger = logging.getLogger(__name__)

# Type alias for tool names
ToolName: TypeAlias = str


class MCPClient:
def __init__(
self,
*,
model: str,
provider: Optional[PROVIDER_OR_POLICY_T] = None,
api_key: Optional[str] = None,
):
warnings.warn(
"'MCPClient' is experimental and might be subject to breaking changes in the future without prior notice.",
UserWarning,
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

personally think we could omit this, but no strong opinion

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved it the docstring


# Initialize MCP sessions as a dictionary of ClientSession objects
self.sessions: Dict[ToolName, "ClientSession"] = {}
self.exit_stack = AsyncExitStack()
self.available_tools: List[ChatCompletionInputTool] = []

# Initialize the AsyncInferenceClient
self.client = AsyncInferenceClient(model=model, provider=provider, api_key=api_key)

async def __aenter__(self):
"""Enter the context manager"""
await self.client.__aenter__()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the point of entering the client context manager here?

It is only used within the process_single_turn_with_tools method, and the context manager is properly entered and exit there using a with block?

Suggested change
await self.client.__aenter__()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I initially intended to delete the async with self.client: below but forgot about it. The issue with including async with self.client: within process_single_turn_with_tools is that concurrent usage of process_single_turn_with_tools may result in unexpected behaviors if one instance terminates the sessions for all others. By relocating the logic to __aenter__, the responsibility shifts to the end user, who can handle the client lifecycle themselves.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated in b273cba

return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Exit the context manager"""
await self.client.__aexit__(exc_type, exc_val, exc_tb)
await self.cleanup()

async def add_mcp_server(
self,
*,
command: str,
args: Optional[List[str]] = None,
env: Optional[Dict[str, str]] = None,
cwd: Union[str, Path, None] = None,
):
"""Connect to an MCP server

Args:
command (str):
The command to run the MCP server.
args (List[str], optional):
Arguments for the command.
env (Dict[str, str], optional):
Environment variables for the command. Default to empty environment.
cwd (Union[str, Path, None], optional):
Working directory for the command. Default to current directory.
"""
from mcp import ClientSession, StdioServerParameters
from mcp import types as mcp_types
from mcp.client.stdio import stdio_client

logger.info(f"Connecting to MCP server with command: {command} {args}")
server_params = StdioServerParameters(
command=command,
args=args if args is not None else [],
env=env,
cwd=cwd,
)

stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
stdio, write = stdio_transport
session = await self.exit_stack.enter_async_context(
ClientSession(
read_stream=stdio,
write_stream=write,
client_info=mcp_types.Implementation(
name="huggingface_hub.MCPClient",
version=get_hf_hub_version(),
),
)
)

logger.debug("Initializing session...")
await session.initialize()

# List available tools
response = await session.list_tools()
logger.debug("Connected to server with tools:", [tool.name for tool in response.tools])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm getting a silent error TypeError: not all arguments converted during string formatting because the 2nd positional arg is treated as a %s placeholder

Suggested change
logger.debug("Connected to server with tools:", [tool.name for tool in response.tools])
logger.debug("Connected to server with tools: %s", [tool.name for tool in response.tools])


for tool in response.tools:
if tool.name in self.sessions:
logger.warning(f"Tool '{tool.name}' already defined by another server. Skipping.")
continue
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should eventually consider supporting tools with the same name coming from different servers. One possible approach could be to prepend the server name to the tool name to ensure uniqueness.
This could be explored in a future PR.

Copy link
Contributor

@Wauplin Wauplin May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes agree 👍 (been mentioned above but for now post-poned for later PR as you said)


# Map tool names to their server for later lookup
self.sessions[tool.name] = session

# Add tool to the list of available tools (for use in chat completions)
self.available_tools.append(
ChatCompletionInputTool.parse_obj_as_instance(
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.inputSchema,
},
}
)
)

async def process_single_turn_with_tools(
self,
messages: List[Union[Dict, ChatCompletionInputMessage]],
exit_loop_tools: Optional[List[ChatCompletionInputTool]] = None,
exit_if_first_chunk_no_tool: bool = False,
) -> AsyncIterable[Union[ChatCompletionStreamOutput, ChatCompletionInputMessage]]:
"""Process a query using `self.model` and available tools, yielding chunks and tool outputs.

Args:
query: The user query to process
opts: Optional parameters including:
- exit_loop_tools: List of tools that should exit the generator when called
- exit_if_first_chunk_no_tool: Exit if no tool is present in the first chunks

Yields:
ChatCompletionStreamOutput chunks or ChatCompletionInputMessage objects
"""
# Prepare tools list based on options
tools = self.available_tools
if exit_loop_tools is not None:
tools = [*exit_loop_tools, *self.available_tools]

# Create the streaming request
async with self.client:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the async with mandatory here? we already open self.client in __aenter__

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment above: I proposed keeping the with block here and removing the self.client from the __aenter__ above.
The reason is that self.client stores an HTTP response object each time it make a request: so better cleaning this resource after the request/response has been handled (here, in this part of the code), instead of keeping all responses if multiple calls to process_single_turn_with_tools are made.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed I forgot to remove this part. See my comment here #2986 (comment).

The reason is that self.client stores an HTTP response object each time it make a request: so better cleaning this resource after the request/response has been handled (here, in this part of the code), instead of keeping all responses if multiple calls to process_single_turn_with_tools are made.

Currently, if process_single_turn_with_tools is called in parallel you might get unexpected behaviors because of a used connection being closed abruptly. In practice, the InferenceClient won't really have many unclosed sessions in parallel. They are already well garbage collected except in the case of a non-awaited streaming response - which shouldn' occur.

response = await self.client.chat.completions.create(
messages=messages,
tools=tools,
tool_choice="auto",
stream=True,
)

message = {"role": "unknown", "content": ""}
final_tool_calls: Dict[int, ChatCompletionStreamOutputDeltaToolCall] = {}
num_of_chunks = 0

# Read from stream
async for chunk in response:
# Yield each chunk to caller
yield chunk

num_of_chunks += 1
delta = chunk.choices[0].delta if chunk.choices and len(chunk.choices) > 0 else None
if not delta:
continue

# Process message
if delta.role:
message["role"] = delta.role
if delta.content:
message["content"] += delta.content

# Process tool calls
if delta.tool_calls:
for tool_call in delta.tool_calls:
# Aggregate chunks into tool calls
if tool_call.index not in final_tool_calls:
if tool_call.function.arguments is None: # Corner case (depends on provider)
tool_call.function.arguments = ""
final_tool_calls[tool_call.index] = tool_call

if tool_call.function.arguments:
final_tool_calls[tool_call.index].function.arguments += tool_call.function.arguments

# Optionally exit early if no tools in first chunks
if exit_if_first_chunk_no_tool and num_of_chunks <= 2 and len(final_tool_calls) == 0:
return

if message["content"]:
messages.append(message)

# Process tool calls one by one
for tool_call in final_tool_calls.values():
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments or "{}")

tool_message = {"role": "tool", "tool_call_id": tool_call.id, "content": "", "name": function_name}

# Check if this is an exit loop tool
if exit_loop_tools and function_name in [t.function.name for t in exit_loop_tools]:
tool_message_as_obj = ChatCompletionInputMessage.parse_obj_as_instance(tool_message)
messages.append(tool_message_as_obj)
yield tool_message_as_obj
return

# Execute tool call with the appropriate session
session = self.sessions.get(function_name)
if session is not None:
result = await session.call_tool(function_name, function_args)
tool_message["content"] = format_result(result)
else:
error_msg = f"Error: No session found for tool: {function_name}"
tool_message["content"] = error_msg

# Yield tool message
tool_message_as_obj = ChatCompletionInputMessage.parse_obj_as_instance(tool_message)
messages.append(tool_message_as_obj)
yield tool_message_as_obj

async def cleanup(self):
"""Clean up resources"""
await self.exit_stack.aclose()
Loading