Skip to content

Commit 7640e70

Browse files
committed
mcp: implement sampling with tools
1 parent d1c06cb commit 7640e70

File tree

4 files changed

+953
-7
lines changed

4 files changed

+953
-7
lines changed

mcp/client.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ type ClientOptions struct {
6363
// Setting CreateMessageHandler to a non-nil value causes the client to
6464
// advertise the sampling capability.
6565
CreateMessageHandler func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error)
66+
// SamplingSupportsTools indicates that the client's CreateMessageHandler
67+
// supports tool use. If true and CreateMessageHandler is set, the
68+
// sampling.tools capability is advertised.
69+
SamplingSupportsTools bool
70+
// SamplingSupportsContext indicates that the client supports
71+
// includeContext values other than "none".
72+
SamplingSupportsContext bool
6673
// ElicitationHandler handles incoming requests for elicitation/create.
6774
//
6875
// Setting ElicitationHandler to a non-nil value causes the client to
@@ -131,6 +138,12 @@ func (c *Client) capabilities() *ClientCapabilities {
131138
caps.Roots.ListChanged = true
132139
if c.opts.CreateMessageHandler != nil {
133140
caps.Sampling = &SamplingCapabilities{}
141+
if c.opts.SamplingSupportsTools {
142+
caps.Sampling.Tools = &SamplingToolsCapabilities{}
143+
}
144+
if c.opts.SamplingSupportsContext {
145+
caps.Sampling.Context = &SamplingContextCapabilities{}
146+
}
134147
}
135148
if c.opts.ElicitationHandler != nil {
136149
caps.Elicitation = &ElicitationCapabilities{}

mcp/content.go

Lines changed: 133 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ import (
1414
)
1515

1616
// A Content is a [TextContent], [ImageContent], [AudioContent],
17-
// [ResourceLink], or [EmbeddedResource].
17+
// [ResourceLink], [EmbeddedResource], [ToolUseContent], or [ToolResultContent].
18+
//
19+
// Note: [ToolUseContent] and [ToolResultContent] are only valid in sampling
20+
// message contexts (CreateMessageParams/CreateMessageResult).
1821
type Content interface {
1922
MarshalJSON() ([]byte, error)
2023
fromWire(*wireContent)
@@ -183,6 +186,104 @@ func (c *EmbeddedResource) fromWire(wire *wireContent) {
183186
c.Annotations = wire.Annotations
184187
}
185188

189+
// ToolUseContent represents a request from the assistant to invoke a tool.
190+
// This content type is only valid in sampling messages.
191+
type ToolUseContent struct {
192+
// ID is a unique identifier for this tool use, used to match with ToolResultContent.
193+
ID string
194+
// Name is the name of the tool to invoke.
195+
Name string
196+
// Input contains the tool arguments as a JSON object.
197+
Input map[string]any
198+
Meta Meta
199+
}
200+
201+
func (c *ToolUseContent) MarshalJSON() ([]byte, error) {
202+
input := c.Input
203+
if input == nil {
204+
input = map[string]any{}
205+
}
206+
wire := struct {
207+
Type string `json:"type"`
208+
ID string `json:"id"`
209+
Name string `json:"name"`
210+
Input map[string]any `json:"input"`
211+
Meta Meta `json:"_meta,omitempty"`
212+
}{
213+
Type: "tool_use",
214+
ID: c.ID,
215+
Name: c.Name,
216+
Input: input,
217+
Meta: c.Meta,
218+
}
219+
return json.Marshal(wire)
220+
}
221+
222+
func (c *ToolUseContent) fromWire(wire *wireContent) {
223+
c.ID = wire.ID
224+
c.Name = wire.Name
225+
c.Input = wire.Input
226+
c.Meta = wire.Meta
227+
}
228+
229+
// ToolResultContent represents the result of a tool invocation.
230+
// This content type is only valid in sampling messages with role "user".
231+
type ToolResultContent struct {
232+
// ToolUseID references the ID from the corresponding ToolUseContent.
233+
ToolUseID string
234+
// Content holds the unstructured result of the tool call.
235+
Content []Content
236+
// StructuredContent holds an optional structured result as a JSON object.
237+
StructuredContent any
238+
// IsError indicates whether the tool call ended in an error.
239+
IsError bool
240+
Meta Meta
241+
}
242+
243+
func (c *ToolResultContent) MarshalJSON() ([]byte, error) {
244+
// Marshal nested content
245+
var contentWire []*wireContent
246+
for _, content := range c.Content {
247+
data, err := content.MarshalJSON()
248+
if err != nil {
249+
return nil, err
250+
}
251+
var w wireContent
252+
if err := json.Unmarshal(data, &w); err != nil {
253+
return nil, err
254+
}
255+
contentWire = append(contentWire, &w)
256+
}
257+
if contentWire == nil {
258+
contentWire = []*wireContent{} // avoid JSON null
259+
}
260+
261+
wire := struct {
262+
Type string `json:"type"`
263+
ToolUseID string `json:"toolUseId"`
264+
Content []*wireContent `json:"content"`
265+
StructuredContent any `json:"structuredContent,omitempty"`
266+
IsError bool `json:"isError,omitempty"`
267+
Meta Meta `json:"_meta,omitempty"`
268+
}{
269+
Type: "tool_result",
270+
ToolUseID: c.ToolUseID,
271+
Content: contentWire,
272+
StructuredContent: c.StructuredContent,
273+
IsError: c.IsError,
274+
Meta: c.Meta,
275+
}
276+
return json.Marshal(wire)
277+
}
278+
279+
func (c *ToolResultContent) fromWire(wire *wireContent) {
280+
c.ToolUseID = wire.ToolUseID
281+
c.StructuredContent = wire.StructuredContent
282+
c.IsError = wire.IsError
283+
c.Meta = wire.Meta
284+
// Content is handled separately in contentFromWire due to nested content
285+
}
286+
186287
// ResourceContents contains the contents of a specific resource or
187288
// sub-resource.
188289
type ResourceContents struct {
@@ -224,10 +325,9 @@ func (r *ResourceContents) MarshalJSON() ([]byte, error) {
224325

225326
// wireContent is the wire format for content.
226327
// It represents the protocol types TextContent, ImageContent, AudioContent,
227-
// ResourceLink, and EmbeddedResource.
328+
// ResourceLink, EmbeddedResource, ToolUseContent, and ToolResultContent.
228329
// The Type field distinguishes them. In the protocol, each type has a constant
229330
// value for the field.
230-
// At most one of Text, Data, Resource, and URI is non-zero.
231331
type wireContent struct {
232332
Type string `json:"type"`
233333
Text string `json:"text,omitempty"`
@@ -242,6 +342,14 @@ type wireContent struct {
242342
Meta Meta `json:"_meta,omitempty"`
243343
Annotations *Annotations `json:"annotations,omitempty"`
244344
Icons []Icon `json:"icons,omitempty"`
345+
// Fields for ToolUseContent (type: "tool_use")
346+
ID string `json:"id,omitempty"`
347+
Input map[string]any `json:"input,omitempty"`
348+
// Fields for ToolResultContent (type: "tool_result")
349+
ToolUseID string `json:"toolUseId,omitempty"`
350+
ToolResultContent []*wireContent `json:"content,omitempty"` // nested content for tool_result
351+
StructuredContent any `json:"structuredContent,omitempty"`
352+
IsError bool `json:"isError,omitempty"`
245353
}
246354

247355
func contentsFromWire(wires []*wireContent, allow map[string]bool) ([]Content, error) {
@@ -284,6 +392,27 @@ func contentFromWire(wire *wireContent, allow map[string]bool) (Content, error)
284392
v := new(EmbeddedResource)
285393
v.fromWire(wire)
286394
return v, nil
395+
case "tool_use":
396+
v := new(ToolUseContent)
397+
v.fromWire(wire)
398+
return v, nil
399+
case "tool_result":
400+
v := new(ToolResultContent)
401+
v.fromWire(wire)
402+
// Handle nested content - tool_result content can contain text, image, audio,
403+
// resource_link, and resource (same as CallToolResult.content)
404+
if wire.ToolResultContent != nil {
405+
toolResultContentAllow := map[string]bool{
406+
"text": true, "image": true, "audio": true,
407+
"resource_link": true, "resource": true,
408+
}
409+
nestedContent, err := contentsFromWire(wire.ToolResultContent, toolResultContentAllow)
410+
if err != nil {
411+
return nil, fmt.Errorf("tool_result nested content: %w", err)
412+
}
413+
v.Content = nestedContent
414+
}
415+
return v, nil
287416
}
288-
return nil, fmt.Errorf("internal error: unrecognized content type %s", wire.Type)
417+
return nil, fmt.Errorf("unrecognized content type %q", wire.Type)
289418
}

mcp/protocol.go

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,11 @@ type CreateMessageParams struct {
291291
Meta `json:"_meta,omitempty"`
292292
// A request to include context from one or more MCP servers (including the
293293
// caller), to be attached to the prompt. The client may ignore this request.
294+
//
295+
// The default behavior is Default is "none". Values "thisServer" and
296+
// "allServers" are soft-deprecated. Servers SHOULD only use these values if
297+
// the client declares ClientCapabilities.sampling.context. These values may
298+
// be removed in future spec releases.
294299
IncludeContext string `json:"includeContext,omitempty"`
295300
// The maximum number of tokens to sample, as requested by the server. The
296301
// client may choose to sample fewer tokens than requested.
@@ -307,6 +312,12 @@ type CreateMessageParams struct {
307312
// may modify or omit this prompt.
308313
SystemPrompt string `json:"systemPrompt,omitempty"`
309314
Temperature float64 `json:"temperature,omitempty"`
315+
// Tools is an optional list of tools available for the model to use.
316+
// Requires the client's sampling.tools capability.
317+
Tools []*Tool `json:"tools,omitempty"`
318+
// ToolChoice controls how the model should use tools.
319+
// Requires the client's sampling.tools capability.
320+
ToolChoice *ToolChoice `json:"toolChoice,omitempty"`
310321
}
311322

312323
func (x *CreateMessageParams) isParams() {}
@@ -326,6 +337,12 @@ type CreateMessageResult struct {
326337
Model string `json:"model"`
327338
Role Role `json:"role"`
328339
// The reason why sampling stopped, if known.
340+
//
341+
// Standard values:
342+
// - "endTurn": natural end of the assistant's turn
343+
// - "stopSequence": a stop sequence was encountered
344+
// - "maxTokens": reached the maximyum token limit
345+
// - "toolUse": the model wants to use one or more tools
329346
StopReason string `json:"stopReason,omitempty"`
330347
}
331348

@@ -339,8 +356,9 @@ func (r *CreateMessageResult) UnmarshalJSON(data []byte) error {
339356
if err := json.Unmarshal(data, &wire); err != nil {
340357
return err
341358
}
359+
// Allow text, image, audio, and tool_use in results
342360
var err error
343-
if wire.result.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true}); err != nil {
361+
if wire.result.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true, "tool_use": true}); err != nil {
344362
return err
345363
}
346364
*r = CreateMessageResult(wire.result)
@@ -876,7 +894,27 @@ func (x *RootsListChangedParams) GetProgressToken() any { return getProgressTok
876894
func (x *RootsListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) }
877895

878896
// SamplingCapabilities describes the capabilities for sampling.
879-
type SamplingCapabilities struct{}
897+
type SamplingCapabilities struct {
898+
// Context indicates the client supports includeContext values other than "none".
899+
Context *SamplingContextCapabilities `json:"context,omitempty"`
900+
// Tools indicates the client supports tools and toolChoice in sampling requests.
901+
Tools *SamplingToolsCapabilities `json:"tools,omitempty"`
902+
}
903+
904+
// SamplingContextCapabilities indicates the client supports context inclusion.
905+
type SamplingContextCapabilities struct{}
906+
907+
// SamplingToolsCapabilities indicates the client supports tool use in sampling.
908+
type SamplingToolsCapabilities struct{}
909+
910+
// ToolChoice controls how the model uses tools during sampling.
911+
type ToolChoice struct {
912+
// Mode controls tool invocation behavior:
913+
// - "auto": Model decides whether to use tools (default)
914+
// - "required": Model must use at least one tool
915+
// - "none": Model must not use any tools
916+
Mode string `json:"mode,omitempty"`
917+
}
880918

881919
// ElicitationCapabilities describes the capabilities for elicitation.
882920
//
@@ -895,6 +933,9 @@ type URLElicitationCapabilities struct {
895933
}
896934

897935
// Describes a message issued to or received from an LLM API.
936+
//
937+
// For assistant messages, Content may be text, image, audio, or tool_use.
938+
// For user messages, Content may be text, image, audio, or tool_result.
898939
type SamplingMessage struct {
899940
Content Content `json:"content"`
900941
Role Role `json:"role"`
@@ -911,8 +952,9 @@ func (m *SamplingMessage) UnmarshalJSON(data []byte) error {
911952
if err := json.Unmarshal(data, &wire); err != nil {
912953
return err
913954
}
955+
// Allow text, image, audio, tool_use, and tool_result in sampling messages
914956
var err error
915-
if wire.msg.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true}); err != nil {
957+
if wire.msg.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true, "tool_use": true, "tool_result": true}); err != nil {
916958
return err
917959
}
918960
*m = SamplingMessage(wire.msg)

0 commit comments

Comments
 (0)