Skip to content

Commit 3e362e2

Browse files
committed
fix: handle tool calls in retryStream
1 parent 7962414 commit 3e362e2

File tree

1 file changed

+67
-155
lines changed

1 file changed

+67
-155
lines changed

backend/cmd/chat/chat.go

Lines changed: 67 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,6 @@ func retryStream(w http.ResponseWriter, r *http.Request) {
273273
ReasoningEffort: provider.ReasoningEffort(reasoningSetting),
274274
}
275275

276-
// Save assistant message after streaming completes
277276
responseMessage := Message{
278277
ID: -1,
279278
ConvID: req.ConversationID,
@@ -285,6 +284,9 @@ func retryStream(w http.ResponseWriter, r *http.Request) {
285284
Children: []int{},
286285
}
287286

287+
var toolCalls []tools.ToolCall
288+
var isToolsUsed bool
289+
288290
// Stream assistant content
289291
completion, err := providerClient.SendChatCompletionStreamRequest(providerParams, w)
290292
if err != nil {
@@ -295,6 +297,8 @@ func retryStream(w http.ResponseWriter, r *http.Request) {
295297
} else {
296298
responseMessage.Content = completion.Content
297299
responseMessage.Reasoning = completion.Reasoning
300+
toolCalls = completion.ToolCalls
301+
isToolsUsed = len(toolCalls) > 0
298302
}
299303

300304
responseID, saveErr := saveMessage(responseMessage)
@@ -306,6 +310,68 @@ func retryStream(w http.ResponseWriter, r *http.Request) {
306310
parent.Children = append(parent.Children, responseID)
307311
}
308312

313+
for len(toolCalls) > 0 {
314+
315+
toolCall := toolCalls[0]
316+
317+
providerParams.Messages = append(providerParams.Messages, provider.SimpleMessage{
318+
Role: "assistant",
319+
ToolCall: toolCall,
320+
})
321+
322+
toolCall.MessageID = responseMessage.ID
323+
toolCall.ConvID = req.ConversationID
324+
325+
output := tools.ExecuteToolCall(toolCall)
326+
toolCall.Output = output
327+
328+
chunk, _ := json.Marshal(provider.StreamChunk{
329+
ToolCall: toolCall,
330+
})
331+
fmt.Fprintf(w, "data: %s\n\n", chunk)
332+
flusher.Flush()
333+
334+
// Append tool result message to context for continued completion
335+
providerParams.Messages = append(providerParams.Messages, provider.SimpleMessage{
336+
Role: "tool",
337+
ToolCall: tools.ToolCall{
338+
ID: toolCall.ID,
339+
ReferenceID: toolCall.ReferenceID,
340+
Name: toolCall.Name,
341+
Output: output,
342+
},
343+
})
344+
345+
toolCalls = toolCalls[1:]
346+
if len(toolCalls) == 0 {
347+
completion, err = providerClient.SendChatCompletionStreamRequest(providerParams, w)
348+
if err != nil {
349+
log.Error("Error streaming chat completion after tool call", "err", err)
350+
fmt.Fprintf(w, "event: error\ndata: {\"error\": \"%s\"}\n\n", err.Error())
351+
flusher.Flush()
352+
responseMessage.Error = err.Error()
353+
break
354+
}
355+
toolCalls = append(toolCalls, completion.ToolCalls...)
356+
}
357+
358+
// Accumulate reasoning for all tool calls
359+
if responseMessage.Reasoning != "" {
360+
responseMessage.Reasoning += " \n`using tool:" + toolCall.Name + "`\n " + completion.Reasoning
361+
}
362+
}
363+
364+
// Update assistant message with full content after all tool calls
365+
if isToolsUsed {
366+
if err == nil {
367+
responseMessage.Content = completion.Content
368+
}
369+
_, err = updateMessage(responseMessage.ID, responseMessage)
370+
if err != nil {
371+
log.Error("Error updating assistant message after tool calls", "err", err)
372+
}
373+
}
374+
309375
// Send completion event with the new assistant message id
310376
completionData := provider.StreamComplete{
311377
UserMessageID: parent.ID,
@@ -344,157 +410,3 @@ func update(W http.ResponseWriter, R *http.Request) {
344410

345411
utils.RespondWithJSON(W, &response, http.StatusOK)
346412
}
347-
348-
// // Temporarily disabled
349-
// func chat(w http.ResponseWriter, r *http.Request) {
350-
// var req Request
351-
// err := utils.ExtractJSONBody(r, &req)
352-
// if err != nil || req.ConversationID == "" || req.Content == "" {
353-
// log.Error("Error unmarshalling request body", "err", err)
354-
// http.Error(w, "Invalid request body", http.StatusBadRequest)
355-
// return
356-
// }
357-
358-
// // find or create conversation
359-
// convID := req.ConversationID
360-
// err = repo.touchConversation(req.ConversationID)
361-
// if err != nil {
362-
// conv := newConversation("admin")
363-
// if err = repo.saveConversation(conv); err != nil {
364-
// log.Error("Error creating conversation", "err", err)
365-
// http.Error(w, fmt.Sprintf("Error creating conversation: %v", err), http.StatusInternalServerError)
366-
// return
367-
// }
368-
// convID = conv.ID
369-
// }
370-
371-
// userMessage := Message{
372-
// ID: -1,
373-
// ConvID: convID,
374-
// Role: "user",
375-
// Content: req.Content,
376-
// ParentID: req.ParentID,
377-
// Children: []int{},
378-
// Attachment: req.Attachment,
379-
// }
380-
381-
// userMessage.ID, err = saveMessage(userMessage)
382-
// if err != nil {
383-
// log.Error("Error saving user message", "err", err)
384-
// http.Error(w, fmt.Sprintf("Error saving user message: %v", err), http.StatusInternalServerError)
385-
// return
386-
// }
387-
388-
// ctx := buildContext(convID, userMessage.ID)
389-
// reasoningSetting, _ := getSetting("reasoningEffort")
390-
391-
// providerParams := provider.ProviderRequestParams{
392-
// Messages: ctx,
393-
// Model: req.Model,
394-
// ReasoningEffort: provider.ReasoningEffort(reasoningSetting),
395-
// }
396-
397-
// // send to provider
398-
// completion, err := providerClient.SendChatCompletionRequest(providerParams)
399-
// if err != nil {
400-
// log.Error("Error sending chat completion request", "err", err)
401-
// http.Error(w, fmt.Sprintf("Chat completion error: %v", err), http.StatusInternalServerError)
402-
// return
403-
// }
404-
405-
// responseMessage := Message{
406-
// ID: -1,
407-
// ConvID: convID,
408-
// Role: "assistant",
409-
// Model: req.Model,
410-
// Content: completion.Choices[0].Message.Content,
411-
// Reasoning: completion.Choices[0].Message.Reasoning,
412-
// ParentID: userMessage.ID,
413-
// Children: []int{},
414-
// }
415-
416-
// responseMessage.ID, err = saveMessage(responseMessage)
417-
// if err != nil {
418-
// log.Error("Error saving response message", "err", err)
419-
// http.Error(w, fmt.Sprintf("Error saving response message: %v", err), http.StatusInternalServerError)
420-
// return
421-
// }
422-
423-
// response := &Response{
424-
// Messages: make(map[int]*Message),
425-
// }
426-
// response.Messages[userMessage.ID] = &userMessage
427-
// response.Messages[responseMessage.ID] = &responseMessage
428-
429-
// utils.RespondWithJSON(w, &response, http.StatusOK)
430-
// }
431-
432-
// // Temporarily disabled
433-
// func retry(w http.ResponseWriter, r *http.Request) {
434-
// var req Retry
435-
// err := utils.ExtractJSONBody(r, &req)
436-
// if err != nil || req.ConversationID == "" {
437-
// log.Error("Error unmarshalling request body", "err", err)
438-
// http.Error(w, "Invalid request body", http.StatusBadRequest)
439-
// return
440-
// }
441-
442-
// err = repo.touchConversation(req.ConversationID)
443-
// if err != nil {
444-
// log.Error("Error retrieving conversation", "err", err)
445-
// http.Error(w, fmt.Sprintf("Error retrieving conversation: %v", err), http.StatusNotFound)
446-
// return
447-
// }
448-
449-
// parent, err := getMessage(req.ParentID)
450-
// if err != nil || parent.Role != "user" {
451-
// log.Error("Error retrieving parent message or invalid role", "err", err)
452-
// http.Error(w, "Invalid parent message", http.StatusBadRequest)
453-
// return
454-
// }
455-
456-
// ctx := buildContext(req.ConversationID, parent.ID)
457-
// reasoningSetting, _ := getSetting("reasoningEffort")
458-
459-
// providerParams := provider.ProviderRequestParams{
460-
// Messages: ctx,
461-
// Model: req.Model,
462-
// ReasoningEffort: provider.ReasoningEffort(reasoningSetting),
463-
// }
464-
465-
// completion, err := providerClient.SendChatCompletionRequest(providerParams)
466-
// if err != nil {
467-
// log.Error("Error sending chat completion request", "err", err)
468-
// http.Error(w, fmt.Sprintf("Chat completion error: %v", err), http.StatusInternalServerError)
469-
// return
470-
// }
471-
472-
// responseMessage := Message{
473-
// ID: -1,
474-
// ConvID: req.ConversationID,
475-
// Model: req.Model,
476-
// Role: "assistant",
477-
// Content: completion.Choices[0].Message.Content,
478-
// Reasoning: completion.Choices[0].Message.Reasoning,
479-
// ParentID: parent.ID,
480-
// Children: []int{},
481-
// }
482-
483-
// responseMessage.ID, err = saveMessage(responseMessage)
484-
// if err != nil {
485-
// log.Error("Error saving message", "err", err)
486-
// http.Error(w, fmt.Sprintf("Error saving message: %v", err), http.StatusInternalServerError)
487-
// return
488-
// }
489-
490-
// parent.Children = append(parent.Children, responseMessage.ID)
491-
492-
// response := &Response{
493-
// Messages: make(map[int]*Message),
494-
// }
495-
496-
// response.Messages[parent.ID] = parent
497-
// response.Messages[responseMessage.ID] = &responseMessage
498-
499-
// utils.RespondWithJSON(w, &response, http.StatusOK)
500-
// }

0 commit comments

Comments
 (0)