Skip to content

Commit 6259a3d

Browse files
authored
Support custom tools (#223)
1 parent 5620699 commit 6259a3d

3 files changed

Lines changed: 155 additions & 0 deletions

File tree

cmd/main.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ type Options struct {
8686
ExtraPromptPaths []string `json:"extraPromptPaths,omitempty"`
8787
TracePath string `json:"tracePath,omitempty"`
8888
RemoveWorkDir bool `json:"removeWorkDir,omitempty"`
89+
ToolConfigPath string `json:"toolConfigPath,omitempty"`
8990

9091
// UserInterface is the type of user interface to use.
9192
UserInterface UserInterface `json:"userInterface,omitempty"`
@@ -236,6 +237,12 @@ func run(ctx context.Context) error {
236237
return fmt.Errorf("failed to load config file: %w", err)
237238
}
238239

240+
// Load and register custom tools from config file
241+
if err := tools.LoadAndRegisterCustomTools(opt.ToolConfigPath); err != nil {
242+
// Log the error but continue execution, as custom tools are optional
243+
klog.Warningf("Failed to load or register custom tools (path: %q): %v", opt.ToolConfigPath, err)
244+
}
245+
239246
rootCmd, err := BuildRootCommand(&opt)
240247
if err != nil {
241248
return err
@@ -267,6 +274,7 @@ func (opt *Options) bindCLIFlags(f *pflag.FlagSet) error {
267274
f.StringVar(&opt.ModelID, "model", opt.ModelID, "language model e.g. gemini-2.0-flash-thinking-exp-01-21, gemini-2.0-flash")
268275
f.BoolVar(&opt.SkipPermissions, "skip-permissions", opt.SkipPermissions, "(dangerous) skip asking for confirmation before executing kubectl commands that modify resources")
269276
f.BoolVar(&opt.MCPServer, "mcp-server", opt.MCPServer, "run in MCP server mode")
277+
f.StringVar(&opt.ToolConfigPath, "custom-tools-config", opt.ToolConfigPath, "path to custom tools config file")
270278
f.BoolVar(&opt.EnableToolUseShim, "enable-tool-use-shim", opt.EnableToolUseShim, "enable tool use shim")
271279
f.BoolVar(&opt.Quiet, "quiet", opt.Quiet, "run in non-interactive mode, requires a query to be provided as a positional argument")
272280

pkg/tools/custom_tool.go

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package tools
16+
17+
import (
18+
"context"
19+
"fmt"
20+
"os"
21+
"os/exec"
22+
"strings"
23+
24+
"github.com/GoogleCloudPlatform/kubectl-ai/gollm"
25+
)
26+
27+
// CustomToolConfig defines the structure for configuring a custom tool.
28+
type CustomToolConfig struct {
29+
Name string `yaml:"name"`
30+
Description string `yaml:"description"`
31+
Command string `yaml:"command"`
32+
CommandDesc string `yaml:"command_desc"`
33+
}
34+
35+
// CustomTool implements the Tool interface for external commands.
36+
type CustomTool struct {
37+
config CustomToolConfig
38+
}
39+
40+
// NewCustomTool creates a new CustomTool instance.
41+
func NewCustomTool(config CustomToolConfig) (*CustomTool, error) {
42+
if config.Name == "" {
43+
return nil, fmt.Errorf("custom tool name cannot be empty")
44+
}
45+
if len(config.Command) == 0 {
46+
return nil, fmt.Errorf("custom tool command cannot be empty for tool %q", config.Name)
47+
}
48+
49+
return &CustomTool{config: config}, nil
50+
}
51+
52+
// Name returns the tool's name.
53+
func (t *CustomTool) Name() string {
54+
return t.config.Name
55+
}
56+
57+
// Description returns the tool's description from its function definition.
58+
func (t *CustomTool) Description() string {
59+
return t.config.Description
60+
}
61+
62+
// FunctionDefinition returns the tool's function definition.
63+
func (t *CustomTool) FunctionDefinition() *gollm.FunctionDefinition {
64+
return &gollm.FunctionDefinition{
65+
Name: t.Name(),
66+
Description: t.Description(),
67+
Parameters: &gollm.Schema{
68+
Type: gollm.TypeObject,
69+
Properties: map[string]*gollm.Schema{
70+
"command": {
71+
Type: gollm.TypeString,
72+
Description: t.config.CommandDesc,
73+
},
74+
},
75+
},
76+
}
77+
}
78+
79+
// Run executes the external command defined for the custom tool.
80+
func (t *CustomTool) Run(ctx context.Context, args map[string]any) (any, error) {
81+
command := strings.Fields(t.config.Command)
82+
if len(command) == 0 {
83+
return nil, fmt.Errorf("empty command")
84+
}
85+
cmdArgs := []string{}
86+
if len(command) > 1 {
87+
cmdArgs = command[1:]
88+
}
89+
workDir := ctx.Value(WorkDirKey).(string)
90+
91+
cmd := exec.CommandContext(ctx, command[0], cmdArgs...)
92+
cmd.Dir = workDir
93+
cmd.Env = os.Environ()
94+
95+
return executeCommand(cmd)
96+
}

pkg/tools/tools.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,16 @@ import (
1919
"encoding/json"
2020
"fmt"
2121
"maps"
22+
"os"
23+
"path/filepath"
2224
"slices"
2325
"sort"
2426
"strings"
2527
"time"
2628

2729
"github.com/GoogleCloudPlatform/kubectl-ai/pkg/journal"
2830
"github.com/google/uuid"
31+
"sigs.k8s.io/yaml"
2932
)
3033

3134
type ContextKey string
@@ -180,3 +183,51 @@ func ToolResultToMap(result any) (map[string]any, error) {
180183
}
181184
return m, nil
182185
}
186+
187+
// LoadAndRegisterCustomTools loads tool configurations from a YAML file
188+
// and registers them.
189+
func LoadAndRegisterCustomTools(configPath string) error {
190+
if configPath == "" {
191+
// Default config path: ~/.config/kubectl-ai/tools.yaml
192+
home, err := os.UserHomeDir()
193+
if err != nil {
194+
return fmt.Errorf("failed to get user home directory for default custom tools config: %w", err)
195+
}
196+
configPath = filepath.Join(home, ".config", "kubectl-ai", "tools.yaml")
197+
}
198+
199+
yamlFile, err := os.ReadFile(configPath)
200+
if os.IsNotExist(err) {
201+
return nil
202+
} else if err != nil {
203+
return fmt.Errorf("failed to read config file %s: %w", configPath, err)
204+
}
205+
206+
var configs []CustomToolConfig
207+
err = yaml.Unmarshal(yamlFile, &configs)
208+
if err != nil {
209+
return fmt.Errorf("failed to parse YAML config file %s: %w", configPath, err)
210+
}
211+
212+
// Register each custom tool
213+
var registrationErrors []string
214+
for _, config := range configs {
215+
tool, err := NewCustomTool(config)
216+
if err != nil {
217+
registrationErrors = append(registrationErrors, fmt.Sprintf("failed to create tool %q: %v", config.Name, err))
218+
continue // Skip registration if creation failed
219+
}
220+
// Check for duplicate registration attempt
221+
if _, exists := allTools.tools[tool.Name()]; exists {
222+
registrationErrors = append(registrationErrors, fmt.Sprintf("tool %q already registered (possibly built-in), skipping custom definition", tool.Name()))
223+
continue
224+
}
225+
RegisterTool(tool)
226+
}
227+
228+
if len(registrationErrors) > 0 {
229+
return fmt.Errorf("encountered errors during custom tool registration:\n - %s", strings.Join(registrationErrors, "\n - "))
230+
}
231+
232+
return nil
233+
}

0 commit comments

Comments
 (0)