Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 42 additions & 20 deletions go/plugins/compat_oai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/firebase/genkit/go/ai"
"github.com/openai/openai-go"
"github.com/openai/openai-go/packages/param"
"github.com/openai/openai-go/shared"
)

Expand Down Expand Up @@ -78,16 +79,9 @@ func (g *ModelGenerator) WithMessages(messages []*ai.Message) *ModelGenerator {
case ai.RoleSystem:
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
case ai.RoleModel:
oaiMessages = append(oaiMessages, openai.AssistantMessage(content))

am := openai.ChatCompletionAssistantMessageParam{}
if msg.Content[0].Text != "" {
am.Content.OfArrayOfContentParts = append(am.Content.OfArrayOfContentParts, openai.ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion{
OfText: &openai.ChatCompletionContentPartTextParam{
Text: msg.Content[0].Text,
},
})
}
am.Content.OfString = param.NewOpt(content)
toolCalls := convertToolCalls(msg.Content)
if len(toolCalls) > 0 {
am.ToolCalls = (toolCalls)
Expand Down Expand Up @@ -267,11 +261,16 @@ func (g *ModelGenerator) generateStream(ctx context.Context, handleChunk func(co

var currentToolCall *ai.ToolRequest
var currentArguments string
var toolCallCollects []struct {
toolCall *ai.ToolRequest
args string
}

for stream.Next() {
chunk := stream.Current()
if len(chunk.Choices) > 0 {
choice := chunk.Choices[0]
modelChunk := &ai.ModelResponseChunk{}

switch choice.FinishReason {
case "tool_calls", "stop":
Expand All @@ -289,40 +288,60 @@ func (g *ModelGenerator) generateStream(ctx context.Context, handleChunk func(co
// handle tool calls
for _, toolCall := range choice.Delta.ToolCalls {
// first tool call (= current tool call is nil) contains the tool call name
if currentToolCall != nil && toolCall.ID != "" && currentToolCall.Ref != toolCall.ID {
toolCallCollects = append(toolCallCollects, struct {
toolCall *ai.ToolRequest
args string
}{
toolCall: currentToolCall,
args: currentArguments,
})
currentToolCall = nil
currentArguments = ""
}

if currentToolCall == nil {
currentToolCall = &ai.ToolRequest{
Name: toolCall.Function.Name,
Ref: toolCall.ID,
}
}

if toolCall.Function.Arguments != "" {
currentArguments += toolCall.Function.Arguments
}

modelChunk.Content = append(modelChunk.Content, ai.NewToolRequestPart(&ai.ToolRequest{
Name: currentToolCall.Name,
Input: toolCall.Function.Arguments,
Ref: currentToolCall.Ref,
}))
}

// when tool call is complete
if choice.FinishReason == "tool_calls" && currentToolCall != nil {
// parse accumulated arguments string
for _, toolcall := range toolCallCollects {
toolcall.toolCall.Input = jsonStringToMap(toolcall.args)
fullResponse.Message.Content = append(fullResponse.Message.Content, ai.NewToolRequestPart(toolcall.toolCall))
}
if currentArguments != "" {
currentToolCall.Input = jsonStringToMap(currentArguments)
}

fullResponse.Message.Content = []*ai.Part{ai.NewToolRequestPart(currentToolCall)}
return &fullResponse, nil
fullResponse.Message.Content = append(fullResponse.Message.Content, ai.NewToolRequestPart(currentToolCall))
}

content := chunk.Choices[0].Delta.Content
modelChunk := &ai.ModelResponseChunk{
Content: []*ai.Part{ai.NewTextPart(content)},
// when starting a tool call, the content is empty
if content != "" {
modelChunk.Content = append(modelChunk.Content, ai.NewTextPart(content))
fullResponse.Message.Content = append(fullResponse.Message.Content, modelChunk.Content...)
}

if err := handleChunk(ctx, modelChunk); err != nil {
return nil, fmt.Errorf("callback error: %w", err)
}

fullResponse.Message.Content = append(fullResponse.Message.Content, modelChunk.Content...)

// Update Usage
fullResponse.Usage.InputTokens += int(chunk.Usage.PromptTokens)
fullResponse.Usage.OutputTokens += int(chunk.Usage.CompletionTokens)
fullResponse.Usage.TotalTokens += int(chunk.Usage.TotalTokens)
Expand Down Expand Up @@ -379,14 +398,17 @@ func (g *ModelGenerator) generateComplete(ctx context.Context) (*ai.ModelRespons
Input: jsonStringToMap(toolCall.Function.Arguments),
}))
}

// content and tool call may exist simultaneously
if completion.Choices[0].Message.Content != "" {
resp.Message.Content = append(resp.Message.Content, ai.NewTextPart(completion.Choices[0].Message.Content))
}

if len(toolRequestParts) > 0 {
resp.Message.Content = toolRequestParts
resp.Message.Content = append(resp.Message.Content, toolRequestParts...)
return resp, nil
}

resp.Message.Content = []*ai.Part{
ai.NewTextPart(completion.Choices[0].Message.Content),
}
return resp, nil
}

Expand Down
Loading