Skip to content

Propagate reactive Context to AsyncMcpToolCallback #3640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;
Expand Down Expand Up @@ -120,7 +121,7 @@ public String call(String functionInput) {
new IllegalStateException("Error calling tool: " + response.content()));
}
return ModelOptionsUtils.toJsonString(response.content());
}).block();
}).contextWrite(ctx -> ctx.putAll(ToolCallReactiveContextHolder.getContext())).block();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import org.springframework.ai.content.Media;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
Expand Down Expand Up @@ -263,8 +264,14 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse) && chatResponse.hasFinishReasons(Set.of("tool_use"))) {
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
return Flux.defer(() -> {
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
return Flux.deferContextual((ctx) -> {
ToolExecutionResult toolExecutionResult;
try {
ToolCallReactiveContextHolder.setContext(ctx);
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
} finally {
ToolCallReactiveContextHolder.clearContext();
}
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
return Flux.just(ChatResponse.builder().from(chatResponse)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.support.UsageCalculator;
import org.springframework.ai.tool.definition.ToolDefinition;
Expand Down Expand Up @@ -380,8 +381,15 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)) {
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
return Flux.defer(() -> {
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
return Flux.deferContextual((ctx) -> {
ToolExecutionResult toolExecutionResult;
try {
ToolCallReactiveContextHolder.setContext(ctx);
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
}
finally {
ToolCallReactiveContextHolder.clearContext();
}
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
return Flux.just(ChatResponse.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -681,8 +682,15 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse perviousCh

// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
return Flux.defer(() -> {
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
return Flux.deferContextual((ctx) -> {
ToolExecutionResult toolExecutionResult;
try {
ToolCallReactiveContextHolder.setContext(ctx);
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
}
finally {
ToolCallReactiveContextHolder.clearContext();
}

if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.support.UsageCalculator;
import org.springframework.ai.tool.definition.ToolDefinition;
Expand Down Expand Up @@ -286,10 +287,16 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
// @formatter:off
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
return Flux.defer(() -> {
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
return Flux.deferContextual((ctx) -> {
ToolExecutionResult toolExecutionResult;
try {
ToolCallReactiveContextHolder.setContext(ctx);
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
} finally {
ToolCallReactiveContextHolder.clearContext();
}
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
return Flux.just(ChatResponse.builder().from(response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.http.ResponseEntity;
Expand Down Expand Up @@ -370,10 +371,16 @@ public Flux<ChatResponse> stream(Prompt prompt) {

Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response)) {
return Flux.defer(() -> {
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
var toolExecutionResult = this.toolCallingManager.executeToolCalls(requestPrompt, response);
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
return Flux.deferContextual((ctx) -> {
ToolExecutionResult toolExecutionResult;
try {
ToolCallReactiveContextHolder.setContext(ctx);
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
} finally {
ToolCallReactiveContextHolder.clearContext();
}
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
return Flux.just(ChatResponse.builder().from(response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.support.UsageCalculator;
import org.springframework.ai.tool.definition.ToolDefinition;
Expand Down Expand Up @@ -316,8 +317,14 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
return Flux.defer(() -> {
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
return Flux.deferContextual((ctx) -> {
ToolExecutionResult toolExecutionResult;
try {
ToolCallReactiveContextHolder.setContext(ctx);
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
} finally {
ToolCallReactiveContextHolder.clearContext();
}
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
return Flux.just(ChatResponse.builder().from(response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaApi.ChatRequest;
import org.springframework.ai.ollama.api.OllamaApi.Message.Role;
Expand Down Expand Up @@ -351,8 +352,14 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCh
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
return Flux.defer(() -> {
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
return Flux.deferContextual((ctx) -> {
ToolExecutionResult toolExecutionResult;
try {
ToolCallReactiveContextHolder.setContext(ctx);
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
} finally {
ToolCallReactiveContextHolder.clearContext();
}
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
return Flux.just(ChatResponse.builder().from(response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion.Choice;
Expand Down Expand Up @@ -363,10 +364,16 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
// @formatter:off
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
return Flux.defer(() -> {
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
return Flux.deferContextual((ctx) -> {
ToolExecutionResult toolExecutionResult;
try {
ToolCallReactiveContextHolder.setContext(ctx);
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
} finally {
ToolCallReactiveContextHolder.clearContext();
}
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
return Flux.just(ChatResponse.builder().from(response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.support.UsageCalculator;
import org.springframework.ai.tool.definition.ToolDefinition;
Expand Down Expand Up @@ -540,9 +541,15 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
Flux<ChatResponse> flux = chatResponseFlux.flatMap(response -> {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
return Flux.defer(() -> {
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
// is currently only synchronous
return Flux.deferContextual((ctx) -> {
ToolExecutionResult toolExecutionResult;
try {
ToolCallReactiveContextHolder.setContext(ctx);
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
} finally {
ToolCallReactiveContextHolder.clearContext();
}
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
return Flux.just(ChatResponse.builder().from(response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
Expand Down Expand Up @@ -357,10 +358,16 @@ public Flux<ChatResponse> stream(Prompt prompt) {
// @formatter:off
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response)) {
return Flux.defer(() -> {
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
var toolExecutionResult = this.toolCallingManager.executeToolCalls(requestPrompt, response);
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
return Flux.deferContextual((ctx) -> {
ToolExecutionResult toolExecutionResult;
try {
ToolCallReactiveContextHolder.setContext(ctx);
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
} finally {
ToolCallReactiveContextHolder.clearContext();
}
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
return Flux.just(ChatResponse.builder().from(response)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package org.springframework.ai.model.tool.internal;

import reactor.util.context.Context;
import reactor.util.context.ContextView;

/**
* This class bridges blocking Tools call and the reactive context. When calling tools, it
* captures the context in a thread local, making it available to re-inject in a nested
* reactive call.
*
* @author Daniel Garnier-Moiroux
* @since 1.1.0
*/
public class ToolCallReactiveContextHolder {

private static final ThreadLocal<ContextView> context = ThreadLocal.withInitial(Context::empty);

public static void setContext(ContextView contextView) {
context.set(contextView);
}

public static ContextView getContext() {
return context.get();
}

public static void clearContext() {
context.remove();
}

}