Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 4 additions & 18 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,23 +383,9 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod
modelName = genOpts.Model.Name()
}

var dynamicTools []Tool
tools := make([]string, len(genOpts.Tools))
toolNames := make(map[string]bool)
for i, toolRef := range genOpts.Tools {
name := toolRef.Name()
// Redundant duplicate tool check with GenerateWithRequest otherwise we will panic when we register the dynamic tools.
if toolNames[name] {
return nil, core.NewError(core.INVALID_ARGUMENT, "ai.Generate: duplicate tool %q", name)
}
toolNames[name] = true
tools[i] = name
// Dynamic tools wouldn't have been registered by this point.
if LookupTool(r, name) == nil {
if tool, ok := toolRef.(Tool); ok {
dynamicTools = append(dynamicTools, tool)
}
}
toolNames, dynamicTools, err := resolveUniqueTools(r, genOpts.Tools)
if err != nil {
return nil, err
}

if len(dynamicTools) > 0 {
Expand Down Expand Up @@ -477,7 +463,7 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod
actionOpts := &GenerateActionOptions{
Model: modelName,
Messages: messages,
Tools: tools,
Tools: toolNames,
MaxTurns: genOpts.MaxTurns,
Config: genOpts.Config,
ToolChoice: genOpts.ToolChoice,
Expand Down
33 changes: 32 additions & 1 deletion go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,24 +144,31 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod
if modelRef, ok := execOpts.Model.(ModelRef); ok && execOpts.Config == nil {
execOpts.Config = modelRef.Config()
}

if execOpts.Config != nil {
actionOpts.Config = execOpts.Config
}

if len(execOpts.Documents) > 0 {
actionOpts.Docs = execOpts.Documents
}

if execOpts.ToolChoice != "" {
actionOpts.ToolChoice = execOpts.ToolChoice
}

if execOpts.Model != nil {
actionOpts.Model = execOpts.Model.Name()
}

if execOpts.MaxTurns != 0 {
actionOpts.MaxTurns = execOpts.MaxTurns
}

if execOpts.ReturnToolRequests != nil {
actionOpts.ReturnToolRequests = *execOpts.ReturnToolRequests
}

if execOpts.MessagesFn != nil {
m, err := buildVariables(execOpts.Input)
if err != nil {
Expand All @@ -180,7 +187,31 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod
}
}

return GenerateWithRequest(ctx, p.registry, actionOpts, execOpts.Middleware, execOpts.Stream)
toolRefs := execOpts.Tools
if len(toolRefs) == 0 {
toolRefs = make([]ToolRef, 0, len(actionOpts.Tools))
for _, toolName := range actionOpts.Tools {
toolRefs = append(toolRefs, ToolName(toolName))
}
}

toolNames, newTools, err := resolveUniqueTools(p.registry, toolRefs)
if err != nil {
return nil, err
}
actionOpts.Tools = toolNames

r := p.registry
if len(newTools) > 0 {
if !r.IsChild() {
r = r.NewChild()
}
for _, t := range newTools {
t.Register(p.registry)
}
}

return GenerateWithRequest(ctx, r, actionOpts, execOpts.Middleware, execOpts.Stream)
}

// Render renders the prompt template based on user input.
Expand Down
55 changes: 55 additions & 0 deletions go/ai/prompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,61 @@ func TestValidPrompt(t *testing.T) {
},
},
},
{
name: "execute with tools overriding prompt-level tools",
model: model,
config: &GenerationCommonConfig{Temperature: 11},
inputType: HelloPromptInput{},
systemText: "say hello",
promptText: "my name is foo",
tools: []ToolRef{testTool(reg, "promptTool")},
input: HelloPromptInput{Name: "foo"},
executeOptions: []PromptExecuteOption{
WithInput(HelloPromptInput{Name: "foo"}),
WithTools(testTool(reg, "executeOverrideTool")),
},
wantTextOutput: "Echo: system: tool: say hello; my name is foo; ; Bar; ; config: {\n \"temperature\": 11\n}; context: null",
wantGenerated: &ModelRequest{
Config: &GenerationCommonConfig{
Temperature: 11,
},
Output: &ModelOutputConfig{
ContentType: "text/plain",
},
ToolChoice: "required",
Messages: []*Message{
{
Role: RoleSystem,
Content: []*Part{NewTextPart("say hello")},
},
{
Role: RoleUser,
Content: []*Part{NewTextPart("my name is foo")},
},
{
Role: RoleModel,
Content: []*Part{NewToolRequestPart(&ToolRequest{Name: "executeOverrideTool", Input: map[string]any{"Test": "Bar"}})},
},
{
Role: RoleTool,
Content: []*Part{NewToolResponsePart(&ToolResponse{Output: "Bar"})},
},
},
Tools: []*ToolDefinition{
{
Name: "executeOverrideTool",
Description: "use when need to execute a test",
InputSchema: map[string]any{
"additionalProperties": bool(false),
"properties": map[string]any{"Test": map[string]any{"type": string("string")}},
"required": []any{string("Test")},
"type": string("object"),
},
OutputSchema: map[string]any{"type": string("string")},
},
},
},
},
}

cmpPart := func(a, b *Part) bool {
Expand Down
24 changes: 24 additions & 0 deletions go/ai/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,27 @@ func (t *tool) Restart(p *Part, opts *RestartOptions) *Part {

return newToolReq
}

// resolveUniqueTools resolves the list of tool refs to a list of all tool names and new tools that must be registered.
// Returns an error if there are tool refs with duplicate names.
func resolveUniqueTools(r api.Registry, toolRefs []ToolRef) (toolNames []string, newTools []Tool, err error) {
toolMap := make(map[string]bool)

for _, toolRef := range toolRefs {
name := toolRef.Name()

if toolMap[name] {
return nil, nil, core.NewError(core.INVALID_ARGUMENT, "duplicate tool %q", name)
}
toolMap[name] = true
toolNames = append(toolNames, name)

if LookupTool(r, name) == nil {
if tool, ok := toolRef.(Tool); ok {
newTools = append(newTools, tool)
}
}
}

return toolNames, newTools, nil
}
Loading