@@ -21,6 +21,7 @@ import (
2121
2222 "github.com/firebase/genkit/go/ai"
2323 "github.com/openai/openai-go"
24+ "github.com/openai/openai-go/packages/param"
2425 "github.com/openai/openai-go/shared"
2526)
2627
@@ -78,16 +79,9 @@ func (g *ModelGenerator) WithMessages(messages []*ai.Message) *ModelGenerator {
7879 case ai .RoleSystem :
7980 oaiMessages = append (oaiMessages , openai .SystemMessage (content ))
8081 case ai .RoleModel :
81- oaiMessages = append (oaiMessages , openai .AssistantMessage (content ))
8282
8383 am := openai.ChatCompletionAssistantMessageParam {}
84- if msg .Content [0 ].Text != "" {
85- am .Content .OfArrayOfContentParts = append (am .Content .OfArrayOfContentParts , openai.ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion {
86- OfText : & openai.ChatCompletionContentPartTextParam {
87- Text : msg .Content [0 ].Text ,
88- },
89- })
90- }
84+ am .Content .OfString = param .NewOpt (content )
9185 toolCalls := convertToolCalls (msg .Content )
9286 if len (toolCalls ) > 0 {
9387 am .ToolCalls = (toolCalls )
@@ -267,11 +261,16 @@ func (g *ModelGenerator) generateStream(ctx context.Context, handleChunk func(co
267261
268262 var currentToolCall * ai.ToolRequest
269263 var currentArguments string
264+ var toolCallCollects []struct {
265+ toolCall * ai.ToolRequest
266+ args string
267+ }
270268
271269 for stream .Next () {
272270 chunk := stream .Current ()
273271 if len (chunk .Choices ) > 0 {
274272 choice := chunk .Choices [0 ]
273+ modelChunk := & ai.ModelResponseChunk {}
275274
276275 switch choice .FinishReason {
277276 case "tool_calls" , "stop" :
@@ -289,40 +288,60 @@ func (g *ModelGenerator) generateStream(ctx context.Context, handleChunk func(co
289288 // handle tool calls
290289 for _ , toolCall := range choice .Delta .ToolCalls {
291290 // first tool call (= current tool call is nil) contains the tool call name
291+ if currentToolCall != nil && toolCall .ID != "" && currentToolCall .Ref != toolCall .ID {
292+ toolCallCollects = append (toolCallCollects , struct {
293+ toolCall * ai.ToolRequest
294+ args string
295+ }{
296+ toolCall : currentToolCall ,
297+ args : currentArguments ,
298+ })
299+ currentToolCall = nil
300+ currentArguments = ""
301+ }
302+
292303 if currentToolCall == nil {
293304 currentToolCall = & ai.ToolRequest {
294305 Name : toolCall .Function .Name ,
306+ Ref : toolCall .ID ,
295307 }
296308 }
297309
298310 if toolCall .Function .Arguments != "" {
299311 currentArguments += toolCall .Function .Arguments
300312 }
313+
314+ modelChunk .Content = append (modelChunk .Content , ai .NewToolRequestPart (& ai.ToolRequest {
315+ Name : currentToolCall .Name ,
316+ Input : toolCall .Function .Arguments ,
317+ Ref : currentToolCall .Ref ,
318+ }))
301319 }
302320
303321 // when tool call is complete
304322 if choice .FinishReason == "tool_calls" && currentToolCall != nil {
305323 // parse accumulated arguments string
324+ for _ , toolcall := range toolCallCollects {
325+ toolcall .toolCall .Input = jsonStringToMap (toolcall .args )
326+ fullResponse .Message .Content = append (fullResponse .Message .Content , ai .NewToolRequestPart (toolcall .toolCall ))
327+ }
306328 if currentArguments != "" {
307329 currentToolCall .Input = jsonStringToMap (currentArguments )
308330 }
309-
310- fullResponse .Message .Content = []* ai.Part {ai .NewToolRequestPart (currentToolCall )}
311- return & fullResponse , nil
331+ fullResponse .Message .Content = append (fullResponse .Message .Content , ai .NewToolRequestPart (currentToolCall ))
312332 }
313333
314334 content := chunk .Choices [0 ].Delta .Content
315- modelChunk := & ai.ModelResponseChunk {
316- Content : []* ai.Part {ai .NewTextPart (content )},
335+ // when starting a tool call, the content is empty
336+ if content != "" {
337+ modelChunk .Content = append (modelChunk .Content , ai .NewTextPart (content ))
338+ fullResponse .Message .Content = append (fullResponse .Message .Content , modelChunk .Content ... )
317339 }
318340
319341 if err := handleChunk (ctx , modelChunk ); err != nil {
320342 return nil , fmt .Errorf ("callback error: %w" , err )
321343 }
322344
323- fullResponse .Message .Content = append (fullResponse .Message .Content , modelChunk .Content ... )
324-
325- // Update Usage
326345 fullResponse .Usage .InputTokens += int (chunk .Usage .PromptTokens )
327346 fullResponse .Usage .OutputTokens += int (chunk .Usage .CompletionTokens )
328347 fullResponse .Usage .TotalTokens += int (chunk .Usage .TotalTokens )
@@ -379,14 +398,17 @@ func (g *ModelGenerator) generateComplete(ctx context.Context) (*ai.ModelRespons
379398 Input : jsonStringToMap (toolCall .Function .Arguments ),
380399 }))
381400 }
401+
402+ // content and tool call may exist simultaneously
403+ if completion .Choices [0 ].Message .Content != "" {
404+ resp .Message .Content = append (resp .Message .Content , ai .NewTextPart (completion .Choices [0 ].Message .Content ))
405+ }
406+
382407 if len (toolRequestParts ) > 0 {
383- resp .Message .Content = toolRequestParts
408+ resp .Message .Content = append ( resp . Message . Content , toolRequestParts ... )
384409 return resp , nil
385410 }
386411
387- resp .Message .Content = []* ai.Part {
388- ai .NewTextPart (completion .Choices [0 ].Message .Content ),
389- }
390412 return resp , nil
391413}
392414
0 commit comments