diff --git a/pkg/config/config.go b/pkg/config/config.go index af3f0ae..723d17c 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -6,17 +6,23 @@ import ( "github.com/BurntSushi/toml" ) +// StaticConfig is the configuration for the server. +// It allows to configure server specific settings and tools to be enabled or disabled. type StaticConfig struct { DeniedResources []GroupVersionKind `toml:"denied_resources"` - LogLevel int `toml:"log_level,omitempty"` - SSEPort int `toml:"sse_port,omitempty"` - HTTPPort int `toml:"http_port,omitempty"` - SSEBaseURL string `toml:"sse_base_url,omitempty"` - KubeConfig string `toml:"kubeconfig,omitempty"` - ListOutput string `toml:"list_output,omitempty"` - ReadOnly bool `toml:"read_only,omitempty"` - DisableDestructive bool `toml:"disable_destructive,omitempty"` + LogLevel int `toml:"log_level,omitempty"` + SSEPort int `toml:"sse_port,omitempty"` + HTTPPort int `toml:"http_port,omitempty"` + SSEBaseURL string `toml:"sse_base_url,omitempty"` + KubeConfig string `toml:"kubeconfig,omitempty"` + ListOutput string `toml:"list_output,omitempty"` + // When true, expose only tools annotated with readOnlyHint=true + ReadOnly bool `toml:"read_only,omitempty"` + // When true, disable tools annotated with destructiveHint=true + DisableDestructive bool `toml:"disable_destructive,omitempty"` + EnabledTools []string `toml:"enabled_tools,omitempty"` + DisabledTools []string `toml:"disabled_tools,omitempty"` } type GroupVersionKind struct { @@ -25,6 +31,7 @@ type GroupVersionKind struct { Kind string `toml:"kind,omitempty"` } +// ReadConfig reads the toml file and returns the StaticConfig. func ReadConfig(configPath string) (*StaticConfig, error) { configData, err := os.ReadFile(configPath) if err != nil { diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 1aac25f..0a7be56 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -57,14 +57,13 @@ list_output = "yaml" read_only = true disable_destructive = false -[[denied_resources]] -group = "apps" -version = "v1" -kind = "Deployment" +denied_resources = [ + {group = "apps", version = "v1", kind = "Deployment"}, + {group = "rbac.authorization.k8s.io", version = "v1", kind = "Role"} +] -[[denied_resources]] -group = "rbac.authorization.k8s.io" -version = "v1" +enabled_tools = ["configuration_view", "events_list", "namespaces_list", "pods_list", "resources_list", "resources_get", "resources_create_or_update", "resources_delete"] +disabled_tools = ["pods_delete", "pods_top", "pods_log", "pods_run", "pods_exec"] `) config, err := ReadConfig(validConfigPath) @@ -109,6 +108,12 @@ version = "v1" if config.DisableDestructive { t.Fatalf("Unexpected disable destructive: %v", config.DisableDestructive) } + if len(config.EnabledTools) != 8 { + t.Fatalf("Unexpected enabled tools: %v", config.EnabledTools) + } + if len(config.DisabledTools) != 5 { + t.Fatalf("Unexpected disabled tools: %v", config.DisabledTools) + } }) } diff --git a/pkg/kubernetes-mcp-server/cmd/root.go b/pkg/kubernetes-mcp-server/cmd/root.go index 9d36fe3..cd4837f 100644 --- a/pkg/kubernetes-mcp-server/cmd/root.go +++ b/pkg/kubernetes-mcp-server/cmd/root.go @@ -186,12 +186,9 @@ func (m *MCPServerOptions) Run() error { return nil } mcpServer, err := mcp.NewServer(mcp.Configuration{ - Profile: profile, - ListOutput: listOutput, - ReadOnly: m.StaticConfig.ReadOnly, - DisableDestructive: m.StaticConfig.DisableDestructive, - Kubeconfig: m.StaticConfig.KubeConfig, - StaticConfig: m.StaticConfig, + Profile: profile, + ListOutput: listOutput, + StaticConfig: m.StaticConfig, }) if err != nil { return fmt.Errorf("Failed to initialize MCP server: %w\n", err) diff --git a/pkg/kubernetes-mcp-server/cmd/testdata/valid-config.toml b/pkg/kubernetes-mcp-server/cmd/testdata/valid-config.toml index bb8da16..8b46a1a 100644 --- a/pkg/kubernetes-mcp-server/cmd/testdata/valid-config.toml +++ b/pkg/kubernetes-mcp-server/cmd/testdata/valid-config.toml @@ -5,11 +5,11 @@ list_output = "yaml" read_only = true disable_destructive = true -[[denied_resources]] -group = "apps" -version = "v1" -kind = "Deployment" +denied_resources = [ + {group = "apps", version = "v1", kind = "Deployment"}, + {group = "rbac.authorization.k8s.io", version = "v1", kind = "Role"} +] + +enabled_tools = ["configuration_view", "events_list", "namespaces_list", "pods_list", "resources_list", "resources_get", "resources_create_or_update", "resources_delete"] +disabled_tools = ["pods_delete", "pods_top", "pods_log", "pods_run", "pods_exec"] -[[denied_resources]] -group = "rbac.authorization.k8s.io" -version = "v1" \ No newline at end of file diff --git a/pkg/mcp/common_test.go b/pkg/mcp/common_test.go index d21f8f9..458ff8f 100644 --- a/pkg/mcp/common_test.go +++ b/pkg/mcp/common_test.go @@ -4,6 +4,13 @@ import ( "context" "encoding/json" "fmt" + "net/http/httptest" + "os" + "path/filepath" + "runtime" + "testing" + "time" + "github.com/manusa/kubernetes-mcp-server/pkg/config" "github.com/manusa/kubernetes-mcp-server/pkg/output" "github.com/mark3labs/mcp-go/client" @@ -28,18 +35,12 @@ import ( "k8s.io/client-go/tools/clientcmd/api" toolswatch "k8s.io/client-go/tools/watch" "k8s.io/utils/ptr" - "net/http/httptest" - "os" - "path/filepath" - "runtime" "sigs.k8s.io/controller-runtime/pkg/envtest" "sigs.k8s.io/controller-runtime/tools/setup-envtest/env" "sigs.k8s.io/controller-runtime/tools/setup-envtest/remote" "sigs.k8s.io/controller-runtime/tools/setup-envtest/store" "sigs.k8s.io/controller-runtime/tools/setup-envtest/versions" "sigs.k8s.io/controller-runtime/tools/setup-envtest/workflows" - "testing" - "time" ) // envTest has an expensive setup, so we only want to do it once per entire test run. @@ -97,20 +98,19 @@ func TestMain(m *testing.M) { } type mcpContext struct { - profile Profile - listOutput output.Output - readOnly bool - disableDestructive bool - staticConfig *config.StaticConfig - clientOptions []transport.ClientOption - before func(*mcpContext) - after func(*mcpContext) - ctx context.Context - tempDir string - cancel context.CancelFunc - mcpServer *Server - mcpHttpServer *httptest.Server - mcpClient *client.Client + profile Profile + listOutput output.Output + + staticConfig *config.StaticConfig + clientOptions []transport.ClientOption + before func(*mcpContext) + after func(*mcpContext) + ctx context.Context + tempDir string + cancel context.CancelFunc + mcpServer *Server + mcpHttpServer *httptest.Server + mcpClient *client.Client } func (c *mcpContext) beforeEach(t *testing.T) { @@ -125,17 +125,18 @@ func (c *mcpContext) beforeEach(t *testing.T) { c.listOutput = output.Yaml } if c.staticConfig == nil { - c.staticConfig = &config.StaticConfig{} + c.staticConfig = &config.StaticConfig{ + ReadOnly: false, + DisableDestructive: false, + } } if c.before != nil { c.before(c) } if c.mcpServer, err = NewServer(Configuration{ - Profile: c.profile, - ListOutput: c.listOutput, - ReadOnly: c.readOnly, - DisableDestructive: c.disableDestructive, - StaticConfig: c.staticConfig, + Profile: c.profile, + ListOutput: c.listOutput, + StaticConfig: c.staticConfig, }); err != nil { t.Fatal(err) return diff --git a/pkg/mcp/mcp.go b/pkg/mcp/mcp.go index 8d0d950..1b4961c 100644 --- a/pkg/mcp/mcp.go +++ b/pkg/mcp/mcp.go @@ -2,13 +2,14 @@ package mcp import ( "context" - "github.com/manusa/kubernetes-mcp-server/pkg/config" "net/http" + "slices" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "k8s.io/utils/ptr" + "github.com/manusa/kubernetes-mcp-server/pkg/config" "github.com/manusa/kubernetes-mcp-server/pkg/kubernetes" "github.com/manusa/kubernetes-mcp-server/pkg/output" "github.com/manusa/kubernetes-mcp-server/pkg/version" @@ -17,15 +18,26 @@ import ( type Configuration struct { Profile Profile ListOutput output.Output - // When true, expose only tools annotated with readOnlyHint=true - ReadOnly bool - // When true, disable tools annotated with destructiveHint=true - DisableDestructive bool - Kubeconfig string StaticConfig *config.StaticConfig } +func (c *Configuration) isToolApplicable(tool server.ServerTool) bool { + if c.StaticConfig.ReadOnly && !ptr.Deref(tool.Tool.Annotations.ReadOnlyHint, false) { + return false + } + if c.StaticConfig.DisableDestructive && !ptr.Deref(tool.Tool.Annotations.ReadOnlyHint, false) && ptr.Deref(tool.Tool.Annotations.DestructiveHint, false) { + return false + } + if c.StaticConfig.EnabledTools != nil && !slices.Contains(c.StaticConfig.EnabledTools, tool.Tool.Name) { + return false + } + if c.StaticConfig.DisabledTools != nil && slices.Contains(c.StaticConfig.DisabledTools, tool.Tool.Name) { + return false + } + return true +} + type Server struct { configuration *Configuration server *server.MCPServer @@ -53,17 +65,14 @@ func NewServer(configuration Configuration) (*Server, error) { } func (s *Server) reloadKubernetesClient() error { - k, err := kubernetes.NewManager(s.configuration.Kubeconfig, s.configuration.StaticConfig) + k, err := kubernetes.NewManager(s.configuration.StaticConfig.KubeConfig, s.configuration.StaticConfig) if err != nil { return err } s.k = k applicableTools := make([]server.ServerTool, 0) for _, tool := range s.configuration.Profile.GetTools(s) { - if s.configuration.ReadOnly && !ptr.Deref(tool.Tool.Annotations.ReadOnlyHint, false) { - continue - } - if s.configuration.DisableDestructive && !ptr.Deref(tool.Tool.Annotations.ReadOnlyHint, false) && ptr.Deref(tool.Tool.Annotations.DestructiveHint, false) { + if !s.configuration.isToolApplicable(tool) { continue } applicableTools = append(applicableTools, tool) diff --git a/pkg/mcp/mcp_test.go b/pkg/mcp/mcp_test.go index b1ba3ef..da1c8ff 100644 --- a/pkg/mcp/mcp_test.go +++ b/pkg/mcp/mcp_test.go @@ -9,8 +9,12 @@ import ( "testing" "time" + "k8s.io/utils/ptr" + + "github.com/manusa/kubernetes-mcp-server/pkg/config" "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" ) func TestWatchKubeConfig(t *testing.T) { @@ -52,7 +56,7 @@ func TestWatchKubeConfig(t *testing.T) { } func TestReadOnly(t *testing.T) { - readOnlyServer := func(c *mcpContext) { c.readOnly = true } + readOnlyServer := func(c *mcpContext) { c.staticConfig = &config.StaticConfig{ReadOnly: true} } testCaseWithContext(t, &mcpContext{before: readOnlyServer}, func(c *mcpContext) { tools, err := c.mcpClient.ListTools(c.ctx, mcp.ListToolsRequest{}) t.Run("ListTools returns tools", func(t *testing.T) { @@ -73,8 +77,175 @@ func TestReadOnly(t *testing.T) { }) } +func TestIsToolApplicableReadOnly(t *testing.T) { + tests := []struct { + config Configuration + expected bool + tool server.ServerTool + }{ + { + config: Configuration{ + StaticConfig: &config.StaticConfig{ + ReadOnly: true, + }, + }, + expected: true, + tool: server.ServerTool{ + Tool: mcp.Tool{ + Annotations: mcp.ToolAnnotation{ + ReadOnlyHint: ptr.To(true), + }, + }, + }, + }, + { + config: Configuration{ + StaticConfig: &config.StaticConfig{ + ReadOnly: true, + }, + }, + expected: false, + tool: server.ServerTool{ + Tool: mcp.Tool{ + Annotations: mcp.ToolAnnotation{ + ReadOnlyHint: ptr.To(false), + }, + }, + }, + }, + { + config: Configuration{ + StaticConfig: &config.StaticConfig{ + DisableDestructive: true, + }, + }, + expected: true, + tool: server.ServerTool{ + Tool: mcp.Tool{ + Annotations: mcp.ToolAnnotation{ + DestructiveHint: ptr.To(false), + }, + }, + }, + }, + { + config: Configuration{ + StaticConfig: &config.StaticConfig{ + DisableDestructive: true, + }, + }, + expected: true, + tool: server.ServerTool{ + Tool: mcp.Tool{ + Annotations: mcp.ToolAnnotation{ + DestructiveHint: ptr.To(true), + ReadOnlyHint: ptr.To(true), + }, + }, + }, + }, + { + config: Configuration{ + StaticConfig: &config.StaticConfig{ + DisableDestructive: true, + }, + }, + expected: false, + tool: server.ServerTool{ + Tool: mcp.Tool{ + Annotations: mcp.ToolAnnotation{ + DestructiveHint: ptr.To(true), + }, + }, + }, + }, + { + config: Configuration{ + StaticConfig: &config.StaticConfig{ + EnabledTools: []string{"namespaces_list"}, + }, + }, + expected: true, + tool: server.ServerTool{ + Tool: mcp.Tool{ + Name: "namespaces_list", + }, + }, + }, + { + config: Configuration{ + StaticConfig: &config.StaticConfig{ + DisabledTools: []string{"namespaces_list"}, + }, + }, + expected: false, + tool: server.ServerTool{ + Tool: mcp.Tool{ + Name: "namespaces_list", + }, + }, + }, + } + for _, test := range tests { + t.Run("", func(t *testing.T) { + isToolApplicable := test.config.isToolApplicable(test.tool) + if isToolApplicable != test.expected { + t.Errorf("isToolApplicable should return %t, got %t", test.expected, isToolApplicable) + } + }) + } + +} + +func TestIsToolApplicableEnabledTools(t *testing.T) { + testCaseWithContext(t, &mcpContext{ + staticConfig: &config.StaticConfig{ + EnabledTools: []string{"namespaces_list", "events_list"}, + }, + }, func(c *mcpContext) { + tools, err := c.mcpClient.ListTools(c.ctx, mcp.ListToolsRequest{}) + t.Run("ListTools returns tools", func(t *testing.T) { + if err != nil { + t.Fatalf("call ListTools failed %v", err) + } + }) + t.Run("ListTools does not only return enabled tools", func(t *testing.T) { + if len(tools.Tools) != 2 { + t.Fatalf("ListTools should return 2 tools, got %d", len(tools.Tools)) + } + for _, tool := range tools.Tools { + if tool.Name != "namespaces_list" && tool.Name != "events_list" { + t.Errorf("Tool %s is not enabled but should be", tool.Name) + } + } + }) + }) +} + +func TestIsToolApplicableDisabledTools(t *testing.T) { + testCaseWithContext(t, &mcpContext{ + staticConfig: &config.StaticConfig{ + DisabledTools: []string{"namespaces_list", "events_list"}, + }, + }, func(c *mcpContext) { + tools, err := c.mcpClient.ListTools(c.ctx, mcp.ListToolsRequest{}) + t.Run("ListTools returns tools", func(t *testing.T) { + if err != nil { + t.Fatalf("call ListTools failed %v", err) + } + }) + t.Run("ListTools does not only return disabled tools", func(t *testing.T) { + for _, tool := range tools.Tools { + if tool.Name == "namespaces_list" || tool.Name == "events_list" { + t.Errorf("Tool %s is not disabled but should be", tool.Name) + } + } + }) + }) +} + func TestDisableDestructive(t *testing.T) { - disableDestructiveServer := func(c *mcpContext) { c.disableDestructive = true } + disableDestructiveServer := func(c *mcpContext) { c.staticConfig = &config.StaticConfig{DisableDestructive: true} } testCaseWithContext(t, &mcpContext{before: disableDestructiveServer}, func(c *mcpContext) { tools, err := c.mcpClient.ListTools(c.ctx, mcp.ListToolsRequest{}) t.Run("ListTools returns tools", func(t *testing.T) {