-
Notifications
You must be signed in to change notification settings - Fork 762
PoC: InferenceClient
is also a MCPClient
#2986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
c720d86
a0be544
2c7329c
cef1bba
42f036e
990a926
879d2ee
9ee3c68
c827256
e5d205b
7c08143
67304ce
3d422f8
1a12eb5
1f2181c
5313d8b
ff1d39b
bc8448d
b03ef86
5d9af3a
ee648eb
0d6981a
63a37f9
b273cba
834cef2
b3ea2ee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import asyncio | ||
import json | ||
import os | ||
from contextlib import AsyncExitStack | ||
from typing import Dict, List, Optional, TypeAlias | ||
|
||
from mcp import ClientSession, StdioServerParameters | ||
from mcp.client.stdio import stdio_client | ||
|
||
from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool, ChatCompletionOutput | ||
from huggingface_hub.inference._providers import PROVIDER_T | ||
|
||
|
||
# Type alias for tool names | ||
ToolName: TypeAlias = str | ||
|
||
|
||
class MCPClient: | ||
def __init__( | ||
self, | ||
*, | ||
provider: PROVIDER_T, | ||
model: str, | ||
Wauplin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
api_key: Optional[str] = None, | ||
): | ||
self.client = AsyncInferenceClient( | ||
provider=provider, | ||
api_key=api_key, | ||
) | ||
self.model = model | ||
# Initialize MCP sessions as a dictionary of ClientSession objects | ||
self.sessions: Dict[ToolName, ClientSession] = {} | ||
self.exit_stack = AsyncExitStack() | ||
self.available_tools: List[ChatCompletionInputTool] = [] | ||
|
||
async def add_mcp_server(self, command: str, args: List[str], env: Dict[str, str]): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you would need to lighten a bit the requirements on your args if you want to make it work with SSE or the intent is just to support STDIO ? I see the rest seems to focus on stdio so maybe it's by design There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for now, just Stdio, but in the future There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes this is the new spec, but it is backward compatible and at the level you are working with in this PR I wouldnt expect much to change, probably the internals of the client will change but the client interface would remain the same. Which means if today you do something like add_mcp_server(StdioParameters | dict) dict being the arguments of the sse_client from the python sdk you could already support all the SSE servers + potentially future Streaming HTTP server with minor adjustments at most |
||
"""Connect to an MCP server | ||
|
||
Args: | ||
todo | ||
Wauplin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
server_params = StdioServerParameters(command=command, args=args, env=env) | ||
|
||
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) | ||
stdio, write = stdio_transport | ||
session = await self.exit_stack.enter_async_context(ClientSession(stdio, write)) | ||
|
||
await session.initialize() | ||
|
||
# List available tools | ||
response = await session.list_tools() | ||
tools = response.tools | ||
print("\nConnected to server with tools:", [tool.name for tool in tools]) | ||
|
||
# Map tool names to their server for later lookup | ||
for tool in tools: | ||
self.sessions[tool.name] = session | ||
Wauplin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self.available_tools += [ | ||
Wauplin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
"type": "function", | ||
"function": { | ||
"name": tool.name, | ||
"description": tool.description, | ||
"parameters": tool.inputSchema, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just a note that I have seen some MCP servers with jsonref in their description which sometimes confuses the model. In mcpadapt I had to resolve the jsonref before passing it to the model, might be minor for now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. confused or sometime plain unsupported by the model sdk like google genai... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting, does the spec mention anything about whether jsonref is allowed or not? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think the spec mention it however, it gets auto generated if you use pydantic models by the official mcp python sdk using the fastMCP syntax. I had the case for one of my mcp server I use to test things: https://github.com/grll/pubmedmcp |
||
}, | ||
} | ||
for tool in tools | ||
] | ||
|
||
async def process_query(self, query: str) -> ChatCompletionOutput: | ||
"""Process a query using `self.model` and available tools""" | ||
messages = [{"role": "user", "content": query}] | ||
|
||
response = await self.client.chat.completions.create( | ||
model=self.model, | ||
messages=messages, | ||
tools=self.available_tools, | ||
tool_choice="auto", | ||
) | ||
|
||
# Process response and handle tool calls | ||
tool_calls = response.choices[0].message.tool_calls | ||
if tool_calls is None or len(tool_calls) == 0: | ||
return response | ||
|
||
for tool_call in tool_calls: | ||
Wauplin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
function_name = tool_call.function.name | ||
function_args = json.loads(tool_call.function.arguments) | ||
|
||
# Get the appropriate session for this tool | ||
session = self.sessions.get(function_name) | ||
if session: | ||
# Execute tool call with the appropriate session | ||
result = await session.call_tool(function_name, function_args) | ||
messages.append( | ||
{ | ||
"tool_call_id": tool_call.id, | ||
"role": "tool", | ||
"name": function_name, | ||
"content": result.content[0].text, | ||
} | ||
) | ||
else: | ||
error_msg = f"No session found for tool: {function_name}" | ||
print(f"Error: {error_msg}") | ||
messages.append( | ||
{ | ||
"tool_call_id": tool_call.id, | ||
"role": "tool", | ||
"name": function_name, | ||
"content": f"Error: {error_msg}", | ||
} | ||
) | ||
|
||
enriched_response = await self.client.chat.completions.create( | ||
model=self.model, | ||
messages=messages, | ||
) | ||
|
||
return enriched_response | ||
|
||
async def cleanup(self): | ||
"""Clean up resources""" | ||
await self.exit_stack.aclose() | ||
|
||
|
||
async def main(): | ||
client = MCPClient( | ||
provider="together", | ||
model="Qwen/Qwen2.5-72B-Instruct", | ||
api_key=os.environ["HF_TOKEN"], | ||
) | ||
try: | ||
await client.add_mcp_server( | ||
"node", | ||
["--disable-warning=ExperimentalWarning", f"{os.path.expanduser('~')}/Desktop/hf-mcp/index.ts"], | ||
{"HF_TOKEN": os.environ["HF_TOKEN"]}, | ||
) | ||
response = await client.process_query( | ||
""" | ||
find an app that generates 3D models from text, | ||
and also get the best paper about transformers | ||
""" | ||
) | ||
print("\n" + response.choices[0].message.content) | ||
finally: | ||
await client.cleanup() | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
Uh oh!
There was an error while loading. Please reload this page.