diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml index fdd71a70..184a31c3 100644 --- a/.github/workflows/conformance.yml +++ b/.github/workflows/conformance.yml @@ -50,7 +50,7 @@ jobs: uses: modelcontextprotocol/conformance@a2855b03582a6c0b31065ad4d9af248316ce61a3 # v0.1.15 with: mode: client - command: go run ./conformance/everything-client/main.go + command: go run -tags mcp_go_client_oauth ./conformance/everything-client suite: core expected-failures: ./conformance/baseline.yml node-version: 22 diff --git a/README.md b/README.md index 2b015c6c..ffaf33e1 100644 --- a/README.md +++ b/README.md @@ -34,10 +34,15 @@ contains feature documentation, mapping the MCP spec to the packages above. The following table shows which versions of the Go SDK support which versions of the MCP specification: -| SDK Version | Latest MCP Spec | All Supported MCP Specs | -|-----------------|-------------------|------------------------------------------------| -| v1.2.0+ | 2025-06-18 | 2025-11-25, 2025-06-18, 2025-03-26, 2024-11-05 | -| v1.0.0 - v1.1.0 | 2025-06-18 | 2025-06-18, 2025-03-26, 2024-11-05 | +| SDK Version | Latest MCP Spec | All Supported MCP Specs | +|-----------------|-------------------|----------------------------------------------------| +| v1.4.0+ | 2025-11-25\* | 2025-11-25\*, 2025-06-18, 2025-03-26, 2024-11-05 | +| v1.2.0 - v1.3.1 | 2025-11-25\*\* | 2025-11-25\*\*, 2025-06-18, 2025-03-26, 2024-11-05 | +| v1.0.0 - v1.1.0 | 2025-06-18 | 2025-06-18, 2025-03-26, 2024-11-05 | + +\* Client side OAuth has experimental support. + +\*\* Partial support for 2025-11-25 (client side OAuth and Sampling with tools not available). New releases of the SDK target only supported versions of Go. See https://go.dev/doc/devel/release#policy for more information. diff --git a/auth/auth.go b/auth/auth.go index 87665121..36ff259e 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -25,8 +25,7 @@ type TokenInfo struct { // session hijacking by ensuring that all requests for a given session // come from the same user. UserID string - // TODO: add standard JWT fields - Extra map[string]any + Extra map[string]any } // The error that a TokenVerifier should return if the token cannot be verified. @@ -106,6 +105,9 @@ func verify(req *http.Request, verifier TokenVerifier, opts *RequireBearerTokenO } return nil, err.Error(), http.StatusInternalServerError } + if tokenInfo == nil { + return nil, "token validation failed", http.StatusInternalServerError + } // Check scopes. All must be present. if opts != nil { diff --git a/auth/authorization_code.go b/auth/authorization_code.go new file mode 100644 index 00000000..2a6ed32b --- /dev/null +++ b/auth/authorization_code.go @@ -0,0 +1,548 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "net/http" + "net/url" + "slices" + "strings" + + "github.com/modelcontextprotocol/go-sdk/oauthex" + "golang.org/x/oauth2" +) + +// ClientSecretAuthConfig is used to configure client authentication using client_secret. +// Authentication method will be selected based on the authorization server's supported methods, +// according to the following preference order: +// 1. client_secret_post +// 2. client_secret_basic +type ClientSecretAuthConfig struct { + // ClientID is the client ID to be used for client authentication. + ClientID string + // ClientSecret is the client secret to be used for client authentication. + ClientSecret string +} + +// ClientIDMetadataDocumentConfig is used to configure the Client ID Metadata Document +// based client registration per +// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#client-id-metadata-documents. +// See https://client.dev/ for more information. +type ClientIDMetadataDocumentConfig struct { + // URL is the client identifier URL as per + // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-client-id-metadata-document-00#section-3. + URL string +} + +// PreregisteredClientConfig is used to configure a pre-registered client per +// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#preregistration. +// Currently only "client_secret_basic" and "client_secret_post" authentication methods are supported. +type PreregisteredClientConfig struct { + // ClientSecretAuthConfig is the client_secret based configuration to be used for client authentication. + ClientSecretAuthConfig *ClientSecretAuthConfig +} + +// DynamicClientRegistrationConfig is used to configure dynamic client registration per +// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#dynamic-client-registration. +type DynamicClientRegistrationConfig struct { + // Metadata to be used in dynamic client registration request as per + // https://datatracker.ietf.org/doc/html/rfc7591#section-2. + Metadata *oauthex.ClientRegistrationMetadata +} + +// AuthorizationResult is the result of an authorization flow. +// It is returned by [AuthorizationCodeHandler].AuthorizationCodeFetcher implementations. +type AuthorizationResult struct { + // Code is the authorization code obtained from the authorization server. + Code string + // State string returned by the authorization server. + State string +} + +// AuthorizationArgs is the input to [AuthorizationCodeHandlerConfig].AuthorizationCodeFetcher. +type AuthorizationArgs struct { + // Authorization URL to be opened in a browser for the user to start the authorization process. + URL string +} + +// AuthorizationCodeHandlerConfig is the configuration for [AuthorizationCodeHandler]. +type AuthorizationCodeHandlerConfig struct { + // Client registration configuration. + // It is attempted in the following order: + // 1. Client ID Metadata Document + // 2. Preregistration + // 3. Dynamic Client Registration + // At least one method must be configured. + ClientIDMetadataDocumentConfig *ClientIDMetadataDocumentConfig + PreregisteredClientConfig *PreregisteredClientConfig + DynamicClientRegistrationConfig *DynamicClientRegistrationConfig + + // RedirectURL is a required URL to redirect to after authorization. + // The caller is responsible for handling the redirect out of band. + // + // If Dynamic Client Registration is used: + // - this field is permitted to be empty, in which case it will be set + // to the first redirect URI from + // DynamicClientRegistrationConfig.Metadata.RedirectURIs. + // - if the field is not empty, it must be one of the redirect URIs in + // DynamicClientRegistrationConfig.Metadata.RedirectURIs. + RedirectURL string + + // AuthorizationCodeFetcher is a required function called to initiate the authorization flow. + // It is responsible for opening the URL in a browser for the user to start the authorization process. + // It should return the authorization code and state once the Authorization Server + // redirects back to the RedirectURL. + AuthorizationCodeFetcher func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) +} + +// AuthorizationCodeHandler is an implementation of [OAuthHandler] that uses +// the authorization code flow to obtain access tokens. +type AuthorizationCodeHandler struct { + config *AuthorizationCodeHandlerConfig + + // tokenSource is the token source to use for authorization. + tokenSource oauth2.TokenSource +} + +var _ OAuthHandler = (*AuthorizationCodeHandler)(nil) + +func (h *AuthorizationCodeHandler) isOAuthHandler() {} + +func (h *AuthorizationCodeHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { + return h.tokenSource, nil +} + +// NewAuthorizationCodeHandler creates a new AuthorizationCodeHandler. +// It performs validation of the configuration and returns an error if it is invalid. +// The passed config is consumed by the handler and should not be modified after. +func NewAuthorizationCodeHandler(config *AuthorizationCodeHandlerConfig) (*AuthorizationCodeHandler, error) { + if config == nil { + return nil, errors.New("config must be provided") + } + if config.ClientIDMetadataDocumentConfig == nil && + config.PreregisteredClientConfig == nil && + config.DynamicClientRegistrationConfig == nil { + return nil, errors.New("at least one client registration configuration must be provided") + } + if config.AuthorizationCodeFetcher == nil { + return nil, errors.New("AuthorizationCodeFetcher is required") + } + if config.ClientIDMetadataDocumentConfig != nil && !isNonRootHTTPSURL(config.ClientIDMetadataDocumentConfig.URL) { + return nil, fmt.Errorf("client ID metadata document URL must be a non-root HTTPS URL") + } + preCfg := config.PreregisteredClientConfig + if preCfg != nil { + if preCfg.ClientSecretAuthConfig == nil { + return nil, errors.New("ClientSecretAuthConfig is required for pre-registered client") + } + if preCfg.ClientSecretAuthConfig.ClientID == "" || preCfg.ClientSecretAuthConfig.ClientSecret == "" { + return nil, fmt.Errorf("pre-registered client ID or secret is empty") + } + } + dCfg := config.DynamicClientRegistrationConfig + if dCfg != nil { + if dCfg.Metadata == nil { + return nil, errors.New("Metadata is required for dynamic client registration") + } + if len(dCfg.Metadata.RedirectURIs) == 0 { + return nil, errors.New("Metadata.RedirectURIs is required for dynamic client registration") + } + if config.RedirectURL == "" { + config.RedirectURL = dCfg.Metadata.RedirectURIs[0] + } else if !slices.Contains(dCfg.Metadata.RedirectURIs, config.RedirectURL) { + return nil, fmt.Errorf("RedirectURL %q is not in the list of allowed redirect URIs for dynamic client registration", config.RedirectURL) + } + } + if config.RedirectURL == "" { + // If the RedirectURL was supposed to be set by the dynamic client registration, + // it should have been set by now. Otherwise, it is required. + return nil, errors.New("RedirectURL is required") + } + return &AuthorizationCodeHandler{config: config}, nil +} + +func isNonRootHTTPSURL(u string) bool { + pu, err := url.Parse(u) + if err != nil { + return false + } + return pu.Scheme == "https" && pu.Path != "" +} + +// Authorize performs the authorization flow. +// It is designed to perform the whole Authorization Code Grant flow. +// On success, [AuthorizationCodeHandler.TokenSource] will return a token source with the fetched token. +func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { + defer resp.Body.Close() + + wwwChallenges, err := oauthex.ParseWWWAuthenticate(resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")]) + if err != nil { + return fmt.Errorf("failed to parse WWW-Authenticate header: %v", err) + } + + if resp.StatusCode == http.StatusForbidden && errorFromChallenges(wwwChallenges) != "insufficient_scope" { + // We only want to perform step-up authorization for insufficient_scope errors. + // Returning nil, so that the call is retried immediately and the response + // is handled appropriately by the connection. + // Step-up authorization is defined at + // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#step-up-authorization-flow + return nil + } + + prm, err := h.getProtectedResourceMetadata(ctx, wwwChallenges, req.URL.String()) + if err != nil { + return err + } + + asm, err := h.getAuthServerMetadata(ctx, prm) + if err != nil { + return err + } + + resolvedClientConfig, err := h.handleRegistration(ctx, asm) + if err != nil { + return err + } + + scps := scopesFromChallenges(wwwChallenges) + if len(scps) == 0 && len(prm.ScopesSupported) > 0 { + scps = prm.ScopesSupported + } + + cfg := &oauth2.Config{ + ClientID: resolvedClientConfig.clientID, + ClientSecret: resolvedClientConfig.clientSecret, + + Endpoint: oauth2.Endpoint{ + AuthURL: asm.AuthorizationEndpoint, + TokenURL: asm.TokenEndpoint, + AuthStyle: resolvedClientConfig.authStyle, + }, + RedirectURL: h.config.RedirectURL, + Scopes: scps, + } + + authRes, err := h.getAuthorizationCode(ctx, cfg, req.URL.String()) + if err != nil { + // Purposefully leaving the error unwrappable so it can be handled by the caller. + return err + } + + return h.exchangeAuthorizationCode(ctx, cfg, authRes, prm.Resource) +} + +// resourceMetadataURLFromChallenges returns a resource metadata URL from the given "WWW-Authenticate" header challenges, +// or the empty string if there is none. +func resourceMetadataURLFromChallenges(cs []oauthex.Challenge) string { + for _, c := range cs { + if u := c.Params["resource_metadata"]; u != "" { + return u + } + } + return "" +} + +// scopesFromChallenges returns the scopes from the given "WWW-Authenticate" header challenges. +// It only looks at challenges with the "Bearer" scheme. +func scopesFromChallenges(cs []oauthex.Challenge) []string { + for _, c := range cs { + if c.Scheme == "bearer" && c.Params["scope"] != "" { + return strings.Fields(c.Params["scope"]) + } + } + return nil +} + +// errorFromChallenges returns the error from the given "WWW-Authenticate" header challenges. +// It only looks at challenges with the "Bearer" scheme. +func errorFromChallenges(cs []oauthex.Challenge) string { + for _, c := range cs { + if c.Scheme == "bearer" && c.Params["error"] != "" { + return c.Params["error"] + } + } + return "" +} + +// getProtectedResourceMetadata returns the protected resource metadata. +// If no metadata was found or the fetched metadata fails security checks, +// it returns an error. +func (h *AuthorizationCodeHandler) getProtectedResourceMetadata(ctx context.Context, wwwChallenges []oauthex.Challenge, mcpServerURL string) (*oauthex.ProtectedResourceMetadata, error) { + var errs []error + // Use MCP server URL as the resource URI per + // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#canonical-server-uri. + for _, url := range protectedResourceMetadataURLs(resourceMetadataURLFromChallenges(wwwChallenges), mcpServerURL) { + prm, err := oauthex.GetProtectedResourceMetadata(ctx, url.URL, url.Resource, http.DefaultClient) + if err != nil { + errs = append(errs, err) + continue + } + if prm == nil { + errs = append(errs, fmt.Errorf("protected resource metadata is nil")) + continue + } + return prm, nil + } + return nil, fmt.Errorf("failed to get protected resource metadata: %v", errors.Join(errs...)) +} + +type prmURL struct { + // URL represents a URL where Protected Resource Metadata may be retrieved. + URL string + // Resource represents the corresponding resource URL for [URL]. + // It is required to perform validation described in RFC 9728, section 3.3. + Resource string +} + +// protectedResourceMetadataURLs returns a list of URLs to try when looking for +// protected resource metadata as mandated by the MCP specification: +// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#protected-resource-metadata-discovery-requirements +func protectedResourceMetadataURLs(metadataURL, resourceURL string) []prmURL { + var urls []prmURL + if metadataURL != "" { + urls = append(urls, prmURL{ + URL: metadataURL, + Resource: resourceURL, + }) + } + ru, err := url.Parse(resourceURL) + if err != nil { + return urls + } + mu := *ru + // "At the path of the server's MCP endpoint". + mu.Path = "/.well-known/oauth-protected-resource/" + strings.TrimLeft(ru.Path, "/") + urls = append(urls, prmURL{ + URL: mu.String(), + Resource: resourceURL, + }) + // "At the root". + mu.Path = "/.well-known/oauth-protected-resource" + ru.Path = "" + urls = append(urls, prmURL{ + URL: mu.String(), + Resource: ru.String(), + }) + return urls +} + +// getAuthServerMetadata returns the authorization server metadata. +// The provided Protected Resource Metadata must not be nil. +// It returns an error if the metadata request fails with non-4xx HTTP status code +// or the fetched metadata fails security checks. +// If no metadata was found, it returns a minimal set of endpoints +// as a fallback to 2025-03-26 spec. +func (h *AuthorizationCodeHandler) getAuthServerMetadata(ctx context.Context, prm *oauthex.ProtectedResourceMetadata) (*oauthex.AuthServerMeta, error) { + var authServerURL string + if len(prm.AuthorizationServers) > 0 { + // Use the first authorization server, similarly to other SDKs. + authServerURL = prm.AuthorizationServers[0] + } else { + // Fallback to 2025-03-26 spec: MCP server base URL acts as Authorization Server. + authURL, err := url.Parse(prm.Resource) + if err != nil { + return nil, fmt.Errorf("failed to parse resource URL: %v", err) + } + authURL.Path = "" + authServerURL = authURL.String() + } + + for _, u := range authorizationServerMetadataURLs(authServerURL) { + asm, err := oauthex.GetAuthServerMeta(ctx, u, authServerURL, http.DefaultClient) + if err != nil { + return nil, fmt.Errorf("failed to get authorization server metadata: %w", err) + } + if asm != nil { + return asm, nil + } + } + + // Fallback to 2025-03-26 spec: predefined endpoints. + // https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization#fallbacks-for-servers-without-metadata-discovery + asm := &oauthex.AuthServerMeta{ + Issuer: authServerURL, + AuthorizationEndpoint: authServerURL + "/authorize", + TokenEndpoint: authServerURL + "/token", + RegistrationEndpoint: authServerURL + "/register", + } + return asm, nil +} + +// authorizationServerMetadataURLs returns a list of URLs to try when looking for +// authorization server metadata as mandated by the MCP specification: +// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#authorization-server-metadata-discovery. +func authorizationServerMetadataURLs(issuerURL string) []string { + var urls []string + + baseURL, err := url.Parse(issuerURL) + if err != nil { + return nil + } + + if baseURL.Path == "" { + // "OAuth 2.0 Authorization Server Metadata". + baseURL.Path = "/.well-known/oauth-authorization-server" + urls = append(urls, baseURL.String()) + // "OpenID Connect Discovery 1.0". + baseURL.Path = "/.well-known/openid-configuration" + urls = append(urls, baseURL.String()) + return urls + } + + originalPath := baseURL.Path + // "OAuth 2.0 Authorization Server Metadata with path insertion". + baseURL.Path = "/.well-known/oauth-authorization-server/" + strings.TrimLeft(originalPath, "/") + urls = append(urls, baseURL.String()) + // "OpenID Connect Discovery 1.0 with path insertion". + baseURL.Path = "/.well-known/openid-configuration/" + strings.TrimLeft(originalPath, "/") + urls = append(urls, baseURL.String()) + // "OpenID Connect Discovery 1.0 with path appending". + baseURL.Path = "/" + strings.Trim(originalPath, "/") + "/.well-known/openid-configuration" + urls = append(urls, baseURL.String()) + return urls +} + +type registrationType int + +const ( + registrationTypeClientIDMetadataDocument registrationType = iota + registrationTypePreregistered + registrationTypeDynamic +) + +type resolvedClientConfig struct { + registrationType registrationType + clientID string + clientSecret string + authStyle oauth2.AuthStyle +} + +func selectTokenAuthMethod(supported []string) oauth2.AuthStyle { + prefOrder := []string{ + // Preferred in OAuth 2.1 draft: https://www.ietf.org/archive/id/draft-ietf-oauth-v2-1-14.html#name-client-secret. + "client_secret_post", + "client_secret_basic", + } + for _, method := range prefOrder { + if slices.Contains(supported, method) { + return authMethodToStyle(method) + } + } + return oauth2.AuthStyleAutoDetect +} + +func authMethodToStyle(method string) oauth2.AuthStyle { + switch method { + case "client_secret_post": + return oauth2.AuthStyleInParams + case "client_secret_basic": + return oauth2.AuthStyleInHeader + case "none": + // "none" is equivalent to "client_secret_post" but without sending client secret. + return oauth2.AuthStyleInParams + default: + // "client_secret_basic" is the default per https://datatracker.ietf.org/doc/html/rfc7591#section-2. + return oauth2.AuthStyleInHeader + } +} + +// handleRegistration handles client registration. +// The provided authorization server metadata must be non-nil. +// Support for different registration methods is defined as follows: +// - Client ID Metadata Document: metadata must have +// `ClientIDMetadataDocumentSupported` set to true. +// - Pre-registered client: assumed to be supported. +// - Dynamic client registration: metadata must have +// `RegistrationEndpoint` set to a non-empty value. +func (h *AuthorizationCodeHandler) handleRegistration(ctx context.Context, asm *oauthex.AuthServerMeta) (*resolvedClientConfig, error) { + // 1. Attempt to use Client ID Metadata Document (SEP-991). + cimdCfg := h.config.ClientIDMetadataDocumentConfig + if cimdCfg != nil && asm.ClientIDMetadataDocumentSupported { + return &resolvedClientConfig{ + registrationType: registrationTypeClientIDMetadataDocument, + clientID: cimdCfg.URL, + }, nil + } + // 2. Attempt to use pre-registered client configuration. + pCfg := h.config.PreregisteredClientConfig + if pCfg != nil { + authStyle := selectTokenAuthMethod(asm.TokenEndpointAuthMethodsSupported) + return &resolvedClientConfig{ + registrationType: registrationTypePreregistered, + clientID: pCfg.ClientSecretAuthConfig.ClientID, + clientSecret: pCfg.ClientSecretAuthConfig.ClientSecret, + authStyle: authStyle, + }, nil + } + // 3. Attempt to use dynamic client registration. + dcrCfg := h.config.DynamicClientRegistrationConfig + if dcrCfg != nil && asm.RegistrationEndpoint != "" { + regResp, err := oauthex.RegisterClient(ctx, asm.RegistrationEndpoint, dcrCfg.Metadata, http.DefaultClient) + if err != nil { + return nil, fmt.Errorf("failed to register client: %w", err) + } + cfg := &resolvedClientConfig{ + registrationType: registrationTypeDynamic, + clientID: regResp.ClientID, + clientSecret: regResp.ClientSecret, + authStyle: authMethodToStyle(regResp.TokenEndpointAuthMethod), + } + return cfg, nil + } + return nil, fmt.Errorf("no configured client registration methods are supported by the authorization server") +} + +type authResult struct { + *AuthorizationResult + // usedCodeVerifier is the PKCE code verifier used to obtain the authorization code. + // It is preserved for the token exchange step. + usedCodeVerifier string +} + +// getAuthorizationCode uses the [AuthorizationCodeHandler.AuthorizationCodeFetcher] +// to obtain an authorization code. +func (h *AuthorizationCodeHandler) getAuthorizationCode(ctx context.Context, cfg *oauth2.Config, resourceURL string) (*authResult, error) { + codeVerifier := oauth2.GenerateVerifier() + state := rand.Text() + + authURL := cfg.AuthCodeURL(state, + oauth2.S256ChallengeOption(codeVerifier), + oauth2.SetAuthURLParam("resource", resourceURL), + ) + + authRes, err := h.config.AuthorizationCodeFetcher(ctx, &AuthorizationArgs{URL: authURL}) + if err != nil { + // Purposefully leaving the error unwrappable so it can be handled by the caller. + return nil, err + } + if authRes.State != state { + return nil, fmt.Errorf("state mismatch") + } + return &authResult{ + AuthorizationResult: authRes, + usedCodeVerifier: codeVerifier, + }, nil +} + +// exchangeAuthorizationCode exchanges the authorization code for a token +// and stores it in a token source. +func (h *AuthorizationCodeHandler) exchangeAuthorizationCode(ctx context.Context, cfg *oauth2.Config, authResult *authResult, resourceURL string) error { + opts := []oauth2.AuthCodeOption{ + oauth2.VerifierOption(authResult.usedCodeVerifier), + oauth2.SetAuthURLParam("resource", resourceURL), + } + token, err := cfg.Exchange(ctx, authResult.Code, opts...) + if err != nil { + return fmt.Errorf("token exchange failed: %w", err) + } + h.tokenSource = cfg.TokenSource(ctx, token) + return nil +} diff --git a/auth/authorization_code_test.go b/auth/authorization_code_test.go new file mode 100644 index 00000000..77214cd9 --- /dev/null +++ b/auth/authorization_code_test.go @@ -0,0 +1,658 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "net/http/httputil" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/modelcontextprotocol/go-sdk/internal/oauthtest" + "github.com/modelcontextprotocol/go-sdk/oauthex" + "golang.org/x/oauth2" +) + +func TestAuthorize(t *testing.T) { + authServer := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{ + RegistrationConfig: &oauthtest.RegistrationConfig{ + PreregisteredClients: map[string]oauthtest.ClientInfo{ + "test_client_id": { + Secret: "test_client_secret", + RedirectURIs: []string{"http://localhost:12345/callback"}, + }, + }, + }, + }) + authServer.Start(t) + + resourceMux := http.NewServeMux() + resourceServer := httptest.NewServer(resourceMux) + t.Cleanup(resourceServer.Close) + resourceURL := resourceServer.URL + "/resource" + + resourceMux.Handle("/.well-known/oauth-protected-resource/resource", ProtectedResourceMetadataHandler(&oauthex.ProtectedResourceMetadata{ + Resource: resourceURL, + AuthorizationServers: []string{authServer.URL()}, + })) + + handler, err := NewAuthorizationCodeHandler(&AuthorizationCodeHandlerConfig{ + RedirectURL: "http://localhost:12345/callback", + PreregisteredClientConfig: &PreregisteredClientConfig{ + ClientSecretAuthConfig: &ClientSecretAuthConfig{ + ClientID: "test_client_id", + ClientSecret: "test_client_secret", + }, + }, + AuthorizationCodeFetcher: func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) { + // The fake authorization server will redirect to an URL with code and state. + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + resp, err := client.Get(args.URL) + if err != nil { + return nil, fmt.Errorf("failed to visit auth URL: %v", err) + } + defer resp.Body.Close() + dump, err := httputil.DumpResponse(resp, true) + if err != nil { + t.Fatalf("failed to dump response: %v", err) + } + t.Log(string(dump)) + + location, err := resp.Location() + if err != nil { + return nil, fmt.Errorf("failed to get location header: %v", err) + } + return &AuthorizationResult{ + Code: location.Query().Get("code"), + State: location.Query().Get("state"), + }, nil + }, + }) + if err != nil { + t.Fatalf("NewAuthorizationCodeHandler failed: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, resourceURL, nil) + resp := &http.Response{ + StatusCode: http.StatusUnauthorized, + Header: make(http.Header), + Body: http.NoBody, + Request: req, + } + resp.Header.Set( + "WWW-Authenticate", + "Bearer resource_metadata="+resourceServer.URL+"/.well-known/oauth-protected-resource/resource", + ) + + if err := handler.Authorize(context.Background(), req, resp); err != nil { + t.Fatalf("Authorize failed: %v", err) + } + + tokenSource, err := handler.TokenSource(t.Context()) + if err != nil { + t.Fatalf("Failed to get token source: %v", err) + } + token, err := tokenSource.Token() + if err != nil { + t.Fatalf("Failed to get token: %v", err) + } + if token.AccessToken != "test_access_token" { + t.Errorf("Expected access token 'test_access_token', got '%s'", token.AccessToken) + } +} + +func TestAuthorize_ForbiddenUnhandledError(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://example.com/resource", nil) + resp := &http.Response{ + StatusCode: http.StatusForbidden, + Header: make(http.Header), + Body: http.NoBody, + Request: req, + } + resp.Header.Set( + "WWW-Authenticate", + "Bearer error=invalid_token", + ) + handler := &AuthorizationCodeHandler{} // No config needed for this test. + err := handler.Authorize(t.Context(), req, resp) + if err != nil { + t.Fatalf("Authorize() failed: %v", err) + } +} + +func TestNewAuthorizationCodeHandler_Success(t *testing.T) { + simpleHandler := func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) { + return nil, nil + } + tests := []struct { + name string + config *AuthorizationCodeHandlerConfig + }{ + { + name: "ClientIDMetadataDocumentConfig", + config: &AuthorizationCodeHandlerConfig{ + ClientIDMetadataDocumentConfig: &ClientIDMetadataDocumentConfig{URL: "https://example.com/client"}, + RedirectURL: "https://example.com/callback", + AuthorizationCodeFetcher: simpleHandler, + }, + }, + { + name: "PreregisteredClientConfig", + config: &AuthorizationCodeHandlerConfig{ + PreregisteredClientConfig: &PreregisteredClientConfig{ + ClientSecretAuthConfig: &ClientSecretAuthConfig{ + ClientID: "test_client_id", + ClientSecret: "test_client_secret", + }, + }, + RedirectURL: "https://example.com/callback", + AuthorizationCodeFetcher: simpleHandler, + }, + }, + { + name: "DynamicClientRegistrationConfig_NoRedirectURL", + config: &AuthorizationCodeHandlerConfig{ + DynamicClientRegistrationConfig: &DynamicClientRegistrationConfig{ + Metadata: &oauthex.ClientRegistrationMetadata{ + RedirectURIs: []string{ + "https://example.com/callback", + }, + }, + }, + AuthorizationCodeFetcher: simpleHandler, + }, + }, + { + name: "DynamicClientRegistrationConfig_WithRedirectURL", + config: &AuthorizationCodeHandlerConfig{ + DynamicClientRegistrationConfig: &DynamicClientRegistrationConfig{ + Metadata: &oauthex.ClientRegistrationMetadata{ + RedirectURIs: []string{ + "https://example.com/callback", + }, + }, + }, + RedirectURL: "https://example.com/callback", + AuthorizationCodeFetcher: simpleHandler, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if _, err := NewAuthorizationCodeHandler(tt.config); err != nil { + t.Fatalf("NewAuthorizationCodeHandler failed: %v", err) + } + }) + } +} + +func TestNewAuthorizationCodeHandler_Error(t *testing.T) { + validConfig := func() *AuthorizationCodeHandlerConfig { + return &AuthorizationCodeHandlerConfig{ + ClientIDMetadataDocumentConfig: &ClientIDMetadataDocumentConfig{URL: "https://example.com/client"}, + RedirectURL: "https://example.com/callback", + AuthorizationCodeFetcher: func(ctx context.Context, args *AuthorizationArgs) (*AuthorizationResult, error) { + return nil, nil + }, + } + } + // Ensure the base config is valid + if _, err := NewAuthorizationCodeHandler(validConfig()); err != nil { + t.Fatalf("NewAuthorizationCodeHandler failed: %v", err) + } + + tests := []struct { + name string + config func() *AuthorizationCodeHandlerConfig + }{ + { + name: "NilConfig", + config: func() *AuthorizationCodeHandlerConfig { + return nil + }, + }, + { + name: "NoRegistrationConfig", + config: func() *AuthorizationCodeHandlerConfig { + cfg := validConfig() + cfg.ClientIDMetadataDocumentConfig = nil + cfg.PreregisteredClientConfig = nil + cfg.DynamicClientRegistrationConfig = nil + return cfg + }, + }, + { + name: "MissingRedirectURL", + config: func() *AuthorizationCodeHandlerConfig { + cfg := validConfig() + cfg.RedirectURL = "" + return cfg + }, + }, + { + name: "MissingAuthorizationCodeFetcher", + config: func() *AuthorizationCodeHandlerConfig { + cfg := validConfig() + cfg.AuthorizationCodeFetcher = nil + return cfg + }, + }, + { + name: "InvalidMetadataURL", + config: func() *AuthorizationCodeHandlerConfig { + cfg := validConfig() + cfg.ClientIDMetadataDocumentConfig.URL = "https://example.com" + return cfg + }, + }, + { + name: "InvalidPreregistered_MissingSecretConfig", + config: func() *AuthorizationCodeHandlerConfig { + cfg := validConfig() + cfg.PreregisteredClientConfig = &PreregisteredClientConfig{} + return cfg + }, + }, + { + name: "InvalidPreregistered_EmptyID", + config: func() *AuthorizationCodeHandlerConfig { + cfg := validConfig() + cfg.PreregisteredClientConfig = &PreregisteredClientConfig{ + ClientSecretAuthConfig: &ClientSecretAuthConfig{ + ClientSecret: "secret", + }, + } + return cfg + }, + }, + { + name: "InvalidPreregistered_EmptySecret", + config: func() *AuthorizationCodeHandlerConfig { + cfg := validConfig() + cfg.PreregisteredClientConfig = &PreregisteredClientConfig{ + ClientSecretAuthConfig: &ClientSecretAuthConfig{ + ClientID: "test_client_id", + }, + } + return cfg + }, + }, + { + name: "InvalidDynamic_MissingMetadata", + config: func() *AuthorizationCodeHandlerConfig { + cfg := validConfig() + cfg.DynamicClientRegistrationConfig = &DynamicClientRegistrationConfig{} + return cfg + }, + }, + { + name: "InvalidDynamic_InconsistentRedirectURI", + config: func() *AuthorizationCodeHandlerConfig { + cfg := validConfig() + cfg.DynamicClientRegistrationConfig = &DynamicClientRegistrationConfig{ + Metadata: &oauthex.ClientRegistrationMetadata{ + RedirectURIs: []string{"https://example.com/callback1"}, + }, + } + cfg.RedirectURL = "https://example.com/callback2" + return cfg + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewAuthorizationCodeHandler(tt.config()) + if err == nil { + t.Errorf("NewAuthorizationCodeHandler() = nil, want error") + } + }) + } +} + +func TestGetProtectedResourceMetadata(t *testing.T) { + handler := &AuthorizationCodeHandler{} // No config needed for this method + pathForChallenge := "/protected-resource" + + tests := []struct { + name string + challengesProvided bool + prmPath string + resourcePath string + wantError bool + }{ + { + name: "FromChallenges", + challengesProvided: true, + prmPath: pathForChallenge, + resourcePath: "/resource", + wantError: false, + }, + { + name: "FallbackToEndpoint", + challengesProvided: false, + prmPath: "/.well-known/oauth-protected-resource/resource", + resourcePath: "/resource", + wantError: false, + }, + { + name: "FallbackToRoot", + challengesProvided: false, + prmPath: "/.well-known/oauth-protected-resource", + resourcePath: "", + wantError: false, + }, + { + name: "NoMetadata", + challengesProvided: false, + prmPath: "/incorrect", + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mux := http.NewServeMux() + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + resourceURL := server.URL + tt.resourcePath + metadata := &oauthex.ProtectedResourceMetadata{ + Resource: resourceURL, + ScopesSupported: []string{"read", "write"}, + } + mux.Handle(tt.prmPath, ProtectedResourceMetadataHandler(metadata)) + var challenges []oauthex.Challenge + if tt.challengesProvided { + challenges = []oauthex.Challenge{ + { + Scheme: "Bearer", + Params: map[string]string{ + "resource_metadata": server.URL + pathForChallenge, + }, + }, + } + } + + got, err := handler.getProtectedResourceMetadata(t.Context(), challenges, resourceURL) + if err != nil { + if !tt.wantError { + t.Fatalf("getProtectedResourceMetadata() error = %v, want nil", err) + } + return + } + if got == nil { + t.Fatal("getProtectedResourceMetadata() got nil, want metadata") + } + if diff := cmp.Diff(metadata, got); diff != "" { + t.Errorf("getProtectedResourceMetadata() metadata mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestGetAuthServerMetadata(t *testing.T) { + handler := &AuthorizationCodeHandler{} // No config needed for this method + + tests := []struct { + name string + authorizationAtMCPServer bool + issuerPath string + endpointConfig *oauthtest.MetadataEndpointConfig + }{ + { + name: "OAuthEndpoint_Root", + authorizationAtMCPServer: false, + issuerPath: "", + endpointConfig: &oauthtest.MetadataEndpointConfig{ + ServeOAuthInsertedEndpoint: true, + }, + }, + { + name: "OpenIDEndpoint_Root", + authorizationAtMCPServer: false, + issuerPath: "", + endpointConfig: &oauthtest.MetadataEndpointConfig{ + ServeOpenIDInsertedEndpoint: true, + }, + }, + { + name: "OAuthEndpoint_Path", + authorizationAtMCPServer: false, + issuerPath: "/oauth", + endpointConfig: &oauthtest.MetadataEndpointConfig{ + ServeOAuthInsertedEndpoint: true, + }, + }, + { + name: "OpenIDEndpoint_Path", + authorizationAtMCPServer: false, + issuerPath: "/openid", + endpointConfig: &oauthtest.MetadataEndpointConfig{ + ServeOpenIDInsertedEndpoint: true, + }, + }, + { + name: "OpenIDAppendedEndpoint_Path", + authorizationAtMCPServer: false, + issuerPath: "/openid", + endpointConfig: &oauthtest.MetadataEndpointConfig{ + ServeOpenIDAppendedEndpoint: true, + }, + }, + { + name: "FallbackToMCPServer", + authorizationAtMCPServer: true, + }, + { + name: "NoMetadata", + issuerPath: "", + endpointConfig: &oauthtest.MetadataEndpointConfig{ + // All metadata endpoints disabled. + ServeOAuthInsertedEndpoint: false, + ServeOpenIDInsertedEndpoint: false, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{ + IssuerPath: tt.issuerPath, + MetadataEndpointConfig: tt.endpointConfig, + }) + s.Start(t) + issuerURL := s.URL() + tt.issuerPath + resourceURL := "https://example.com/resource" + authServers := []string{issuerURL} + if tt.authorizationAtMCPServer { + resourceURL = issuerURL + authServers = nil + } + prm := &oauthex.ProtectedResourceMetadata{ + Resource: resourceURL, + AuthorizationServers: authServers, + } + + got, err := handler.getAuthServerMetadata(t.Context(), prm) + if err != nil { + t.Fatalf("getAuthServerMetadata() error = %v, want nil", err) + } + if got == nil { + t.Fatal("getAuthServerMetadata() got nil, want metadata") + } + if got.Issuer != issuerURL { + t.Errorf("getAuthServerMetadata() issuer = %q, want %q", got.Issuer, issuerURL) + } + }) + } +} + +func TestSelectTokenAuthMethod(t *testing.T) { + tests := []struct { + name string + supported []string + want oauth2.AuthStyle + }{ + { + name: "PostPreferredOverBasic", + supported: []string{"client_secret_basic", "client_secret_post"}, + want: oauth2.AuthStyleInParams, + }, + { + name: "BasicChosenIfPostNotSupported", + supported: []string{"private_key_jwt", "client_secret_basic"}, + want: oauth2.AuthStyleInHeader, + }, + { + name: "NoneSupported", + supported: []string{"private_key_jwt"}, + want: oauth2.AuthStyleAutoDetect, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := selectTokenAuthMethod(tt.supported) + if got != tt.want { + t.Errorf("selectTokenAuthMethod() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHandleRegistration(t *testing.T) { + tests := []struct { + name string + serverConfig *oauthtest.RegistrationConfig + handlerConfig *AuthorizationCodeHandlerConfig + asm *oauthex.AuthServerMeta + want *resolvedClientConfig + wantError bool + }{ + { + name: "ClientIDMetadataDocument", + serverConfig: &oauthtest.RegistrationConfig{ + ClientIDMetadataDocumentSupported: true, + }, + handlerConfig: &AuthorizationCodeHandlerConfig{ + ClientIDMetadataDocumentConfig: &ClientIDMetadataDocumentConfig{URL: "https://client.example.com"}, + }, + want: &resolvedClientConfig{ + registrationType: registrationTypeClientIDMetadataDocument, + clientID: "https://client.example.com", + }, + }, + { + name: "Preregistered", + serverConfig: &oauthtest.RegistrationConfig{ + PreregisteredClients: map[string]oauthtest.ClientInfo{ + "pre_client_id": { + Secret: "pre_client_secret", + }, + }, + }, + handlerConfig: &AuthorizationCodeHandlerConfig{ + PreregisteredClientConfig: &PreregisteredClientConfig{ + ClientSecretAuthConfig: &ClientSecretAuthConfig{ + ClientID: "pre_client_id", + ClientSecret: "pre_client_secret", + }, + }, + }, + want: &resolvedClientConfig{ + registrationType: registrationTypePreregistered, + clientID: "pre_client_id", + clientSecret: "pre_client_secret", + authStyle: oauth2.AuthStyleInParams, + }, + }, + { + name: "NoneSupported", + handlerConfig: &AuthorizationCodeHandlerConfig{ + ClientIDMetadataDocumentConfig: &ClientIDMetadataDocumentConfig{URL: "https://client.example.com"}, + }, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{RegistrationConfig: tt.serverConfig}) + s.Start(t) + handler := &AuthorizationCodeHandler{config: tt.handlerConfig} + asm, err := handler.getAuthServerMetadata(t.Context(), &oauthex.ProtectedResourceMetadata{ + AuthorizationServers: []string{s.URL()}, + }) + if err != nil { + t.Fatalf("getAuthServerMetadata() error = %v, want nil", err) + } + got, err := handler.handleRegistration(t.Context(), asm) + if err != nil { + if !tt.wantError { + t.Fatalf("handleRegistration() unexpected error = %v", err) + } + return + } + if got.registrationType != tt.want.registrationType { + t.Errorf("handleRegistration() registrationType = %v, want %v", got.registrationType, tt.want.registrationType) + } + if got.clientID != tt.want.clientID { + t.Errorf("handleRegistration() clientID = %q, want %q", got.clientID, tt.want.clientID) + } + if got.clientSecret != tt.want.clientSecret { + t.Errorf("handleRegistration() clientSecret = %q, want %q", got.clientSecret, tt.want.clientSecret) + } + if got.authStyle != tt.want.authStyle { + t.Errorf("handleRegistration() authStyle = %v, want %v", got.authStyle, tt.want.authStyle) + } + }) + } +} + +func TestDynamicRegistration(t *testing.T) { + s := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{ + RegistrationConfig: &oauthtest.RegistrationConfig{ + DynamicClientRegistrationEnabled: true, + }, + }) + s.Start(t) + handler := &AuthorizationCodeHandler{config: &AuthorizationCodeHandlerConfig{ + DynamicClientRegistrationConfig: &DynamicClientRegistrationConfig{ + Metadata: &oauthex.ClientRegistrationMetadata{}, + }, + }} + asm, err := handler.getAuthServerMetadata(t.Context(), &oauthex.ProtectedResourceMetadata{ + AuthorizationServers: []string{s.URL()}, + }) + if err != nil { + t.Fatalf("getAuthServerMetadata() error = %v, want nil", err) + } + got, err := handler.handleRegistration(t.Context(), asm) + if err != nil { + t.Fatalf("handleRegistration() error = %v, want nil", err) + } + if got.registrationType != registrationTypeDynamic { + t.Errorf("handleRegistration() registrationType = %v, want %v", got.registrationType, registrationTypeDynamic) + } + if got.clientID == "" { + t.Errorf("handleRegistration() clientID = %q, want non-empty", got.clientID) + } + if got.clientSecret == "" { + t.Errorf("handleRegistration() clientSecret = %q, want non-empty", got.clientSecret) + } + if got.authStyle != oauth2.AuthStyleInHeader { + t.Errorf("handleRegistration() authStyle = %v, want %v", got.authStyle, oauth2.AuthStyleInHeader) + } +} diff --git a/auth/client.go b/auth/client.go index acadc51b..0af6963f 100644 --- a/auth/client.go +++ b/auth/client.go @@ -2,122 +2,41 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. -//go:build mcp_go_client_oauth - package auth import ( - "bytes" - "errors" - "io" + "context" "net/http" - "sync" "golang.org/x/oauth2" ) -// An OAuthHandler conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization -// is approved, or an error if not. -// The handler receives the HTTP request and response that triggered the authentication flow. -// To obtain the protected resource metadata, call [oauthex.GetProtectedResourceMetadataFromHeader]. -type OAuthHandler func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) - -// HTTPTransport is an [http.RoundTripper] that follows the MCP -// OAuth protocol when it encounters a 401 Unauthorized response. -type HTTPTransport struct { - handler OAuthHandler - mu sync.Mutex // protects opts.Base - opts HTTPTransportOptions -} - -// NewHTTPTransport returns a new [*HTTPTransport]. -// The handler is invoked when an HTTP request results in a 401 Unauthorized status. -// It is called only once per transport. Once a TokenSource is obtained, it is used -// for the lifetime of the transport; subsequent 401s are not processed. -func NewHTTPTransport(handler OAuthHandler, opts *HTTPTransportOptions) (*HTTPTransport, error) { - if handler == nil { - return nil, errors.New("handler cannot be nil") - } - t := &HTTPTransport{ - handler: handler, - } - if opts != nil { - t.opts = *opts - } - if t.opts.Base == nil { - t.opts.Base = http.DefaultTransport - } - return t, nil -} - -// HTTPTransportOptions are options to [NewHTTPTransport]. -type HTTPTransportOptions struct { - // Base is the [http.RoundTripper] to use. - // If nil, [http.DefaultTransport] is used. - Base http.RoundTripper -} - -func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { - t.mu.Lock() - base := t.opts.Base - t.mu.Unlock() - - var ( - // If haveBody is set, the request has a nontrivial body, and we need avoid - // reading (or closing) it multiple times. In that case, bodyBytes is its - // content. - haveBody bool - bodyBytes []byte - ) - if req.Body != nil && req.Body != http.NoBody { - // if we're setting Body, we must mutate first. - req = req.Clone(req.Context()) - haveBody = true - var err error - bodyBytes, err = io.ReadAll(req.Body) - if err != nil { - return nil, err - } - // Now that we've read the request body, http.RoundTripper requires that we - // close it. - req.Body.Close() // ignore error - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - } - - resp, err := base.RoundTrip(req) - if err != nil { - return nil, err - } - if resp.StatusCode != http.StatusUnauthorized { - return resp, nil - } - if _, ok := base.(*oauth2.Transport); ok { - // We failed to authorize even with a token source; give up. - return resp, nil - } - - resp.Body.Close() - // Try to authorize. - t.mu.Lock() - defer t.mu.Unlock() - // If we don't have a token source, get one by following the OAuth flow. - // (We may have obtained one while t.mu was not held above.) - // TODO: We hold the lock for the entire OAuth flow. This could be a long - // time. Is there a better way? - if _, ok := t.opts.Base.(*oauth2.Transport); !ok { - ts, err := t.handler(req, resp) - if err != nil { - return nil, err - } - t.opts.Base = &oauth2.Transport{Base: t.opts.Base, Source: ts} - } - - // If we don't have a body, the request is reusable, though it will be cloned - // by the base. However, if we've had to read the body, we must clone. - if haveBody { - req = req.Clone(req.Context()) - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - } - - return t.opts.Base.RoundTrip(req) +// OAuthHandler is an interface for handling OAuth flows. +// +// If a transport wishes to support OAuth 2 authorization, it should support +// being configured with an OAuthHandler. It should call the handler's +// TokenSource method whenever it sends an HTTP request to set the +// Authorization header. If a request fails with a 401 or 403, it should call +// Authorize, and if that returns nil, it should retry the request. It should +// not call Authorize after the second failure. See +// [github.com/modelcontextprotocol/go-sdk/mcp.StreamableClientTransport] +// for an example. +type OAuthHandler interface { + isOAuthHandler() + + // TokenSource returns a token source to be used for outgoing requests. + // Returned token source might be nil. In that case, the transport will not + // add any authorization headers to the request. + TokenSource(context.Context) (oauth2.TokenSource, error) + + // Authorize is called when an HTTP request results in an error that may + // be addressed by the authorization flow (currently 401 Unauthorized and 403 Forbidden). + // It is responsible for performing the OAuth flow to obtain an access token. + // The arguments are the request that failed and the response that was received for it. + // The headers of the request are available, but the body will have already been consumed + // when Authorize is called. + // If the returned error is nil, TokenSource is expected to return a non-nil token source. + // After a successful call to Authorize, the HTTP request will be retried by the transport. + // The function is responsible for closing the response body. + Authorize(context.Context, *http.Request, *http.Response) error } diff --git a/auth/client_private.go b/auth/client_private.go new file mode 100644 index 00000000..767c59ee --- /dev/null +++ b/auth/client_private.go @@ -0,0 +1,135 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "bytes" + "errors" + "io" + "net/http" + "sync" + + "golang.org/x/oauth2" +) + +// An OAuthHandlerLegacy conducts an OAuth flow and returns a [oauth2.TokenSource] if the authorization +// is approved, or an error if not. +// The handler receives the HTTP request and response that triggered the authentication flow. +// To obtain the protected resource metadata, call [oauthex.GetProtectedResourceMetadataFromHeader]. +// +// Deprecated: Please use the new [OAuthHandler] abstraction that is built +// into the streamable transport. This struct will be removed in v1.5.0. +type OAuthHandlerLegacy func(req *http.Request, res *http.Response) (oauth2.TokenSource, error) + +// HTTPTransport is an [http.RoundTripper] that follows the MCP +// OAuth protocol when it encounters a 401 Unauthorized response. +// +// Deprecated: Please use the new [OAuthHandler] abstraction that is built +// into the streamable transport. This struct will be removed in v1.5.0. +type HTTPTransport struct { + handler OAuthHandlerLegacy + mu sync.Mutex // protects opts.Base + opts HTTPTransportOptions +} + +// NewHTTPTransport returns a new [*HTTPTransport]. +// The handler is invoked when an HTTP request results in a 401 Unauthorized status. +// It is called only once per transport. Once a TokenSource is obtained, it is used +// for the lifetime of the transport; subsequent 401s are not processed. +// +// Deprecated: Please use the new [OAuthHandler] abstraction that is built +// into the streamable transport. This struct will be removed in v1.5.0. +func NewHTTPTransport(handler OAuthHandlerLegacy, opts *HTTPTransportOptions) (*HTTPTransport, error) { + if handler == nil { + return nil, errors.New("handler cannot be nil") + } + t := &HTTPTransport{ + handler: handler, + } + if opts != nil { + t.opts = *opts + } + if t.opts.Base == nil { + t.opts.Base = http.DefaultTransport + } + return t, nil +} + +// HTTPTransportOptions are options to [NewHTTPTransport]. +// +// Deprecated: Please use the new [OAuthHandler] abstraction that is built +// into the streamable transport. This struct will be removed in v1.5.0. +type HTTPTransportOptions struct { + // Base is the [http.RoundTripper] to use. + // If nil, [http.DefaultTransport] is used. + Base http.RoundTripper +} + +func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { + t.mu.Lock() + base := t.opts.Base + t.mu.Unlock() + + var ( + // If haveBody is set, the request has a nontrivial body, and we need avoid + // reading (or closing) it multiple times. In that case, bodyBytes is its + // content. + haveBody bool + bodyBytes []byte + ) + if req.Body != nil && req.Body != http.NoBody { + // if we're setting Body, we must mutate first. + req = req.Clone(req.Context()) + haveBody = true + var err error + bodyBytes, err = io.ReadAll(req.Body) + if err != nil { + return nil, err + } + // Now that we've read the request body, http.RoundTripper requires that we + // close it. + req.Body.Close() // ignore error + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + resp, err := base.RoundTrip(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusUnauthorized { + return resp, nil + } + if _, ok := base.(*oauth2.Transport); ok { + // We failed to authorize even with a token source; give up. + return resp, nil + } + + resp.Body.Close() + // Try to authorize. + t.mu.Lock() + defer t.mu.Unlock() + // If we don't have a token source, get one by following the OAuth flow. + // (We may have obtained one while t.mu was not held above.) + // TODO: We hold the lock for the entire OAuth flow. This could be a long + // time. Is there a better way? + if _, ok := t.opts.Base.(*oauth2.Transport); !ok { + ts, err := t.handler(req, resp) + if err != nil { + return nil, err + } + t.opts.Base = &oauth2.Transport{Base: t.opts.Base, Source: ts} + } + + // If we don't have a body, the request is reusable, though it will be cloned + // by the base. However, if we've had to read the body, we must clone. + if haveBody { + req = req.Clone(req.Context()) + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + return t.opts.Base.RoundTrip(req) +} diff --git a/conformance/baseline.yml b/conformance/baseline.yml index ae1f9c63..9d7bef80 100644 --- a/conformance/baseline.yml +++ b/conformance/baseline.yml @@ -1,20 +1,2 @@ server: [] # All tests pass! -client: -- auth/basic-cimd -- auth/metadata-default -- auth/metadata-var1 -- auth/metadata-var2 -- auth/metadata-var3 -- auth/2025-03-26-oauth-metadata-backcompat -- auth/2025-03-26-oauth-endpoint-fallback -- auth/scope-from-www-authenticate -- auth/scope-from-scopes-supported -- auth/scope-omitted-when-undefined -- auth/scope-step-up -- auth/scope-retry-limit -- auth/token-endpoint-auth-basic -- auth/token-endpoint-auth-post -- auth/token-endpoint-auth-none -- auth/client-credentials-jwt -- auth/client-credentials-basic -- auth/pre-registration +client: [] # All tests pass! diff --git a/conformance/everything-client/client_private.go b/conformance/everything-client/client_private.go new file mode 100644 index 00000000..3b0c6592 --- /dev/null +++ b/conformance/everything-client/client_private.go @@ -0,0 +1,139 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// The conformance client implements features required for MCP conformance testing. +// It mirrors the functionality of the TypeScript conformance client at +// https://github.com/modelcontextprotocol/typescript-sdk/blob/main/src/conformance/everything-client.ts + +//go:build mcp_go_client_oauth + +package main + +import ( + "context" + "fmt" + "net/http" + "net/url" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/modelcontextprotocol/go-sdk/oauthex" +) + +func init() { + authScenarios := []string{ + "auth/2025-03-26-oauth-metadata-backcompat", + "auth/2025-03-26-oauth-endpoint-fallback", + "auth/basic-cimd", + "auth/metadata-default", + "auth/metadata-var1", + "auth/metadata-var2", + "auth/metadata-var3", + "auth/pre-registration", + "auth/resource-mismatch", + "auth/scope-from-www-authenticate", + "auth/scope-from-scopes-supported", + "auth/scope-omitted-when-undefined", + "auth/scope-step-up", + "auth/scope-retry-limit", + "auth/token-endpoint-auth-basic", + "auth/token-endpoint-auth-post", + "auth/token-endpoint-auth-none", + } + for _, scenario := range authScenarios { + registerScenario(scenario, runAuthClient) + } +} + +// ============================================================================ +// Auth scenarios +// ============================================================================ + +func fetchAuthorizationCodeAndState(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + req, err := http.NewRequestWithContext(ctx, "GET", args.URL, nil) + if err != nil { + return nil, err + } + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + // In conformance tests the authorization server immediately redirects + // to the callback URL with the authorization code and state. + locURL, err := url.Parse(resp.Header.Get("Location")) + if err != nil { + return nil, fmt.Errorf("parse location: %v", err) + } + + return &auth.AuthorizationResult{ + Code: locURL.Query().Get("code"), + State: locURL.Query().Get("state"), + }, nil +} + +func runAuthClient(ctx context.Context, serverURL string, configCtx map[string]any) error { + authConfig := &auth.AuthorizationCodeHandlerConfig{ + RedirectURL: "http://localhost:3000/callback", + AuthorizationCodeFetcher: fetchAuthorizationCodeAndState, + // Try client ID metadata document based registration. + ClientIDMetadataDocumentConfig: &auth.ClientIDMetadataDocumentConfig{ + URL: "https://conformance-test.local/client-metadata.json", + }, + // Try dynamic client registration. + DynamicClientRegistrationConfig: &auth.DynamicClientRegistrationConfig{ + Metadata: &oauthex.ClientRegistrationMetadata{ + RedirectURIs: []string{"http://localhost:3000/callback"}, + }, + }, + } + // Try pre-registered client information if provided in the context. + if clientID, ok := configCtx["client_id"].(string); ok { + if clientSecret, ok := configCtx["client_secret"].(string); ok { + authConfig.PreregisteredClientConfig = &auth.PreregisteredClientConfig{ + ClientSecretAuthConfig: &auth.ClientSecretAuthConfig{ + ClientID: clientID, + ClientSecret: clientSecret, + }, + } + } + } + + authHandler, err := auth.NewAuthorizationCodeHandler(authConfig) + if err != nil { + return fmt.Errorf("failed to create auth handler: %w", err) + } + + session, err := connectToServer(ctx, serverURL, withOAuthHandler(authHandler)) + if err != nil { + return err + } + defer session.Close() + + if _, err := session.ListTools(ctx, nil); err != nil { + return fmt.Errorf("session.ListTools(): %v", err) + } + + if _, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "test-tool", + Arguments: map[string]any{}, + }); err != nil { + return fmt.Errorf("session.CallTool('test-tool'): %v", err) + } + + return nil +} + +func withOAuthHandler(handler auth.OAuthHandler) connectOption { + return func(c *connectConfig) { + c.oauthHandler = handler + } +} diff --git a/conformance/everything-client/main.go b/conformance/everything-client/main.go index 9674dbbc..c05fa6f7 100644 --- a/conformance/everything-client/main.go +++ b/conformance/everything-client/main.go @@ -9,6 +9,7 @@ package main import ( "context" + "encoding/json" "fmt" "log" "os" @@ -16,12 +17,13 @@ import ( "sort" "strings" + "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/mcp" ) // scenarioHandler is the function signature for all conformance test scenarios. // It takes a context and the server URL to connect to. -type scenarioHandler func(ctx context.Context, serverURL string) error +type scenarioHandler func(ctx context.Context, serverURL string, configCtx map[string]any) error var ( // registry stores all registered scenario handlers. @@ -48,7 +50,7 @@ func init() { // Basic scenarios // ============================================================================ -func runBasicClient(ctx context.Context, serverURL string) error { +func runBasicClient(ctx context.Context, serverURL string, _ map[string]any) error { session, err := connectToServer(ctx, serverURL) if err != nil { return err @@ -63,7 +65,7 @@ func runBasicClient(ctx context.Context, serverURL string) error { return nil } -func runToolsCallClient(ctx context.Context, serverURL string) error { +func runToolsCallClient(ctx context.Context, serverURL string, _ map[string]any) error { session, err := connectToServer(ctx, serverURL) if err != nil { return err @@ -97,7 +99,7 @@ func runToolsCallClient(ctx context.Context, serverURL string) error { // Elicitation scenarios // ============================================================================ -func runElicitationDefaultsClient(ctx context.Context, serverURL string) error { +func runElicitationDefaultsClient(ctx context.Context, serverURL string, _ map[string]any) error { elicitationHandler := func(ctx context.Context, req *mcp.ElicitRequest) (*mcp.ElicitResult, error) { return &mcp.ElicitResult{ Action: "accept", @@ -141,8 +143,7 @@ func runElicitationDefaultsClient(ctx context.Context, serverURL string) error { // SSE retry scenario // ============================================================================ -func runSSERetryClient(ctx context.Context, serverURL string) error { - // TODO: this scenario is not passing yet. It requires a fix in the client SSE handling. +func runSSERetryClient(ctx context.Context, serverURL string, _ map[string]any) error { session, err := connectToServer(ctx, serverURL) if err != nil { return err @@ -185,6 +186,7 @@ func main() { serverURL := os.Args[1] scenarioName := os.Getenv("MCP_CONFORMANCE_SCENARIO") + configCtx := getConformanceContext() if scenarioName == "" { printUsageAndExit("MCP_CONFORMANCE_SCENARIO not set") @@ -196,11 +198,21 @@ func main() { } ctx := context.Background() - if err := handler(ctx, serverURL); err != nil { + if err := handler(ctx, serverURL, configCtx); err != nil { log.Fatalf("Scenario %q failed: %v", scenarioName, err) } } +func getConformanceContext() map[string]any { + ctxStr := os.Getenv("MCP_CONFORMANCE_CONTEXT") + if ctxStr == "" { + return nil + } + var ctx map[string]any + _ = json.Unmarshal([]byte(ctxStr), &ctx) + return ctx +} + func printUsageAndExit(format string, args ...any) { var scenarios []string for name := range registry { @@ -214,6 +226,7 @@ func printUsageAndExit(format string, args ...any) { type connectConfig struct { clientOptions *mcp.ClientOptions + oauthHandler auth.OAuthHandler } type connectOption func(*connectConfig) @@ -237,11 +250,14 @@ func connectToServer(ctx context.Context, serverURL string, opts ...connectOptio Version: "1.0.0", }, config.clientOptions) - transport := &mcp.StreamableClientTransport{Endpoint: serverURL} + transport := &mcp.StreamableClientTransport{ + Endpoint: serverURL, + OAuthHandler: config.oauthHandler, + } session, err := client.Connect(ctx, transport, nil) if err != nil { - return nil, fmt.Errorf("client.Connect(): %v", err) + return nil, fmt.Errorf("client.Connect(): %w", err) } return session, nil diff --git a/docs/protocol.md b/docs/protocol.md index 16ba0bfa..abdf50fa 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -306,9 +306,55 @@ The [_auth middleware example_](https://github.com/modelcontextprotocol/go-sdk/ ### Client -Client-side OAuth is implemented by setting -[`StreamableClientTransport.HTTPClient`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk@v0.5.0/mcp#StreamableClientTransport.HTTPClient) to a custom [`http.Client`](https://pkg.go.dev/net/http#Client) -Additional support is forthcoming; see modelcontextprotocol/go-sdk#493. +> [!IMPORTANT] +> Client-side OAuth support is currently experimental and requires the `mcp_go_client_oauth` build tag to compile. +> API changes may still be made, based on developer feedback. The build tag will be removed in `v1.5.0`, which +> is planned to be released by the end of March 2026. + +Client-side authorization is supported via the +[`StreamableClientTransport.OAuthHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#StreamableClientTransport.OAuthHandler) +field. If the handler is provided, the transport will automatically use it to +add an `Authorization: Bearer ` header to every request. The transport +will also call the handler's `Authorize` method if the server returns +`401 Unauthorized` or `403 Forbidden` errors to perform the authorization flow +or facilitate scope step-up authorization. + +The SDK implements the Authorization Code flow in +[`auth.AuthorizationCodeHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#AuthorizationCodeHandler). +This handler supports: + +- [Client ID Metadata Documents](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#client-id-metadata-documents) +- [Pre-registered clients](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#preregistration) +- [Dynamic Client Registration](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#dynamic-client-registration) + +To use it, configure the handler and assign it to the transport: + +```go +authHandler, _ := auth.NewAuthorizationCodeHandler(&auth.AuthorizationCodeHandlerConfig{ + RedirectURL: "https://myapp.com/oauth2-callback", + // Configure one of the following: + // ClientIDMetadataDocumentConfig: ... + // PreregisteredClientConfig: ... + // DynamicClientRegistrationConfig: ... + AuthorizationCodeFetcher: func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { + // Open the args.URL in a browser and return the resulting code and state. + // See full example in examples/auth/client/main.go. + code := ... + state := ... + return &auth.AuthorizationResult{Code: code, State: state}, nil + }, +}) + +transport := &mcp.StreamableClientTransport{ + Endpoint: "https://example.com/mcp", + OAuthHandler: authHandler, +} +client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) +session, err := client.Connect(ctx, transport, nil) +``` + +The `auth.AuthorizationCodeHandler` automatically manages token refreshing +and step-up authentication (when the server returns `insufficient_scope` error). ## Security @@ -317,9 +363,12 @@ the MCP spec's [Security Best Practices](https://modelcontextprotocol.io/specifi ### Confused Deputy -The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation), obtaining user consent for dynamically registered clients, -happens on the MCP client. At present we don't provide client-side OAuth support. - +The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation), +obtaining user consent for dynamically registered clients, is mostly the +responsibility of the MCP Proxy server implementation. The SDK client does +generate cryptographically secure random `state` values for each authorization +request by default and validates them when the authorization code is returned. +Mismatched state values will result in an error. ### Token Passthrough diff --git a/examples/auth/client/main.go b/examples/auth/client/main.go new file mode 100644 index 00000000..32de488e --- /dev/null +++ b/examples/auth/client/main.go @@ -0,0 +1,130 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "log" + "net" + "net/http" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +var ( + // URL of the MCP server. + serverURL = flag.String("server_url", "http://localhost:8000/mcp", "URL of the MCP server.") + // Port for the local HTTP server that will receive the authorization code. + callbackPort = flag.Int("callback_port", 3142, "Port for the local HTTP server that will receive the authorization code.") +) + +type codeReceiver struct { + authChan chan *auth.AuthorizationResult + errChan chan error + listener net.Listener + server *http.Server +} + +func (r *codeReceiver) serveRedirectHandler(listener net.Listener) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { + r.authChan <- &auth.AuthorizationResult{ + Code: req.URL.Query().Get("code"), + State: req.URL.Query().Get("state"), + } + fmt.Fprint(w, "Authentication successful. You can close this window.") + }) + + r.server = &http.Server{ + Addr: fmt.Sprintf("localhost:%d", *callbackPort), + Handler: mux, + } + if err := r.server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { + r.errChan <- err + } +} + +func (r *codeReceiver) getAuthorizationCode(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { + select { + case authRes := <-r.authChan: + return authRes, nil + case err := <-r.errChan: + return nil, err + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (r *codeReceiver) close() { + if r.server != nil { + r.server.Close() + } +} + +func main() { + flag.Parse() + receiver := &codeReceiver{ + authChan: make(chan *auth.AuthorizationResult), + errChan: make(chan error), + } + listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", *callbackPort)) + if err != nil { + log.Fatalf("failed to listen: %v", err) + } + go receiver.serveRedirectHandler(listener) + defer receiver.close() + + authHandler, err := auth.NewAuthorizationCodeHandler(&auth.AuthorizationCodeHandlerConfig{ + RedirectURL: fmt.Sprintf("http://localhost:%d", *callbackPort), + AuthorizationCodeFetcher: receiver.getAuthorizationCode, + // Uncomment the client configuration you want to use. + // PreregisteredClientConfig: &auth.PreregisteredClientConfig{ + // ClientID: "", + // ClientSecret: "", + // }, + // DynamicClientRegistrationConfig: &auth.DynamicClientRegistrationConfig{ + // Metadata: &oauthex.ClientRegistrationMetadata{ + // ClientName: "Dynamically registered MCP client", + // RedirectURIs: []string{fmt.Sprintf("http://localhost:%d", *callbackPort)}, + // Scope: "read", + // }, + // }, + }) + if err != nil { + log.Fatalf("failed to create auth handler: %v", err) + } + + transport := &mcp.StreamableClientTransport{ + Endpoint: *serverURL, + OAuthHandler: authHandler, + } + + ctx := context.Background() + client := mcp.NewClient(&mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, nil) + + session, err := client.Connect(ctx, transport, nil) + if err != nil { + log.Fatalf("client.Connect(): %v", err) + } + defer session.Close() + + tools, err := session.ListTools(ctx, nil) + if err != nil { + log.Fatalf("session.ListTools(): %v", err) + } + log.Println("Tools:") + for _, tool := range tools.Tools { + log.Printf("- %q", tool.Name) + } +} diff --git a/examples/auth/server/main.go b/examples/auth/server/main.go new file mode 100644 index 00000000..94ad9ae3 --- /dev/null +++ b/examples/auth/server/main.go @@ -0,0 +1,167 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "log" + "net/http" + "net/http/httputil" + "net/url" + "strings" + "time" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/modelcontextprotocol/go-sdk/oauthex" +) + +// Flags. +var ( + port = flag.Int("port", 8000, "Port to listen on") +) + +// Configuration required for this example. +var ( + // Authorization server to return in the protected resource metadata. + authorizationServer = "" + // Introspection endpoint for verifying tokens. + introspectionEndpoint = "" + // Client credentials used in the introspection request. + clientID = "" + clientSecret = "" +) + +func verifyToken(ctx context.Context, token string, _ *http.Request) (*auth.TokenInfo, error) { + data := url.Values{} + data.Set("token", token) + data.Set("token_type_hint", "access_token") + + req, err := http.NewRequestWithContext(ctx, "POST", introspectionEndpoint, strings.NewReader(data.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + req.SetBasicAuth(clientID, clientSecret) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + dump, _ := httputil.DumpResponse(resp, true) + log.Printf("Introspection failed: %s", dump) + return nil, fmt.Errorf("introspection failed with status %d", resp.StatusCode) + } + + var result struct { + Active bool `json:"active"` + Scope string `json:"scope"` + Exp int64 `json:"exp"` + Sub string `json:"sub"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, err + } + + if !result.Active { + return nil, auth.ErrInvalidToken + } + + return &auth.TokenInfo{ + Scopes: strings.Fields(result.Scope), + Expiration: time.Unix(result.Exp, 0), + UserID: result.Sub, + }, nil +} + +type args struct { + Input string `json:"input"` +} + +func echo(ctx context.Context, req *mcp.CallToolRequest, args args) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + &mcp.TextContent{Text: args.Input}, + }, + }, nil, nil +} + +func main() { + flag.Parse() + metadata := &oauthex.ProtectedResourceMetadata{ + Resource: fmt.Sprintf("http://localhost:%d/mcp", *port), + AuthorizationServers: []string{authorizationServer}, + ScopesSupported: []string{"read"}, + } + http.Handle("/.well-known/oauth-protected-resource", auth.ProtectedResourceMetadataHandler(metadata)) + + server := mcp.NewServer(&mcp.Implementation{ + Name: "test-server", + Version: "1.0.0", + }, nil) + server.AddReceivingMiddleware(createLoggingMiddleware()) + mcp.AddTool(server, &mcp.Tool{Name: "echo"}, echo) + + handler := mcp.NewStreamableHTTPHandler(func(req *http.Request) *mcp.Server { + return server + }, nil) + + authMiddleware := auth.RequireBearerToken(verifyToken, &auth.RequireBearerTokenOptions{ + Scopes: []string{"read"}, + ResourceMetadataURL: fmt.Sprintf("http://localhost:%d/.well-known/oauth-protected-resource", *port), + }) + + http.Handle("/mcp", authMiddleware(handler)) + + log.Printf("Starting server on http://localhost:%d", *port) + log.Fatal(http.ListenAndServe(fmt.Sprintf("localhost:%d", *port), nil)) +} + +// createLoggingMiddleware creates an MCP middleware that logs method calls. +func createLoggingMiddleware() mcp.Middleware { + return func(next mcp.MethodHandler) mcp.MethodHandler { + return func( + ctx context.Context, + method string, + req mcp.Request, + ) (mcp.Result, error) { + start := time.Now() + sessionID := req.GetSession().ID() + + // Log request details. + log.Printf("[REQUEST] Session: %s | Method: %s", + sessionID, + method) + + // Call the actual handler. + result, err := next(ctx, method, req) + + // Log response details. + duration := time.Since(start) + + if err != nil { + log.Printf("[RESPONSE] Session: %s | Method: %s | Status: ERROR | Duration: %v | Error: %v", + sessionID, + method, + duration, + err) + } else { + log.Printf("[RESPONSE] Session: %s | Method: %s | Status: OK | Duration: %v", + sessionID, + method, + duration) + } + + return result, err + } + } +} diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md index ada34371..3771b581 100644 --- a/internal/docs/protocol.src.md +++ b/internal/docs/protocol.src.md @@ -232,9 +232,55 @@ The [_auth middleware example_](https://github.com/modelcontextprotocol/go-sdk/ ### Client -Client-side OAuth is implemented by setting -[`StreamableClientTransport.HTTPClient`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk@v0.5.0/mcp#StreamableClientTransport.HTTPClient) to a custom [`http.Client`](https://pkg.go.dev/net/http#Client) -Additional support is forthcoming; see modelcontextprotocol/go-sdk#493. +> [!IMPORTANT] +> Client-side OAuth support is currently experimental and requires the `mcp_go_client_oauth` build tag to compile. +> API changes may still be made, based on developer feedback. The build tag will be removed in `v1.5.0`, which +> is planned to be released by the end of March 2026. + +Client-side authorization is supported via the +[`StreamableClientTransport.OAuthHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/mcp#StreamableClientTransport.OAuthHandler) +field. If the handler is provided, the transport will automatically use it to +add an `Authorization: Bearer ` header to every request. The transport +will also call the handler's `Authorize` method if the server returns +`401 Unauthorized` or `403 Forbidden` errors to perform the authorization flow +or facilitate scope step-up authorization. + +The SDK implements the Authorization Code flow in +[`auth.AuthorizationCodeHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#AuthorizationCodeHandler). +This handler supports: + +- [Client ID Metadata Documents](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#client-id-metadata-documents) +- [Pre-registered clients](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#preregistration) +- [Dynamic Client Registration](https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#dynamic-client-registration) + +To use it, configure the handler and assign it to the transport: + +```go +authHandler, _ := auth.NewAuthorizationCodeHandler(&auth.AuthorizationCodeHandlerConfig{ + RedirectURL: "https://myapp.com/oauth2-callback", + // Configure one of the following: + // ClientIDMetadataDocumentConfig: ... + // PreregisteredClientConfig: ... + // DynamicClientRegistrationConfig: ... + AuthorizationCodeFetcher: func(ctx context.Context, args *auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { + // Open the args.URL in a browser and return the resulting code and state. + // See full example in examples/auth/client/main.go. + code := ... + state := ... + return &auth.AuthorizationResult{Code: code, State: state}, nil + }, +}) + +transport := &mcp.StreamableClientTransport{ + Endpoint: "https://example.com/mcp", + OAuthHandler: authHandler, +} +client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) +session, err := client.Connect(ctx, transport, nil) +``` + +The `auth.AuthorizationCodeHandler` automatically manages token refreshing +and step-up authentication (when the server returns `insufficient_scope` error). ## Security @@ -243,9 +289,12 @@ the MCP spec's [Security Best Practices](https://modelcontextprotocol.io/specifi ### Confused Deputy -The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation), obtaining user consent for dynamically registered clients, -happens on the MCP client. At present we don't provide client-side OAuth support. - +The [mitigation](https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#mitigation), +obtaining user consent for dynamically registered clients, is mostly the +responsibility of the MCP Proxy server implementation. The SDK client does +generate cryptographically secure random `state` values for each authorization +request by default and validates them when the authorization code is returned. +Mismatched state values will result in an error. ### Token Passthrough diff --git a/internal/oauthtest/fake_authorization_server.go b/internal/oauthtest/fake_authorization_server.go new file mode 100644 index 00000000..1e81f102 --- /dev/null +++ b/internal/oauthtest/fake_authorization_server.go @@ -0,0 +1,304 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package oauthtest + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "maps" + "net/http" + "net/http/httptest" + "slices" + "testing" + + internaljson "github.com/modelcontextprotocol/go-sdk/internal/json" + "github.com/modelcontextprotocol/go-sdk/oauthex" +) + +type ClientInfo struct { + Secret string + RedirectURIs []string +} + +type MetadataEndpointConfig struct { + // Whether to serve the OAuth Authorization Server Metadata at + // /.well-known/oauth-authorization-server + issuerPath. + ServeOAuthInsertedEndpoint bool + // Whether to serve the OAuth Authorization Server Metadata at + // /.well-known/openid-configuration + issuerPath. + ServeOpenIDInsertedEndpoint bool + // Whether to serve the OAuth Authorization Server Metadata at + // issuerPath + /.well-known/openid-configuration. + // Should be used when issuerPath is not empty. + ServeOpenIDAppendedEndpoint bool +} + +type RegistrationConfig struct { + // Whether the client ID metadata document is supported. + ClientIDMetadataDocumentSupported bool + // PreregisteredClients is a map of valid ClientIDs to ClientSecrets. + PreregisteredClients map[string]ClientInfo + // Whether dynamic client registration is enabled. + DynamicClientRegistrationEnabled bool +} + +// Config holds configuration for FakeAuthorizationServer. +type Config struct { + // The optional path component of the issuer URL. + // If non-empty, it should start with a "/". It should not end with a "/". + // It affects the paths of the server endpoints. + IssuerPath string + // Configuration of the metadata endpoint. + MetadataEndpointConfig *MetadataEndpointConfig + // Configuration for client registration. + RegistrationConfig *RegistrationConfig +} + +// FakeAuthorizationServer is a fake OAuth 2.0 Authorization Server for testing. +type FakeAuthorizationServer struct { + server *httptest.Server + Mux *http.ServeMux + config Config + clients map[string]ClientInfo + codes map[string]codeInfo +} + +type codeInfo struct { + CodeChallenge string +} + +// NewFakeAuthorizationServer creates a new FakeAuthorizationServer. +// The server is simple and should not be used outside of testing. +// It supports: +// - Only the authorization Code Grant +// - PKCE verification +// - Client tracking & dynamic registration +// - Client authentication +func NewFakeAuthorizationServer(config Config) *FakeAuthorizationServer { + s := &FakeAuthorizationServer{ + Mux: http.NewServeMux(), + config: config, + codes: make(map[string]codeInfo), + } + if config.RegistrationConfig != nil { + s.clients = maps.Clone(config.RegistrationConfig.PreregisteredClients) + } + if s.clients == nil { + s.clients = make(map[string]ClientInfo) + } + + s.Mux.HandleFunc(s.config.IssuerPath+"/authorize", s.handleAuthorize) + s.Mux.HandleFunc(s.config.IssuerPath+"/token", s.handleToken) + if config.MetadataEndpointConfig != nil { + if config.MetadataEndpointConfig.ServeOAuthInsertedEndpoint { + s.Mux.HandleFunc("/.well-known/oauth-authorization-server"+s.config.IssuerPath, s.handleMetadata) + } + if config.MetadataEndpointConfig.ServeOpenIDInsertedEndpoint { + s.Mux.HandleFunc("/.well-known/openid-configuration"+s.config.IssuerPath, s.handleMetadata) + } + if config.MetadataEndpointConfig.ServeOpenIDAppendedEndpoint && s.config.IssuerPath != "" { + s.Mux.HandleFunc(s.config.IssuerPath+"/.well-known/openid-configuration", s.handleMetadata) + } + } else { + // Serve the default OAuth endpoint. + s.Mux.HandleFunc("/.well-known/oauth-authorization-server", s.handleMetadata) + } + if config.RegistrationConfig != nil && config.RegistrationConfig.DynamicClientRegistrationEnabled { + s.Mux.HandleFunc(s.config.IssuerPath+"/register", s.handleRegister) + } + s.server = httptest.NewUnstartedServer(s.Mux) + + return s +} + +// Start starts the HTTP server and registers a cleanup function on t to close the server. +func (s *FakeAuthorizationServer) Start(t testing.TB) { + s.server.Start() + t.Cleanup(s.server.Close) +} + +// URL returns the base URL of the server (Issuer). +func (s *FakeAuthorizationServer) URL() string { + return s.server.URL +} + +func (s *FakeAuthorizationServer) handleMetadata(w http.ResponseWriter, r *http.Request) { + cimdSupported := false + var registrationEndpoint string + if s.config.RegistrationConfig != nil { + cimdSupported = s.config.RegistrationConfig.ClientIDMetadataDocumentSupported + if s.config.RegistrationConfig.DynamicClientRegistrationEnabled { + registrationEndpoint = s.URL() + s.config.IssuerPath + "/register" + } + } + meta := &oauthex.AuthServerMeta{ + Issuer: s.URL() + s.config.IssuerPath, + AuthorizationEndpoint: s.URL() + s.config.IssuerPath + "/authorize", + TokenEndpoint: s.URL() + s.config.IssuerPath + "/token", + RegistrationEndpoint: registrationEndpoint, + ResponseTypesSupported: []string{"code"}, + CodeChallengeMethodsSupported: []string{"S256"}, + ClientIDMetadataDocumentSupported: cimdSupported, + TokenEndpointAuthMethodsSupported: []string{"client_secret_post", "client_secret_basic"}, + } + // Set CORS headers for cross-origin client discovery. + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + + // Handle CORS preflight requests + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + // Only GET allowed for metadata retrieval + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(meta); err != nil { + http.Error(w, "Failed to encode metadata", http.StatusInternalServerError) + return + } +} + +func (s *FakeAuthorizationServer) handleRegister(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + var metadata oauthex.ClientRegistrationMetadata + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "failed to read request body", http.StatusBadRequest) + return + } + if err := internaljson.Unmarshal(body, &metadata); err != nil { + http.Error(w, "failed to parse request", http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + clientID := rand.Text() + ci := ClientInfo{ + Secret: rand.Text(), + RedirectURIs: metadata.RedirectURIs, + } + s.clients[clientID] = ci + metadata.TokenEndpointAuthMethod = "client_secret_basic" + json.NewEncoder(w).Encode(&oauthex.ClientRegistrationResponse{ + ClientID: clientID, + ClientSecret: ci.Secret, + ClientRegistrationMetadata: metadata, + }) +} + +func (s *FakeAuthorizationServer) handleAuthorize(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + clientID := r.URL.Query().Get("client_id") + clientInfo, ok := s.clients[clientID] + if !ok { + http.Error(w, "unknown client_id", http.StatusBadRequest) + return + } + + redirectURI := r.URL.Query().Get("redirect_uri") + if redirectURI == "" { + http.Error(w, "missing redirect_uri", http.StatusBadRequest) + return + } + if !slices.Contains(clientInfo.RedirectURIs, redirectURI) { + http.Error(w, "invalid redirect_uri", http.StatusBadRequest) + return + } + codeChallenge := r.URL.Query().Get("code_challenge") + if codeChallenge == "" { + http.Error(w, "missing code_challenge", http.StatusBadRequest) + return + } + code := rand.Text() + s.codes[code] = codeInfo{ + CodeChallenge: codeChallenge, + } + + state := r.URL.Query().Get("state") + + redirectURL := fmt.Sprintf("%s?code=%s&state=%s", redirectURI, code, state) + http.Redirect(w, r, redirectURL, http.StatusFound) +} + +func (s *FakeAuthorizationServer) handleToken(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "failed to parse form", http.StatusBadRequest) + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if err := s.authenticateClient(r); err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + if r.Form.Get("grant_type") != "authorization_code" { + http.Error(w, "invalid grant_type", http.StatusBadRequest) + return + } + code := r.Form.Get("code") + if code == "" { + http.Error(w, "missing code", http.StatusBadRequest) + return + } + codeInfo, ok := s.codes[code] + if !ok { + http.Error(w, "unknown authorization code", http.StatusBadRequest) + return + } + verifier := r.Form.Get("code_verifier") + if verifier == "" { + http.Error(w, "missing code_verifier", http.StatusBadRequest) + return + } + sha := sha256.Sum256([]byte(verifier)) + expectedChallenge := base64.RawURLEncoding.EncodeToString(sha[:]) + if expectedChallenge != codeInfo.CodeChallenge { + http.Error(w, "PKCE verification failed", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": "test_access_token", + "token_type": "Bearer", + "expires_in": 3600, + }) +} + +func (s *FakeAuthorizationServer) authenticateClient(r *http.Request) error { + clientID, clientSecret, ok := r.BasicAuth() + if !ok { + clientID = r.Form.Get("client_id") + clientSecret = r.Form.Get("client_secret") + } + + clientInfo, ok := s.clients[clientID] + if !ok || clientInfo.Secret != clientSecret { + return errors.New("client not found") + } + return nil +} diff --git a/internal/readme/README.src.md b/internal/readme/README.src.md index fce0fa44..d419b022 100644 --- a/internal/readme/README.src.md +++ b/internal/readme/README.src.md @@ -33,10 +33,15 @@ contains feature documentation, mapping the MCP spec to the packages above. The following table shows which versions of the Go SDK support which versions of the MCP specification: -| SDK Version | Latest MCP Spec | All Supported MCP Specs | -|-----------------|-------------------|------------------------------------------------| -| v1.2.0+ | 2025-06-18 | 2025-11-25, 2025-06-18, 2025-03-26, 2024-11-05 | -| v1.0.0 - v1.1.0 | 2025-06-18 | 2025-06-18, 2025-03-26, 2024-11-05 | +| SDK Version | Latest MCP Spec | All Supported MCP Specs | +|-----------------|-------------------|----------------------------------------------------| +| v1.4.0+ | 2025-11-25\* | 2025-11-25\*, 2025-06-18, 2025-03-26, 2024-11-05 | +| v1.2.0 - v1.3.1 | 2025-11-25\*\* | 2025-11-25\*\*, 2025-06-18, 2025-03-26, 2024-11-05 | +| v1.0.0 - v1.1.0 | 2025-06-18 | 2025-06-18, 2025-03-26, 2024-11-05 | + +\* Client side OAuth has experimental support. + +\*\* Partial support for 2025-11-25 (client side OAuth and Sampling with tools not available). New releases of the SDK target only supported versions of Go. See https://go.dev/doc/devel/release#policy for more information. diff --git a/mcp/streamable.go b/mcp/streamable.go index d78697de..0b11eff0 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -26,7 +26,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "time" "github.com/modelcontextprotocol/go-sdk/auth" @@ -1456,6 +1455,9 @@ type StreamableClientTransport struct { // - You want to avoid maintaining a persistent connection DisableStandaloneSSE bool + // OAuthHandler is an optional field that, if provided, will be used to authorize the requests. + OAuthHandler auth.OAuthHandler + // TODO(rfindley): propose exporting these. // If strict is set, the transport is in 'strict mode', where any violation // of the MCP spec causes a failure. @@ -1531,6 +1533,7 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er cancel: cancel, failed: make(chan struct{}), disableStandaloneSSE: t.DisableStandaloneSSE, + oauthHandler: t.OAuthHandler, } return conn, nil } @@ -1549,6 +1552,9 @@ type streamableClientConn struct { // for receiving server-to-client notifications when no request is in flight. disableStandaloneSSE bool // from [StreamableClientTransport.DisableStandaloneSSE] + // oauthHandler is the OAuth handler for the connection. + oauthHandler auth.OAuthHandler // from [StreamableClientTransport.OAuthHandler] + // Guard calls to Close, as it may be called multiple times. closeOnce sync.Once closeErr error @@ -1718,20 +1724,46 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return fmt.Errorf("%s: %v", requestSummary, err) } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(data)) + doRequest := func() (*http.Request, *http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.url, bytes.NewReader(data)) + if err != nil { + return nil, nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + if err := c.setMCPHeaders(req); err != nil { + // Failure to set headers means that the request was not sent. + // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr + // and permanently break the connection. + return nil, nil, fmt.Errorf("%s: %w: %v", requestSummary, jsonrpc2.ErrRejected, err) + } + resp, err := c.client.Do(req) + if err != nil { + // Any error from client.Do means the request didn't reach the server. + // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr + // and permanently break the connection. + err = fmt.Errorf("%s: %w: %v", requestSummary, jsonrpc2.ErrRejected, err) + } + return req, resp, err + } + + req, resp, err := doRequest() if err != nil { return err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json, text/event-stream") - c.setMCPHeaders(req) - resp, err := c.client.Do(req) - if err != nil { - // Any error from client.Do means the request didn't reach the server. - // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr - // and permanently break the connection. - return fmt.Errorf("%w: %s: %v", jsonrpc2.ErrRejected, requestSummary, err) + if (resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden) && c.oauthHandler != nil { + if err := c.oauthHandler.Authorize(ctx, req, resp); err != nil { + // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr + // and permanently break the connection. + // Wrap the authorization error as well for client inspection. + return fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err) + } + // Retry the request after successful authorization. + _, resp, err = doRequest() + if err != nil { + return err + } } if err := c.checkResponse(requestSummary, resp); err != nil { @@ -1799,23 +1831,32 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return nil } -// testAuth controls whether a fake Authorization header is added to outgoing requests. -// TODO: replace with a better mechanism when client-side auth is in place. -var testAuth atomic.Bool - -func (c *streamableClientConn) setMCPHeaders(req *http.Request) { +func (c *streamableClientConn) setMCPHeaders(req *http.Request) error { c.mu.Lock() defer c.mu.Unlock() + if c.oauthHandler != nil { + ts, err := c.oauthHandler.TokenSource(c.ctx) + if err != nil { + return err + } + if ts != nil { + token, err := ts.Token() + if err != nil { + return err + } + if token != nil { + req.Header.Set("Authorization", "Bearer "+token.AccessToken) + } + } + } if c.initializedResult != nil { req.Header.Set(protocolVersionHeader, c.initializedResult.ProtocolVersion) } if c.sessionID != "" { req.Header.Set(sessionIDHeader, c.sessionID) } - if testAuth.Load() { - req.Header.Set("Authorization", "Bearer foo") - } + return nil } func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Response) { @@ -2068,7 +2109,9 @@ func (c *streamableClientConn) connectSSE(ctx context.Context, lastEventID strin if err != nil { return nil, err } - c.setMCPHeaders(req) + if err := c.setMCPHeaders(req); err != nil { + return nil, err + } if lastEventID != "" { req.Header.Set(lastEventIDHeader, lastEventID) } @@ -2099,8 +2142,9 @@ func (c *streamableClientConn) Close() error { if err != nil { c.closeErr = err } else { - c.setMCPHeaders(req) - if _, err := c.client.Do(req); err != nil { + if err := c.setMCPHeaders(req); err != nil { + c.closeErr = err + } else if _, err := c.client.Do(req); err != nil { c.closeErr = err } } diff --git a/mcp/streamable_client_auth_test.go b/mcp/streamable_client_auth_test.go new file mode 100644 index 00000000..a1211e48 --- /dev/null +++ b/mcp/streamable_client_auth_test.go @@ -0,0 +1,179 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package mcp + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/auth" + "golang.org/x/oauth2" +) + +type mockOAuthHandler struct { + // Embed to satisfy the interface. + auth.AuthorizationCodeHandler + + token *oauth2.Token + authorizeErr error + authorizeCalled bool +} + +func (h *mockOAuthHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { + if h.token == nil { + return nil, nil + } + return oauth2.StaticTokenSource(h.token), nil +} + +func (h *mockOAuthHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { + h.authorizeCalled = true + return h.authorizeErr +} + +func TestStreamableClientOAuth_AuthorizationHeader(t *testing.T) { + ctx := context.Background() + token := &oauth2.Token{AccessToken: "test-token"} + oauthHandler := &mockOAuthHandler{token: token} + + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized, ""}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + {"GET", "123", "", ""}: { + header: header{ + "Content-Type": "text/event-stream", + }, + }, + {"DELETE", "123", "", ""}: {}, + }, + } + verifier := func(ctx context.Context, token string, req *http.Request) (*auth.TokenInfo, error) { + if token != "test-token" { + return nil, auth.ErrInvalidToken + } + return &auth.TokenInfo{Expiration: time.Now().Add(time.Hour)}, nil + } + httpServer := httptest.NewServer(auth.RequireBearerToken(verifier, nil)(fake)) + t.Cleanup(httpServer.Close) + + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + OAuthHandler: oauthHandler, + } + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + session.Close() +} + +func TestStreamableClientOAuth_401(t *testing.T) { + ctx := context.Background() + oauthHandler := &mockOAuthHandler{token: nil} + + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + }, + } + verifier := func(ctx context.Context, token string, req *http.Request) (*auth.TokenInfo, error) { + // Accept any token. + return &auth.TokenInfo{Expiration: time.Now().Add(time.Hour)}, nil + } + httpServer := httptest.NewServer(auth.RequireBearerToken(verifier, nil)(fake)) + t.Cleanup(httpServer.Close) + + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + OAuthHandler: oauthHandler, + } + client := NewClient(testImpl, nil) + _, err := client.Connect(ctx, transport, nil) + if err == nil || !strings.Contains(err.Error(), "Unauthorized") { + t.Fatalf("client.Connect() error does not contain 'Unauthorized': %v", err) + } + + if !oauthHandler.authorizeCalled { + t.Errorf("expected Authorize to be called") + } +} + +func TestTokenInfo(t *testing.T) { + ctx := context.Background() + + // Create a server with a tool that returns TokenInfo. + tokenInfo := func(ctx context.Context, req *CallToolRequest, _ struct{}) (*CallToolResult, any, error) { + return &CallToolResult{Content: []Content{&TextContent{Text: fmt.Sprintf("%v", req.Extra.TokenInfo)}}}, nil, nil + } + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "tokenInfo", Description: "return token info"}, tokenInfo) + + streamHandler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + verifier := func(ctx context.Context, token string, req *http.Request) (*auth.TokenInfo, error) { + if token != "test-token" { + return nil, auth.ErrInvalidToken + } + return &auth.TokenInfo{ + Scopes: []string{"scope"}, + // Expiration is far, far in the future. + Expiration: time.Date(5000, 1, 2, 3, 4, 5, 0, time.UTC), + }, nil + } + handler := auth.RequireBearerToken(verifier, nil)(streamHandler) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + OAuthHandler: &mockOAuthHandler{token: &oauth2.Token{AccessToken: "test-token"}}, + } + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer session.Close() + + res, err := session.CallTool(ctx, &CallToolParams{Name: "tokenInfo"}) + if err != nil { + t.Fatal(err) + } + if len(res.Content) == 0 { + t.Fatal("missing content") + } + tc, ok := res.Content[0].(*TextContent) + if !ok { + t.Fatal("not TextContent") + } + if g, w := tc.Text, "&{[scope] 5000-01-02 03:04:05 +0000 UTC map[]}"; g != w { + t.Errorf("got %q, want %q", g, w) + } +} diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index e3adfe98..d189ca41 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -17,7 +17,6 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 2cbe4002..22d0d1c6 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1666,55 +1666,6 @@ func textContent(t *testing.T, res *CallToolResult) string { return text.Text } -func TestTokenInfo(t *testing.T) { - oldAuth := testAuth.Load() - defer testAuth.Store(oldAuth) - testAuth.Store(true) - ctx := context.Background() - - // Create a server with a tool that returns TokenInfo. - tokenInfo := func(ctx context.Context, req *CallToolRequest, _ struct{}) (*CallToolResult, any, error) { - return &CallToolResult{Content: []Content{&TextContent{Text: fmt.Sprintf("%v", req.Extra.TokenInfo)}}}, nil, nil - } - server := NewServer(testImpl, nil) - AddTool(server, &Tool{Name: "tokenInfo", Description: "return token info"}, tokenInfo) - - streamHandler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) - verifier := func(context.Context, string, *http.Request) (*auth.TokenInfo, error) { - return &auth.TokenInfo{ - Scopes: []string{"scope"}, - // Expiration is far, far in the future. - Expiration: time.Date(5000, 1, 2, 3, 4, 5, 0, time.UTC), - }, nil - } - handler := auth.RequireBearerToken(verifier, nil)(streamHandler) - httpServer := httptest.NewServer(mustNotPanic(t, handler)) - defer httpServer.Close() - - transport := &StreamableClientTransport{Endpoint: httpServer.URL} - client := NewClient(testImpl, nil) - session, err := client.Connect(ctx, transport, nil) - if err != nil { - t.Fatalf("client.Connect() failed: %v", err) - } - defer session.Close() - - res, err := session.CallTool(ctx, &CallToolParams{Name: "tokenInfo"}) - if err != nil { - t.Fatal(err) - } - if len(res.Content) == 0 { - t.Fatal("missing content") - } - tc, ok := res.Content[0].(*TextContent) - if !ok { - t.Fatal("not TextContent") - } - if g, w := tc.Text, "&{[scope] 5000-01-02 03:04:05 +0000 UTC map[]}"; g != w { - t.Errorf("got %q, want %q", g, w) - } -} - func TestSessionHijackingPrevention(t *testing.T) { // This test verifies that sessions bound to a user ID cannot be accessed // by a different user (session hijacking prevention). diff --git a/oauthex/auth_meta.go b/oauthex/auth_meta.go index 9aa0c8d7..b05d80b6 100644 --- a/oauthex/auth_meta.go +++ b/oauthex/auth_meta.go @@ -14,6 +14,9 @@ import ( "errors" "fmt" "net/http" + "net/url" + + "github.com/modelcontextprotocol/go-sdk/internal/util" ) // AuthServerMeta represents the metadata for an OAuth 2.0 authorization server, @@ -28,8 +31,6 @@ import ( // // [RFC 8414]: https://tools.ietf.org/html/rfc8414) type AuthServerMeta struct { - // GENERATED BY GEMINI 2.5. - // Issuer is the REQUIRED URL identifying the authorization server. Issuer string `json:"issuer"` @@ -113,51 +114,61 @@ type AuthServerMeta struct { // CodeChallengeMethodsSupported is a RECOMMENDED JSON array of strings containing a list of // PKCE code challenge methods supported by this authorization server. CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported,omitempty"` -} -var wellKnownPaths = []string{ - "/.well-known/oauth-authorization-server", - "/.well-known/openid-configuration", + // ClientIDMetadataDocumentSupported is a boolean indicating whether the authorization server + // supports client ID metadata documents. + ClientIDMetadataDocumentSupported bool `json:"client_id_metadata_document_supported,omitempty"` } // GetAuthServerMeta issues a GET request to retrieve authorization server metadata -// from an OAuth authorization server with the given issuerURL. +// from an OAuth authorization server with the given metadataURL. // // It follows [RFC 8414]: -// - The well-known paths specified there are inserted into the URL's path, one at time. -// The first to succeed is used. -// - The Issuer field is checked against issuerURL. +// - The metadataURL must use HTTPS or be a local address. +// - The Issuer field is checked against metadataURL.Issuer. +// +// It also verifies that the authorization server supports PKCE and that the URLs +// in the metadata don't use dangerous schemes. +// +// It returns an error if the request fails with a non-4xx status code or the fetched +// metadata doesn't pass security validations. +// It returns nil if the request fails with a 4xx status code. // // [RFC 8414]: https://tools.ietf.org/html/rfc8414 -func GetAuthServerMeta(ctx context.Context, issuerURL string, c *http.Client) (*AuthServerMeta, error) { - var errs []error - for _, p := range wellKnownPaths { - u, err := prependToPath(issuerURL, p) - if err != nil { - // issuerURL is bad; no point in continuing. - return nil, err - } - asm, err := getJSON[AuthServerMeta](ctx, c, u, 1<<20) - if err == nil { - if asm.Issuer != issuerURL { // section 3.3 - // Security violation; don't keep trying. - return nil, fmt.Errorf("metadata issuer %q does not match issuer URL %q", asm.Issuer, issuerURL) - } - - if len(asm.CodeChallengeMethodsSupported) == 0 { - return nil, fmt.Errorf("authorization server at %s does not implement PKCE", issuerURL) +func GetAuthServerMeta(ctx context.Context, metadataURL, issuer string, c *http.Client) (*AuthServerMeta, error) { + u, err := url.Parse(metadataURL) + if err != nil { + return nil, err + } + // Only allow HTTP for local addresses (testing or development purposes). + if !util.IsLoopback(u.Host) && u.Scheme != "https" { + return nil, fmt.Errorf("metadataURL %q does not use HTTPS", metadataURL) + } + asm, err := getJSON[AuthServerMeta](ctx, c, metadataURL, 1<<20) + if err != nil { + var httpErr *httpStatusError + if errors.As(err, &httpErr) { + if 400 <= httpErr.StatusCode && httpErr.StatusCode < 500 { + return nil, nil } + } + return nil, fmt.Errorf("%v", err) // Do not expose error types. + } + if asm.Issuer != issuer { + // Validate the Issuer field (see RFC 8414, section 3.3). + return nil, fmt.Errorf("metadata issuer %q does not match issuer URL %q", asm.Issuer, issuer) + } - // Validate endpoint URLs to prevent XSS attacks (see #526). - if err := validateAuthServerMetaURLs(asm); err != nil { - return nil, err - } + if len(asm.CodeChallengeMethodsSupported) == 0 { + return nil, fmt.Errorf("authorization server at %s does not implement PKCE", issuer) + } - return asm, nil - } - errs = append(errs, err) + // Validate endpoint URLs to prevent XSS attacks (see #526). + if err := validateAuthServerMetaURLs(asm); err != nil { + return nil, err } - return nil, fmt.Errorf("failed to get auth server metadata from %q: %w", issuerURL, errors.Join(errs...)) + + return asm, nil } // validateAuthServerMetaURLs validates all URL fields in AuthServerMeta diff --git a/oauthex/auth_meta_test.go b/oauthex/auth_meta_test.go index 1e608824..8b67e59b 100644 --- a/oauthex/auth_meta_test.go +++ b/oauthex/auth_meta_test.go @@ -86,6 +86,7 @@ func TestGetAuthServerMetaPKCESupport(t *testing.T) { // The fake server sets issuer to https://localhost:, so compute that issuer. u, _ := url.Parse(ts.URL) issuer := "https://localhost:" + u.Port() + metadataURL := issuer + "/.well-known/oauth-authorization-server" // The fake server presents a cert for example.com; set ServerName accordingly. httpClient := ts.Client() @@ -95,7 +96,7 @@ func TestGetAuthServerMetaPKCESupport(t *testing.T) { httpClient.Transport = clone } - meta, err := GetAuthServerMeta(ctx, issuer, httpClient) + meta, err := GetAuthServerMeta(ctx, metadataURL, issuer, httpClient) if tt.wantError != "" { if err == nil { t.Fatal("wanted error but got none") diff --git a/oauthex/oauth2.go b/oauthex/oauth2.go index cdda695b..836a4201 100644 --- a/oauthex/oauth2.go +++ b/oauthex/oauth2.go @@ -19,21 +19,12 @@ import ( "strings" ) -// prependToPath prepends pre to the path of urlStr. -// When pre is the well-known path, this is the algorithm specified in both RFC 9728 -// section 3.1 and RFC 8414 section 3.1. -func prependToPath(urlStr, pre string) (string, error) { - u, err := url.Parse(urlStr) - if err != nil { - return "", err - } - p := "/" + strings.Trim(pre, "/") - if u.Path != "" { - p += "/" - } +type httpStatusError struct { + StatusCode int +} - u.Path = p + strings.TrimLeft(u.Path, "/") - return u.String(), nil +func (e *httpStatusError) Error() string { + return fmt.Sprintf("bad status %d", e.StatusCode) } // getJSON retrieves JSON and unmarshals JSON from the URL, as specified in both @@ -53,11 +44,9 @@ func getJSON[T any](ctx context.Context, c *http.Client, url string, limit int64 } defer res.Body.Close() - // Specs require a 200. if res.StatusCode != http.StatusOK { - return nil, fmt.Errorf("bad status %s", res.Status) + return nil, &httpStatusError{StatusCode: res.StatusCode} } - // Specs require application/json. ct := res.Header.Get("Content-Type") mediaType, _, err := mime.ParseMediaType(ct) if err != nil || mediaType != "application/json" { diff --git a/oauthex/oauth2_test.go b/oauthex/oauth2_test.go index 08d2d314..36f732e8 100644 --- a/oauthex/oauth2_test.go +++ b/oauthex/oauth2_test.go @@ -82,13 +82,13 @@ func TestParseSingleChallenge(t *testing.T) { tests := []struct { name string input string - want challenge + want Challenge wantErr bool }{ { name: "scheme only", input: "Basic", - want: challenge{ + want: Challenge{ Scheme: "basic", }, wantErr: false, @@ -96,7 +96,7 @@ func TestParseSingleChallenge(t *testing.T) { { name: "scheme with one quoted param", input: `Bearer realm="example.com"`, - want: challenge{ + want: Challenge{ Scheme: "bearer", Params: map[string]string{"realm": "example.com"}, }, @@ -105,7 +105,7 @@ func TestParseSingleChallenge(t *testing.T) { { name: "scheme with one unquoted param", input: `Bearer realm=example.com`, - want: challenge{ + want: Challenge{ Scheme: "bearer", Params: map[string]string{"realm": "example.com"}, }, @@ -114,7 +114,7 @@ func TestParseSingleChallenge(t *testing.T) { { name: "scheme with multiple params", input: `Bearer realm="example", error="invalid_token", error_description="The token expired"`, - want: challenge{ + want: Challenge{ Scheme: "bearer", Params: map[string]string{ "realm": "example", @@ -127,7 +127,7 @@ func TestParseSingleChallenge(t *testing.T) { { name: "scheme with multiple unquoted params", input: `Bearer realm=example, error=invalid_token, error_description=The token expired`, - want: challenge{ + want: Challenge{ Scheme: "bearer", Params: map[string]string{ "realm": "example", @@ -140,7 +140,7 @@ func TestParseSingleChallenge(t *testing.T) { { name: "case-insensitive scheme and keys", input: `BEARER ReAlM="example"`, - want: challenge{ + want: Challenge{ Scheme: "bearer", Params: map[string]string{"realm": "example"}, }, @@ -149,7 +149,7 @@ func TestParseSingleChallenge(t *testing.T) { { name: "param with escaped quote", input: `Bearer realm="example \"foo\" bar"`, - want: challenge{ + want: Challenge{ Scheme: "bearer", Params: map[string]string{"realm": `example "foo" bar`}, }, @@ -158,7 +158,7 @@ func TestParseSingleChallenge(t *testing.T) { { name: "param without quotes (token)", input: "Bearer realm=example.com", - want: challenge{ + want: Challenge{ Scheme: "bearer", Params: map[string]string{"realm": "example.com"}, }, diff --git a/oauthex/oauthex.go b/oauthex/oauthex.go index 34ed55b5..151da7e5 100644 --- a/oauthex/oauthex.go +++ b/oauthex/oauthex.go @@ -4,89 +4,3 @@ // Package oauthex implements extensions to OAuth2. package oauthex - -// ProtectedResourceMetadata is the metadata for an OAuth 2.0 protected resource, -// as defined in section 2 of https://www.rfc-editor.org/rfc/rfc9728.html. -// -// The following features are not supported: -// - additional keys (§2, last sentence) -// - human-readable metadata (§2.1) -// - signed metadata (§2.2) -type ProtectedResourceMetadata struct { - // GENERATED BY GEMINI 2.5. - - // Resource (resource) is the protected resource's resource identifier. - // Required. - Resource string `json:"resource"` - - // AuthorizationServers (authorization_servers) is an optional slice containing a list of - // OAuth authorization server issuer identifiers (as defined in RFC 8414) that can be - // used with this protected resource. - AuthorizationServers []string `json:"authorization_servers,omitempty"` - - // JWKSURI (jwks_uri) is an optional URL of the protected resource's JSON Web Key (JWK) Set - // document. This contains public keys belonging to the protected resource, such as - // signing key(s) that the resource server uses to sign resource responses. - JWKSURI string `json:"jwks_uri,omitempty"` - - // ScopesSupported (scopes_supported) is a recommended slice containing a list of scope - // values (as defined in RFC 6749) used in authorization requests to request access - // to this protected resource. - ScopesSupported []string `json:"scopes_supported,omitempty"` - - // BearerMethodsSupported (bearer_methods_supported) is an optional slice containing - // a list of the supported methods of sending an OAuth 2.0 bearer token to the - // protected resource. Defined values are "header", "body", and "query". - BearerMethodsSupported []string `json:"bearer_methods_supported,omitempty"` - - // ResourceSigningAlgValuesSupported (resource_signing_alg_values_supported) is an optional - // slice of JWS signing algorithms (alg values) supported by the protected - // resource for signing resource responses. - ResourceSigningAlgValuesSupported []string `json:"resource_signing_alg_values_supported,omitempty"` - - // ResourceName (resource_name) is a human-readable name of the protected resource - // intended for display to the end user. It is RECOMMENDED that this field be included. - // This value may be internationalized. - ResourceName string `json:"resource_name,omitempty"` - - // ResourceDocumentation (resource_documentation) is an optional URL of a page containing - // human-readable information for developers using the protected resource. - // This value may be internationalized. - ResourceDocumentation string `json:"resource_documentation,omitempty"` - - // ResourcePolicyURI (resource_policy_uri) is an optional URL of a page containing - // human-readable policy information on how a client can use the data provided. - // This value may be internationalized. - ResourcePolicyURI string `json:"resource_policy_uri,omitempty"` - - // ResourceTOSURI (resource_tos_uri) is an optional URL of a page containing the protected - // resource's human-readable terms of service. This value may be internationalized. - ResourceTOSURI string `json:"resource_tos_uri,omitempty"` - - // TLSClientCertificateBoundAccessTokens (tls_client_certificate_bound_access_tokens) is an - // optional boolean indicating support for mutual-TLS client certificate-bound - // access tokens (RFC 8705). Defaults to false if omitted. - TLSClientCertificateBoundAccessTokens bool `json:"tls_client_certificate_bound_access_tokens,omitempty"` - - // AuthorizationDetailsTypesSupported (authorization_details_types_supported) is an optional - // slice of 'type' values supported by the resource server for the - // 'authorization_details' parameter (RFC 9396). - AuthorizationDetailsTypesSupported []string `json:"authorization_details_types_supported,omitempty"` - - // DPOPSigningAlgValuesSupported (dpop_signing_alg_values_supported) is an optional - // slice of JWS signing algorithms supported by the resource server for validating - // DPoP proof JWTs (RFC 9449). - DPOPSigningAlgValuesSupported []string `json:"dpop_signing_alg_values_supported,omitempty"` - - // DPOPBoundAccessTokensRequired (dpop_bound_access_tokens_required) is an optional boolean - // specifying whether the protected resource always requires the use of DPoP-bound - // access tokens (RFC 9449). Defaults to false if omitted. - DPOPBoundAccessTokensRequired bool `json:"dpop_bound_access_tokens_required,omitempty"` - - // SignedMetadata (signed_metadata) is an optional JWT containing metadata parameters - // about the protected resource as claims. If present, these values take precedence - // over values conveyed in plain JSON. - // TODO:implement. - // Note that §2.2 says it's okay to ignore this. - // SignedMetadata string `json:"signed_metadata,omitempty"` -} diff --git a/oauthex/resource_meta.go b/oauthex/resource_meta.go index bb61f797..8b911cad 100644 --- a/oauthex/resource_meta.go +++ b/oauthex/resource_meta.go @@ -38,6 +38,8 @@ const defaultProtectedResourceMetadataURI = "/.well-known/oauth-protected-resour // // It then retrieves the metadata at that location using the given client (or the // default client if nil) and validates its resource field against resourceID. +// +// Deprecated: Use [GetProtectedResourceMetadata] instead. This function will be removed in v1.5.0. func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string, c *http.Client) (_ *ProtectedResourceMetadata, err error) { defer util.Wrapf(&err, "GetProtectedResourceMetadataFromID(%q)", resourceID) @@ -47,7 +49,7 @@ func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string, } // Insert well-known URI into URL. u.Path = path.Join(defaultProtectedResourceMetadataURI, u.Path) - return getPRM(ctx, u.String(), c, resourceID) + return GetProtectedResourceMetadata(ctx, u.String(), resourceID, c) } // GetProtectedResourceMetadataFromHeader retrieves protected resource metadata @@ -57,8 +59,9 @@ func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string, // Per RFC 9728 section 3.3, it validates that the resource field of the resulting metadata // matches the serverURL (the URL that the client used to make the original request to the resource server). // If there is no metadata URL in the header, it returns nil, nil. +// +// Deprecated: Use [GetProtectedResourceMetadata] instead. This function will be removed in v1.5.0. func GetProtectedResourceMetadataFromHeader(ctx context.Context, serverURL string, header http.Header, c *http.Client) (_ *ProtectedResourceMetadata, err error) { - defer util.Wrapf(&err, "GetProtectedResourceMetadataFromHeader") headers := header[http.CanonicalHeaderKey("WWW-Authenticate")] if len(headers) == 0 { return nil, nil @@ -67,26 +70,49 @@ func GetProtectedResourceMetadataFromHeader(ctx context.Context, serverURL strin if err != nil { return nil, err } - metadataURL := ResourceMetadataURL(cs) + metadataURL := resourceMetadataURL(cs) if metadataURL == "" { return nil, nil } - return getPRM(ctx, metadataURL, c, serverURL) + return GetProtectedResourceMetadata(ctx, metadataURL, serverURL, c) } -// getPRM makes a GET request to the given URL, and validates the response. -// As part of the validation, it compares the returned resource field to wantResource. -func getPRM(ctx context.Context, purl string, c *http.Client, wantResource string) (*ProtectedResourceMetadata, error) { - if !strings.HasPrefix(strings.ToUpper(purl), "HTTPS://") { - return nil, fmt.Errorf("resource URL %q does not use HTTPS", purl) +// resourceMetadataURL returns a resource metadata URL from the given "WWW-Authenticate" header challenges, +// or the empty string if there is none. +func resourceMetadataURL(cs []Challenge) string { + for _, c := range cs { + if u := c.Params["resource_metadata"]; u != "" { + return u + } + } + return "" +} + +// GetProtectedResourceMetadataFromID issues a GET request to retrieve protected resource +// metadata from a resource server. +// The metadataURL is typically a URL with a host:port and possibly a path. +// The resourceURL is the resource URI the metadataURL is for. +// The following checks are performed: +// - The metadataURL must use HTTPS or be a local address. +// - The resource field of the resulting metadata must match the resourceURL. +// - The authorization_servers field of the resulting metadata is checked for dangerous URL schemes. +func GetProtectedResourceMetadata(ctx context.Context, metadataURL, resourceURL string, c *http.Client) (_ *ProtectedResourceMetadata, err error) { + defer util.Wrapf(&err, "GetProtectedResourceMetadata(%q)", metadataURL) + u, err := url.Parse(metadataURL) + if err != nil { + return nil, err + } + // Only allow HTTP for local addresses (testing or development purposes). + if !util.IsLoopback(u.Host) && u.Scheme != "https" { + return nil, fmt.Errorf("metadataURL %q does not use HTTPS", metadataURL) } - prm, err := getJSON[ProtectedResourceMetadata](ctx, c, purl, 1<<20) + prm, err := getJSON[ProtectedResourceMetadata](ctx, c, metadataURL, 1<<20) if err != nil { return nil, err } // Validate the Resource field (see RFC 9728, section 3.3). - if prm.Resource != wantResource { - return nil, fmt.Errorf("got metadata resource %q, want %q", prm.Resource, wantResource) + if prm.Resource != resourceURL { + return nil, fmt.Errorf("got metadata resource %q, want %q", prm.Resource, resourceURL) } // Validate the authorization server URLs to prevent XSS attacks (see #526). for _, u := range prm.AuthorizationServers { @@ -97,37 +123,12 @@ func getPRM(ctx context.Context, purl string, c *http.Client, wantResource strin return prm, nil } -// challenge represents a single authentication challenge from a WWW-Authenticate header. -// As per RFC 9110, Section 11.6.1, a challenge consists of a scheme and optional parameters. -type challenge struct { - // GENERATED BY GEMINI 2.5. - // - // Scheme is the authentication scheme (e.g., "Bearer", "Basic"). - // It is case-insensitive. A parsed value will always be lower-case. - Scheme string - // Params is a map of authentication parameters. - // Keys are case-insensitive. Parsed keys are always lower-case. - Params map[string]string -} - -// ResourceMetadataURL returns a resource metadata URL from the given challenges, -// or the empty string if there is none. -func ResourceMetadataURL(cs []challenge) string { - for _, c := range cs { - if u := c.Params["resource_metadata"]; u != "" { - return u - } - } - return "" -} - // ParseWWWAuthenticate parses a WWW-Authenticate header string. // The header format is defined in RFC 9110, Section 11.6.1, and can contain // one or more challenges, separated by commas. // It returns a slice of challenges or an error if one of the headers is malformed. -func ParseWWWAuthenticate(headers []string) ([]challenge, error) { - // GENERATED BY GEMINI 2.5 (human-tweaked) - var challenges []challenge +func ParseWWWAuthenticate(headers []string) ([]Challenge, error) { + var challenges []Challenge for _, h := range headers { challengeStrings, err := splitChallenges(h) if err != nil { @@ -151,7 +152,6 @@ func ParseWWWAuthenticate(headers []string) ([]challenge, error) { // It correctly handles commas within quoted strings and distinguishes between // commas separating auth-params and commas separating challenges. func splitChallenges(header string) ([]string, error) { - // GENERATED BY GEMINI 2.5. var challenges []string inQuotes := false start := 0 @@ -195,15 +195,14 @@ func splitChallenges(header string) ([]string, error) { // parseSingleChallenge parses a string containing exactly one challenge. // challenge = auth-scheme [ 1*SP ( token68 / #auth-param ) ] -func parseSingleChallenge(s string) (challenge, error) { - // GENERATED BY GEMINI 2.5, human-tweaked. +func parseSingleChallenge(s string) (Challenge, error) { s = strings.TrimSpace(s) if s == "" { - return challenge{}, errors.New("empty challenge string") + return Challenge{}, errors.New("empty challenge string") } scheme, paramsStr, found := strings.Cut(s, " ") - c := challenge{Scheme: strings.ToLower(scheme)} + c := Challenge{Scheme: strings.ToLower(scheme)} if !found { return c, nil } @@ -215,7 +214,7 @@ func parseSingleChallenge(s string) (challenge, error) { // Find the end of the parameter key. keyEnd := strings.Index(paramsStr, "=") if keyEnd <= 0 { - return challenge{}, fmt.Errorf("malformed auth parameter: expected key=value, but got %q", paramsStr) + return Challenge{}, fmt.Errorf("malformed auth parameter: expected key=value, but got %q", paramsStr) } key := strings.TrimSpace(paramsStr[:keyEnd]) @@ -243,7 +242,7 @@ func parseSingleChallenge(s string) (challenge, error) { // A quoted string must be terminated. if i == len(paramsStr) { - return challenge{}, fmt.Errorf("unterminated quoted string in auth parameter") + return Challenge{}, fmt.Errorf("unterminated quoted string in auth parameter") } value = valBuilder.String() @@ -261,7 +260,7 @@ func parseSingleChallenge(s string) (challenge, error) { } } if value == "" { - return challenge{}, fmt.Errorf("no value for auth param %q", key) + return Challenge{}, fmt.Errorf("no value for auth param %q", key) } // Per RFC 9110, parameter keys are case-insensitive. @@ -272,10 +271,10 @@ func parseSingleChallenge(s string) (challenge, error) { paramsStr = strings.TrimSpace(paramsStr[1:]) } else if paramsStr != "" { // If there's content but it's not a new parameter, the format is wrong. - return challenge{}, fmt.Errorf("malformed auth parameter: expected comma after value, but got %q", paramsStr) + return Challenge{}, fmt.Errorf("malformed auth parameter: expected comma after value, but got %q", paramsStr) } } // Per RFC 9110, the scheme is case-insensitive. - return challenge{Scheme: strings.ToLower(scheme), Params: params}, nil + return Challenge{Scheme: strings.ToLower(scheme), Params: params}, nil } diff --git a/oauthex/resource_meta_public.go b/oauthex/resource_meta_public.go new file mode 100644 index 00000000..3bf7d9ac --- /dev/null +++ b/oauthex/resource_meta_public.go @@ -0,0 +1,105 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements Protected Resource Metadata. +// See https://www.rfc-editor.org/rfc/rfc9728.html. + +// This is a temporary file to expose the required objects to the main package. + +package oauthex + +// ProtectedResourceMetadata is the metadata for an OAuth 2.0 protected resource, +// as defined in section 2 of https://www.rfc-editor.org/rfc/rfc9728.html. +// +// The following features are not supported: +// - additional keys (§2, last sentence) +// - human-readable metadata (§2.1) +// - signed metadata (§2.2) +type ProtectedResourceMetadata struct { + // Resource (resource) is the protected resource's resource identifier. + // Required. + Resource string `json:"resource"` + + // AuthorizationServers (authorization_servers) is an optional slice containing a list of + // OAuth authorization server issuer identifiers (as defined in RFC 8414) that can be + // used with this protected resource. + AuthorizationServers []string `json:"authorization_servers,omitempty"` + + // JWKSURI (jwks_uri) is an optional URL of the protected resource's JSON Web Key (JWK) Set + // document. This contains public keys belonging to the protected resource, such as + // signing key(s) that the resource server uses to sign resource responses. + JWKSURI string `json:"jwks_uri,omitempty"` + + // ScopesSupported (scopes_supported) is a recommended slice containing a list of scope + // values (as defined in RFC 6749) used in authorization requests to request access + // to this protected resource. + ScopesSupported []string `json:"scopes_supported,omitempty"` + + // BearerMethodsSupported (bearer_methods_supported) is an optional slice containing + // a list of the supported methods of sending an OAuth 2.0 bearer token to the + // protected resource. Defined values are "header", "body", and "query". + BearerMethodsSupported []string `json:"bearer_methods_supported,omitempty"` + + // ResourceSigningAlgValuesSupported (resource_signing_alg_values_supported) is an optional + // slice of JWS signing algorithms (alg values) supported by the protected + // resource for signing resource responses. + ResourceSigningAlgValuesSupported []string `json:"resource_signing_alg_values_supported,omitempty"` + + // ResourceName (resource_name) is a human-readable name of the protected resource + // intended for display to the end user. It is RECOMMENDED that this field be included. + // This value may be internationalized. + ResourceName string `json:"resource_name,omitempty"` + + // ResourceDocumentation (resource_documentation) is an optional URL of a page containing + // human-readable information for developers using the protected resource. + // This value may be internationalized. + ResourceDocumentation string `json:"resource_documentation,omitempty"` + + // ResourcePolicyURI (resource_policy_uri) is an optional URL of a page containing + // human-readable policy information on how a client can use the data provided. + // This value may be internationalized. + ResourcePolicyURI string `json:"resource_policy_uri,omitempty"` + + // ResourceTOSURI (resource_tos_uri) is an optional URL of a page containing the protected + // resource's human-readable terms of service. This value may be internationalized. + ResourceTOSURI string `json:"resource_tos_uri,omitempty"` + + // TLSClientCertificateBoundAccessTokens (tls_client_certificate_bound_access_tokens) is an + // optional boolean indicating support for mutual-TLS client certificate-bound + // access tokens (RFC 8705). Defaults to false if omitted. + TLSClientCertificateBoundAccessTokens bool `json:"tls_client_certificate_bound_access_tokens,omitempty"` + + // AuthorizationDetailsTypesSupported (authorization_details_types_supported) is an optional + // slice of 'type' values supported by the resource server for the + // 'authorization_details' parameter (RFC 9396). + AuthorizationDetailsTypesSupported []string `json:"authorization_details_types_supported,omitempty"` + + // DPOPSigningAlgValuesSupported (dpop_signing_alg_values_supported) is an optional + // slice of JWS signing algorithms supported by the resource server for validating + // DPoP proof JWTs (RFC 9449). + DPOPSigningAlgValuesSupported []string `json:"dpop_signing_alg_values_supported,omitempty"` + + // DPOPBoundAccessTokensRequired (dpop_bound_access_tokens_required) is an optional boolean + // specifying whether the protected resource always requires the use of DPoP-bound + // access tokens (RFC 9449). Defaults to false if omitted. + DPOPBoundAccessTokensRequired bool `json:"dpop_bound_access_tokens_required,omitempty"` + + // SignedMetadata (signed_metadata) is an optional JWT containing metadata parameters + // about the protected resource as claims. If present, these values take precedence + // over values conveyed in plain JSON. + // TODO:implement. + // Note that §2.2 says it's okay to ignore this. + // SignedMetadata string `json:"signed_metadata,omitempty"` +} + +// Challenge represents a single authentication challenge from a WWW-Authenticate header. +// As per RFC 9110, Section 11.6.1, a challenge consists of a scheme and optional parameters. +type Challenge struct { + // Scheme is the authentication scheme (e.g., "Bearer", "Basic"). + // It is case-insensitive. A parsed value will always be lower-case. + Scheme string + // Params is a map of authentication parameters. + // Keys are case-insensitive. Parsed keys are always lower-case. + Params map[string]string +} diff --git a/oauthex/url_scheme_test.go b/oauthex/url_scheme_test.go index 531a1f9c..83eeb5e1 100644 --- a/oauthex/url_scheme_test.go +++ b/oauthex/url_scheme_test.go @@ -226,7 +226,9 @@ func TestGetAuthServerMetaRejectsDangerousURLs(t *testing.T) { defer server.Close() ctx := context.Background() - _, err := GetAuthServerMeta(ctx, server.URL, server.Client()) + issuer := server.URL + metadataURL := issuer + _, err := GetAuthServerMeta(ctx, metadataURL, issuer, server.Client()) if err == nil { t.Fatal("GetAuthServerMeta(): got nil error, want error") } diff --git a/scripts/client-conformance.sh b/scripts/client-conformance.sh index c093c75f..17528450 100755 --- a/scripts/client-conformance.sh +++ b/scripts/client-conformance.sh @@ -10,6 +10,7 @@ set -e RESULT_DIR="" WORKDIR="" CONFORMANCE_REPO="" +SUITE="core" FINAL_EXIT_CODE=0 usage() { @@ -21,9 +22,11 @@ usage() { echo " --result_dir Save results to the specified directory" echo " --conformance_repo Run conformance tests from a local checkout" echo " instead of using the latest npm release" + echo " --suite Which suite to run (default: core)" echo " --help Show this help message" } + # Parse arguments. while [[ $# -gt 0 ]]; do case $1 in @@ -35,6 +38,10 @@ while [[ $# -gt 0 ]]; do CONFORMANCE_REPO="$2" shift 2 ;; + --suite) + SUITE="$2" + shift 2 + ;; --help) usage exit 0 @@ -56,7 +63,7 @@ else fi # Build the conformance server. -go build -o "$WORKDIR/conformance-client" ./conformance/everything-client +go build -tags mcp_go_client_oauth -o "$WORKDIR/conformance-client" ./conformance/everything-client # Run conformance tests from the work directory to avoid writing results to the repo. echo "Running conformance tests..." @@ -65,13 +72,13 @@ if [ -n "$CONFORMANCE_REPO" ]; then (cd "$WORKDIR" && \ npm --prefix "$CONFORMANCE_REPO" run start -- \ client --command "$WORKDIR/conformance-client" \ - --suite core \ + --suite "$SUITE" \ ${RESULT_DIR:+--output-dir "$RESULT_DIR"}) || FINAL_EXIT_CODE=$? else (cd "$WORKDIR" && \ npx @modelcontextprotocol/conformance@latest \ client --command "$WORKDIR/conformance-client" \ - --suite core \ + --suite "$SUITE" \ ${RESULT_DIR:+--output-dir "$RESULT_DIR"}) || FINAL_EXIT_CODE=$? fi