Skip to content

Commit cd264f4

Browse files
authored
(fix/go/compat_oai): update OpenAI message transformation (#3536)
1 parent c11397c commit cd264f4

File tree

1 file changed

+42
-20
lines changed

1 file changed

+42
-20
lines changed

go/plugins/compat_oai/generate.go

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121

2222
"github.com/firebase/genkit/go/ai"
2323
"github.com/openai/openai-go"
24+
"github.com/openai/openai-go/packages/param"
2425
"github.com/openai/openai-go/shared"
2526
)
2627

@@ -78,16 +79,9 @@ func (g *ModelGenerator) WithMessages(messages []*ai.Message) *ModelGenerator {
7879
case ai.RoleSystem:
7980
oaiMessages = append(oaiMessages, openai.SystemMessage(content))
8081
case ai.RoleModel:
81-
oaiMessages = append(oaiMessages, openai.AssistantMessage(content))
8282

8383
am := openai.ChatCompletionAssistantMessageParam{}
84-
if msg.Content[0].Text != "" {
85-
am.Content.OfArrayOfContentParts = append(am.Content.OfArrayOfContentParts, openai.ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion{
86-
OfText: &openai.ChatCompletionContentPartTextParam{
87-
Text: msg.Content[0].Text,
88-
},
89-
})
90-
}
84+
am.Content.OfString = param.NewOpt(content)
9185
toolCalls := convertToolCalls(msg.Content)
9286
if len(toolCalls) > 0 {
9387
am.ToolCalls = (toolCalls)
@@ -267,11 +261,16 @@ func (g *ModelGenerator) generateStream(ctx context.Context, handleChunk func(co
267261

268262
var currentToolCall *ai.ToolRequest
269263
var currentArguments string
264+
var toolCallCollects []struct {
265+
toolCall *ai.ToolRequest
266+
args string
267+
}
270268

271269
for stream.Next() {
272270
chunk := stream.Current()
273271
if len(chunk.Choices) > 0 {
274272
choice := chunk.Choices[0]
273+
modelChunk := &ai.ModelResponseChunk{}
275274

276275
switch choice.FinishReason {
277276
case "tool_calls", "stop":
@@ -289,40 +288,60 @@ func (g *ModelGenerator) generateStream(ctx context.Context, handleChunk func(co
289288
// handle tool calls
290289
for _, toolCall := range choice.Delta.ToolCalls {
291290
// first tool call (= current tool call is nil) contains the tool call name
291+
if currentToolCall != nil && toolCall.ID != "" && currentToolCall.Ref != toolCall.ID {
292+
toolCallCollects = append(toolCallCollects, struct {
293+
toolCall *ai.ToolRequest
294+
args string
295+
}{
296+
toolCall: currentToolCall,
297+
args: currentArguments,
298+
})
299+
currentToolCall = nil
300+
currentArguments = ""
301+
}
302+
292303
if currentToolCall == nil {
293304
currentToolCall = &ai.ToolRequest{
294305
Name: toolCall.Function.Name,
306+
Ref: toolCall.ID,
295307
}
296308
}
297309

298310
if toolCall.Function.Arguments != "" {
299311
currentArguments += toolCall.Function.Arguments
300312
}
313+
314+
modelChunk.Content = append(modelChunk.Content, ai.NewToolRequestPart(&ai.ToolRequest{
315+
Name: currentToolCall.Name,
316+
Input: toolCall.Function.Arguments,
317+
Ref: currentToolCall.Ref,
318+
}))
301319
}
302320

303321
// when tool call is complete
304322
if choice.FinishReason == "tool_calls" && currentToolCall != nil {
305323
// parse accumulated arguments string
324+
for _, toolcall := range toolCallCollects {
325+
toolcall.toolCall.Input = jsonStringToMap(toolcall.args)
326+
fullResponse.Message.Content = append(fullResponse.Message.Content, ai.NewToolRequestPart(toolcall.toolCall))
327+
}
306328
if currentArguments != "" {
307329
currentToolCall.Input = jsonStringToMap(currentArguments)
308330
}
309-
310-
fullResponse.Message.Content = []*ai.Part{ai.NewToolRequestPart(currentToolCall)}
311-
return &fullResponse, nil
331+
fullResponse.Message.Content = append(fullResponse.Message.Content, ai.NewToolRequestPart(currentToolCall))
312332
}
313333

314334
content := chunk.Choices[0].Delta.Content
315-
modelChunk := &ai.ModelResponseChunk{
316-
Content: []*ai.Part{ai.NewTextPart(content)},
335+
// when starting a tool call, the content is empty
336+
if content != "" {
337+
modelChunk.Content = append(modelChunk.Content, ai.NewTextPart(content))
338+
fullResponse.Message.Content = append(fullResponse.Message.Content, modelChunk.Content...)
317339
}
318340

319341
if err := handleChunk(ctx, modelChunk); err != nil {
320342
return nil, fmt.Errorf("callback error: %w", err)
321343
}
322344

323-
fullResponse.Message.Content = append(fullResponse.Message.Content, modelChunk.Content...)
324-
325-
// Update Usage
326345
fullResponse.Usage.InputTokens += int(chunk.Usage.PromptTokens)
327346
fullResponse.Usage.OutputTokens += int(chunk.Usage.CompletionTokens)
328347
fullResponse.Usage.TotalTokens += int(chunk.Usage.TotalTokens)
@@ -379,14 +398,17 @@ func (g *ModelGenerator) generateComplete(ctx context.Context) (*ai.ModelRespons
379398
Input: jsonStringToMap(toolCall.Function.Arguments),
380399
}))
381400
}
401+
402+
// content and tool call may exist simultaneously
403+
if completion.Choices[0].Message.Content != "" {
404+
resp.Message.Content = append(resp.Message.Content, ai.NewTextPart(completion.Choices[0].Message.Content))
405+
}
406+
382407
if len(toolRequestParts) > 0 {
383-
resp.Message.Content = toolRequestParts
408+
resp.Message.Content = append(resp.Message.Content, toolRequestParts...)
384409
return resp, nil
385410
}
386411

387-
resp.Message.Content = []*ai.Part{
388-
ai.NewTextPart(completion.Choices[0].Message.Content),
389-
}
390412
return resp, nil
391413
}
392414

0 commit comments

Comments
 (0)