Skip to content

Commit ead0ec0

Browse files
committed
fix(anthropic): prevent streaming tool calling responses when internal execution is enabled
- For streaming, block tool calling ChatResponse unless internal execution is disabled - Add streaming validation test and debug logging Related to spring-projects#3640 Signed-off-by: Christian Tzolov <[email protected]>
1 parent 906299c commit ead0ec0

File tree

5 files changed

+86
-32
lines changed

5 files changed

+86
-32
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -260,26 +260,33 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
260260
Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
261261
ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage);
262262

263-
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse) && chatResponse.hasFinishReasons(Set.of("tool_use"))) {
264-
// FIXME: bounded elastic needs to be used since tool calling
265-
// is currently only synchronous
266-
return Flux.defer(() -> {
267-
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
268-
if (toolExecutionResult.returnDirect()) {
269-
// Return tool execution result directly to the client.
270-
return Flux.just(ChatResponse.builder().from(chatResponse)
271-
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
272-
.build());
273-
}
274-
else {
275-
// Send the tool execution result back to the model.
276-
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
277-
chatResponse);
278-
}
279-
}).subscribeOn(Schedulers.boundedElastic());
280-
}
263+
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)) {
264+
265+
if (chatResponse.hasFinishReasons(Set.of("tool_use"))) {
266+
// FIXME: bounded elastic needs to be used since tool calling
267+
// is currently only synchronous
268+
return Flux.defer(() -> {
269+
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
270+
if (toolExecutionResult.returnDirect()) {
271+
// Return tool execution result directly to the client.
272+
return Flux.just(ChatResponse.builder().from(chatResponse)
273+
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
274+
.build());
275+
}
276+
else {
277+
// Send the tool execution result back to the model.
278+
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
279+
chatResponse);
280+
}
281+
}).subscribeOn(Schedulers.boundedElastic());
282+
} else {
283+
return Mono.empty();
284+
}
281285

282-
return Mono.just(chatResponse);
286+
} else {
287+
// If internal tool execution is not required, just return the chat response.
288+
return Mono.just(chatResponse);
289+
}
283290
})
284291
.doOnError(observation::error)
285292
.doFinally(s -> observation.stop())

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,8 @@
2424
import java.util.function.Consumer;
2525
import java.util.function.Predicate;
2626

27-
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
28-
import com.fasterxml.jackson.annotation.JsonInclude;
29-
import com.fasterxml.jackson.annotation.JsonInclude.Include;
30-
import com.fasterxml.jackson.annotation.JsonProperty;
31-
import com.fasterxml.jackson.annotation.JsonSubTypes;
32-
import com.fasterxml.jackson.annotation.JsonTypeInfo;
33-
import reactor.core.publisher.Flux;
34-
import reactor.core.publisher.Mono;
35-
27+
import org.slf4j.Logger;
28+
import org.slf4j.LoggerFactory;
3629
import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder;
3730
import org.springframework.ai.model.ApiKey;
3831
import org.springframework.ai.model.ChatModelDescription;
@@ -52,6 +45,16 @@
5245
import org.springframework.web.client.RestClient;
5346
import org.springframework.web.reactive.function.client.WebClient;
5447

48+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
49+
import com.fasterxml.jackson.annotation.JsonInclude;
50+
import com.fasterxml.jackson.annotation.JsonInclude.Include;
51+
import com.fasterxml.jackson.annotation.JsonProperty;
52+
import com.fasterxml.jackson.annotation.JsonSubTypes;
53+
import com.fasterxml.jackson.annotation.JsonTypeInfo;
54+
55+
import reactor.core.publisher.Flux;
56+
import reactor.core.publisher.Mono;
57+
5558
/**
5659
* The Anthropic API client.
5760
*
@@ -67,6 +70,8 @@
6770
*/
6871
public final class AnthropicApi {
6972

73+
private static final Logger logger = LoggerFactory.getLogger(AnthropicApi.class);
74+
7075
public static Builder builder() {
7176
return new Builder();
7277
}
@@ -222,6 +227,9 @@ public Flux<ChatCompletionResponse> chatCompletionStream(ChatCompletionRequest c
222227
.filter(event -> event.type() != EventType.PING)
223228
// Detect if the chunk is part of a streaming function call.
224229
.map(event -> {
230+
231+
logger.debug("Received event: {}", event);
232+
225233
if (this.streamHelper.isToolUseStart(event)) {
226234
isInsideTool.set(true);
227235
}

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package org.springframework.ai.anthropic.client;
1818

19+
import static org.assertj.core.api.Assertions.assertThat;
20+
1921
import java.io.IOException;
2022
import java.net.URL;
2123
import java.util.Arrays;
@@ -29,8 +31,6 @@
2931
import org.junit.jupiter.params.provider.ValueSource;
3032
import org.slf4j.Logger;
3133
import org.slf4j.LoggerFactory;
32-
import reactor.core.publisher.Flux;
33-
3434
import org.springframework.ai.anthropic.AnthropicChatOptions;
3535
import org.springframework.ai.anthropic.AnthropicTestConfiguration;
3636
import org.springframework.ai.anthropic.api.AnthropicApi;
@@ -41,7 +41,9 @@
4141
import org.springframework.ai.chat.model.ChatResponse;
4242
import org.springframework.ai.converter.BeanOutputConverter;
4343
import org.springframework.ai.converter.ListOutputConverter;
44+
import org.springframework.ai.model.tool.ToolCallingChatOptions;
4445
import org.springframework.ai.test.CurlyBracketEscaper;
46+
import org.springframework.ai.tool.annotation.Tool;
4547
import org.springframework.ai.tool.function.FunctionToolCallback;
4648
import org.springframework.beans.factory.annotation.Autowired;
4749
import org.springframework.beans.factory.annotation.Value;
@@ -53,7 +55,7 @@
5355
import org.springframework.test.context.ActiveProfiles;
5456
import org.springframework.util.MimeTypeUtils;
5557

56-
import static org.assertj.core.api.Assertions.assertThat;
58+
import reactor.core.publisher.Flux;
5759

5860
@SpringBootTest(classes = AnthropicTestConfiguration.class, properties = "spring.ai.retry.on-http-codes=429")
5961
@EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".+")
@@ -343,4 +345,39 @@ record ActorsFilms(String actor, List<String> movies) {
343345

344346
}
345347

348+
@ParameterizedTest(name = "{0} : {displayName} ")
349+
@ValueSource(strings = { "claude-3-7-sonnet-latest", "claude-sonnet-4-0" })
350+
void streamToolCallingResponseShouldNotContainToolCallMessages(String modelName) {
351+
352+
ChatClient chatClient = ChatClient.builder(this.chatModel).build();
353+
354+
Flux<ChatResponse> responses = chatClient.prompt()
355+
.options(ToolCallingChatOptions.builder().model(modelName).build())
356+
.tools(new MyTools())
357+
.user("Get current weather in Amsterdam and Paris")
358+
// .user("Get current weather in Amsterdam. Please don't explain that you will
359+
// call tools.")
360+
.stream()
361+
.chatResponse();
362+
363+
List<ChatResponse> chatResponses = responses.collectList().block();
364+
365+
assertThat(chatResponses).isNotEmpty();
366+
367+
// Verify that none of the ChatResponse objects have tool calls
368+
chatResponses.forEach(chatResponse -> {
369+
logger.info("ChatResponse Results: {}", chatResponse.getResults());
370+
assertThat(chatResponse.hasToolCalls()).isFalse();
371+
});
372+
}
373+
374+
public static class MyTools {
375+
376+
@Tool(description = "Get the current weather forecast by city name")
377+
String getCurrentDateTime(String cityName) {
378+
return "For " + cityName + " Weather is hot and sunny with a temperature of 20 degrees";
379+
}
380+
381+
}
382+
346383
}

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientMethodInvokingFunctionCallbackIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ void streamingParameterLessTool(String modelName) {
290290
.map(cr -> cr.getResult().getOutput().getText())
291291
.collect(Collectors.joining());
292292

293-
assertThat(content).contains("20 degrees");
293+
assertThat(content).contains("20");
294294
}
295295

296296
public static class ParameterLessTools {

models/spring-ai-anthropic/src/test/resources/application-logging-test.properties

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@
1515
#
1616

1717
logging.level.org.springframework.ai.chat.client.advisor=DEBUG
18+
19+
logging.level.org.springframework.ai.anthropic.api.AnthropicApi=DEBUG

0 commit comments

Comments
 (0)