-
Notifications
You must be signed in to change notification settings - Fork 765
PoC: InferenceClient
is also a MCPClient
#2986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
c720d86
a0be544
2c7329c
cef1bba
42f036e
990a926
879d2ee
9ee3c68
c827256
e5d205b
7c08143
67304ce
3d422f8
1a12eb5
1f2181c
5313d8b
ff1d39b
bc8448d
b03ef86
5d9af3a
ee648eb
0d6981a
63a37f9
b273cba
834cef2
b3ea2ee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import asyncio | ||
import json | ||
import os | ||
from contextlib import AsyncExitStack | ||
from typing import List, Optional | ||
|
||
from mcp import ClientSession, StdioServerParameters | ||
from mcp.client.stdio import stdio_client | ||
|
||
from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool, ChatCompletionOutput | ||
from huggingface_hub.inference._providers import PROVIDER_T | ||
|
||
|
||
class MCPClient(AsyncInferenceClient): | ||
def __init__( | ||
self, | ||
*, | ||
provider: PROVIDER_T, | ||
model: str, | ||
api_key: Optional[str] = None, | ||
): | ||
super().__init__( | ||
provider=provider, | ||
api_key=api_key, | ||
) | ||
self.model = model | ||
# Initialize MCP session and client objects | ||
self.session: Optional[ClientSession] = None | ||
self.exit_stack = AsyncExitStack() | ||
self.available_tools: List[ChatCompletionInputTool] = [] | ||
|
||
async def add_mcp_server(self, command: str, args: List[str]): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not name this method There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes we can There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i mean we would need to store a map of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let me know if the following question is out of scope. Option 1client1 = MCPClient()
client1.add_mcp_server()
client2 = MCPClient()
client2.add_mcp_server() or Option 2client = MCPClient()
client.add_mcp_server(server1)
client.add_mcp_server(server2) ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. another design question: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry to chime in unannounced, but from a very removed external user standpoint, I find this all very confusing - I just don't think what you coded should be called When I came to this PR I was fully expecting That being said, it's just about semantics, but I'm kind of a semantics extremist, sorry about that (and feel free to completely disregard this message, as is very likely XD) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
What do you mean as parameter? Do you have an example signature? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. second option of #2986 (comment) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah yes, sure, we can probably add this I guess There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah actually with the async/await stuff i'm not so sure. |
||
"""Connect to an MCP server | ||
|
||
Args: | ||
todo | ||
Wauplin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
server_params = StdioServerParameters(command=command, args=args, env={"HF_TOKEN": os.environ["HF_TOKEN"]}) | ||
|
||
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) | ||
self.stdio, self.write = stdio_transport | ||
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write)) | ||
|
||
await self.session.initialize() | ||
|
||
# List available tools | ||
response = await self.session.list_tools() | ||
tools = response.tools | ||
print("\nConnected to server with tools:", [tool.name for tool in tools]) | ||
self.available_tools += [ | ||
Wauplin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
"type": "function", | ||
"function": { | ||
"name": tool.name, | ||
"description": tool.description, | ||
"parameters": tool.inputSchema, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just a note that I have seen some MCP servers with jsonref in their description which sometimes confuses the model. In mcpadapt I had to resolve the jsonref before passing it to the model, might be minor for now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. confused or sometime plain unsupported by the model sdk like google genai... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting, does the spec mention anything about whether jsonref is allowed or not? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think the spec mention it however, it gets auto generated if you use pydantic models by the official mcp python sdk using the fastMCP syntax. I had the case for one of my mcp server I use to test things: https://github.com/grll/pubmedmcp |
||
}, | ||
} | ||
for tool in tools | ||
] | ||
|
||
async def process_query(self, query: str) -> ChatCompletionOutput: | ||
"""Process a query using Claude and available tools""" | ||
julien-c marked this conversation as resolved.
Show resolved
Hide resolved
|
||
messages = [{"role": "user", "content": query}] | ||
|
||
response = await self.chat.completions.create( | ||
model=self.model, | ||
messages=messages, | ||
tools=self.available_tools, | ||
tool_choice="auto", | ||
) | ||
|
||
# Process response and handle tool calls | ||
tool_calls = response.choices[0].message.tool_calls | ||
if tool_calls: | ||
for tool_call in tool_calls: | ||
function_name = tool_call.function.name | ||
function_args = json.loads(tool_call.function.arguments) | ||
|
||
# Execute tool call | ||
result = await self.session.call_tool(function_name, function_args) | ||
messages.append( | ||
{ | ||
"tool_call_id": tool_call.id, | ||
"role": "tool", | ||
"name": function_name, | ||
"content": result.content[0].text, | ||
} | ||
) | ||
|
||
function_enriched_response = await self.chat.completions.create( | ||
model=self.model, | ||
messages=messages, | ||
) | ||
|
||
return function_enriched_response | ||
|
||
async def cleanup(self): | ||
"""Clean up resources""" | ||
await self.exit_stack.aclose() | ||
|
||
|
||
async def main(): | ||
client = MCPClient( | ||
provider="together", | ||
model="Qwen/Qwen2.5-72B-Instruct", | ||
api_key=os.environ["HF_TOKEN"], | ||
) | ||
try: | ||
await client.add_mcp_server( | ||
"node", ["--disable-warning=ExperimentalWarning", f"{os.path.expanduser('~')}/Desktop/hf-mcp/index.ts"] | ||
) | ||
response = await client.process_query( | ||
""" | ||
find an app that generates 3D models from text, | ||
and also get the best paper about transformers | ||
""" | ||
) | ||
print("\n" + response.choices[0].message.content) | ||
finally: | ||
await client.cleanup() | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
Uh oh!
There was an error while loading. Please reload this page.