Skip to content

Commit 8983dcb

Browse files
committed
Propagate reactive Context to AsyncMcpToolCallback
- When calling tools while using ChatModel#stream, store the reactive context in a thread-local, so it can be used by downstream reactive tools. - In AsyncMcpToolCallback, restore the reactive context so it can be accessed by the tool. This will be useful for Spring Security OAuth2 support in reactive scenarios, because it relies on the context. Signed-off-by: Daniel Garnier-Moiroux <[email protected]>
1 parent 5aa8940 commit 8983dcb

File tree

12 files changed

+133
-30
lines changed

12 files changed

+133
-30
lines changed

mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import org.springframework.ai.chat.model.ToolContext;
2525
import org.springframework.ai.model.ModelOptionsUtils;
26+
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
2627
import org.springframework.ai.tool.ToolCallback;
2728
import org.springframework.ai.tool.definition.DefaultToolDefinition;
2829
import org.springframework.ai.tool.definition.ToolDefinition;
@@ -120,7 +121,7 @@ public String call(String functionInput) {
120121
new IllegalStateException("Error calling tool: " + response.content()));
121122
}
122123
return ModelOptionsUtils.toJsonString(response.content());
123-
}).block();
124+
}).contextWrite(ctx -> ctx.putAll(ToolCallReactiveContextHolder.getContext())).block();
124125
}
125126

126127
@Override

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
import org.springframework.ai.content.Media;
6565
import org.springframework.ai.model.ModelOptionsUtils;
6666
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
67+
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
6768
import org.springframework.ai.model.tool.ToolCallingChatOptions;
6869
import org.springframework.ai.model.tool.ToolCallingManager;
6970
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
@@ -263,8 +264,14 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
263264
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse) && chatResponse.hasFinishReasons(Set.of("tool_use"))) {
264265
// FIXME: bounded elastic needs to be used since tool calling
265266
// is currently only synchronous
266-
return Flux.defer(() -> {
267-
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
267+
return Flux.deferContextual((ctx) -> {
268+
ToolExecutionResult toolExecutionResult;
269+
try {
270+
ToolCallReactiveContextHolder.setContext(ctx);
271+
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
272+
} finally {
273+
ToolCallReactiveContextHolder.clearContext();
274+
}
268275
if (toolExecutionResult.returnDirect()) {
269276
// Return tool execution result directly to the client.
270277
return Flux.just(ChatResponse.builder().from(chatResponse)

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
import org.springframework.ai.model.tool.ToolCallingManager;
9696
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
9797
import org.springframework.ai.model.tool.ToolExecutionResult;
98+
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
9899
import org.springframework.ai.observation.conventions.AiProvider;
99100
import org.springframework.ai.support.UsageCalculator;
100101
import org.springframework.ai.tool.definition.ToolDefinition;
@@ -380,8 +381,15 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
380381
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)) {
381382
// FIXME: bounded elastic needs to be used since tool calling
382383
// is currently only synchronous
383-
return Flux.defer(() -> {
384-
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
384+
return Flux.deferContextual((ctx) -> {
385+
ToolExecutionResult toolExecutionResult;
386+
try {
387+
ToolCallReactiveContextHolder.setContext(ctx);
388+
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
389+
}
390+
finally {
391+
ToolCallReactiveContextHolder.clearContext();
392+
}
385393
if (toolExecutionResult.returnDirect()) {
386394
// Return tool execution result directly to the client.
387395
return Flux.just(ChatResponse.builder()

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
import org.springframework.ai.model.tool.ToolCallingManager;
102102
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
103103
import org.springframework.ai.model.tool.ToolExecutionResult;
104+
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
104105
import org.springframework.ai.observation.conventions.AiProvider;
105106
import org.springframework.ai.tool.definition.ToolDefinition;
106107
import org.springframework.util.Assert;
@@ -681,8 +682,15 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse perviousCh
681682

682683
// FIXME: bounded elastic needs to be used since tool calling
683684
// is currently only synchronous
684-
return Flux.defer(() -> {
685-
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
685+
return Flux.deferContextual((ctx) -> {
686+
ToolExecutionResult toolExecutionResult;
687+
try {
688+
ToolCallReactiveContextHolder.setContext(ctx);
689+
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
690+
}
691+
finally {
692+
ToolCallReactiveContextHolder.clearContext();
693+
}
686694

687695
if (toolExecutionResult.returnDirect()) {
688696
// Return tool execution result directly to the client.

models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
import org.springframework.ai.model.tool.ToolCallingManager;
6363
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
6464
import org.springframework.ai.model.tool.ToolExecutionResult;
65+
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
6566
import org.springframework.ai.retry.RetryUtils;
6667
import org.springframework.ai.support.UsageCalculator;
6768
import org.springframework.ai.tool.definition.ToolDefinition;
@@ -286,10 +287,16 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
286287
// @formatter:off
287288
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
288289
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
289-
return Flux.defer(() -> {
290-
// FIXME: bounded elastic needs to be used since tool calling
291-
// is currently only synchronous
292-
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
290+
// FIXME: bounded elastic needs to be used since tool calling
291+
// is currently only synchronous
292+
return Flux.deferContextual((ctx) -> {
293+
ToolExecutionResult toolExecutionResult;
294+
try {
295+
ToolCallReactiveContextHolder.setContext(ctx);
296+
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
297+
} finally {
298+
ToolCallReactiveContextHolder.clearContext();
299+
}
293300
if (toolExecutionResult.returnDirect()) {
294301
// Return tool execution result directly to the client.
295302
return Flux.just(ChatResponse.builder().from(response)

models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
import org.springframework.ai.model.tool.ToolCallingManager;
6666
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
6767
import org.springframework.ai.model.tool.ToolExecutionResult;
68+
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
6869
import org.springframework.ai.retry.RetryUtils;
6970
import org.springframework.ai.tool.definition.ToolDefinition;
7071
import org.springframework.http.ResponseEntity;
@@ -370,10 +371,16 @@ public Flux<ChatResponse> stream(Prompt prompt) {
370371

371372
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
372373
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response)) {
373-
return Flux.defer(() -> {
374-
// FIXME: bounded elastic needs to be used since tool calling
375-
// is currently only synchronous
376-
var toolExecutionResult = this.toolCallingManager.executeToolCalls(requestPrompt, response);
374+
// FIXME: bounded elastic needs to be used since tool calling
375+
// is currently only synchronous
376+
return Flux.deferContextual((ctx) -> {
377+
ToolExecutionResult toolExecutionResult;
378+
try {
379+
ToolCallReactiveContextHolder.setContext(ctx);
380+
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
381+
} finally {
382+
ToolCallReactiveContextHolder.clearContext();
383+
}
377384
if (toolExecutionResult.returnDirect()) {
378385
// Return tool execution result directly to the client.
379386
return Flux.just(ChatResponse.builder().from(response)

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
import org.springframework.ai.model.tool.ToolCallingManager;
6565
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
6666
import org.springframework.ai.model.tool.ToolExecutionResult;
67+
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
6768
import org.springframework.ai.retry.RetryUtils;
6869
import org.springframework.ai.support.UsageCalculator;
6970
import org.springframework.ai.tool.definition.ToolDefinition;
@@ -316,8 +317,14 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
316317
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
317318
// FIXME: bounded elastic needs to be used since tool calling
318319
// is currently only synchronous
319-
return Flux.defer(() -> {
320-
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
320+
return Flux.deferContextual((ctx) -> {
321+
ToolExecutionResult toolExecutionResult;
322+
try {
323+
ToolCallReactiveContextHolder.setContext(ctx);
324+
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
325+
} finally {
326+
ToolCallReactiveContextHolder.clearContext();
327+
}
321328
if (toolExecutionResult.returnDirect()) {
322329
// Return tool execution result directly to the client.
323330
return Flux.just(ChatResponse.builder().from(response)

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import org.springframework.ai.model.tool.ToolCallingManager;
5555
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
5656
import org.springframework.ai.model.tool.ToolExecutionResult;
57+
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
5758
import org.springframework.ai.ollama.api.OllamaApi;
5859
import org.springframework.ai.ollama.api.OllamaApi.ChatRequest;
5960
import org.springframework.ai.ollama.api.OllamaApi.Message.Role;
@@ -351,8 +352,14 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCh
351352
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
352353
// FIXME: bounded elastic needs to be used since tool calling
353354
// is currently only synchronous
354-
return Flux.defer(() -> {
355-
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
355+
return Flux.deferContextual((ctx) -> {
356+
ToolExecutionResult toolExecutionResult;
357+
try {
358+
ToolCallReactiveContextHolder.setContext(ctx);
359+
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
360+
} finally {
361+
ToolCallReactiveContextHolder.clearContext();
362+
}
356363
if (toolExecutionResult.returnDirect()) {
357364
// Return tool execution result directly to the client.
358365
return Flux.just(ChatResponse.builder().from(response)

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
import org.springframework.ai.model.tool.ToolCallingManager;
6262
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
6363
import org.springframework.ai.model.tool.ToolExecutionResult;
64+
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
6465
import org.springframework.ai.openai.api.OpenAiApi;
6566
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
6667
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion.Choice;
@@ -363,10 +364,16 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
363364
// @formatter:off
364365
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
365366
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
366-
return Flux.defer(() -> {
367-
// FIXME: bounded elastic needs to be used since tool calling
368-
// is currently only synchronous
369-
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
367+
// FIXME: bounded elastic needs to be used since tool calling
368+
// is currently only synchronous
369+
return Flux.deferContextual((ctx) -> {
370+
ToolExecutionResult toolExecutionResult;
371+
try {
372+
ToolCallReactiveContextHolder.setContext(ctx);
373+
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
374+
} finally {
375+
ToolCallReactiveContextHolder.clearContext();
376+
}
370377
if (toolExecutionResult.returnDirect()) {
371378
// Return tool execution result directly to the client.
372379
return Flux.just(ChatResponse.builder().from(response)

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
import org.springframework.ai.model.tool.ToolCallingManager;
8282
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
8383
import org.springframework.ai.model.tool.ToolExecutionResult;
84+
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
8485
import org.springframework.ai.retry.RetryUtils;
8586
import org.springframework.ai.support.UsageCalculator;
8687
import org.springframework.ai.tool.definition.ToolDefinition;
@@ -540,9 +541,15 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
540541
Flux<ChatResponse> flux = chatResponseFlux.flatMap(response -> {
541542
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
542543
// FIXME: bounded elastic needs to be used since tool calling
543-
// is currently only synchronous
544-
return Flux.defer(() -> {
545-
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
544+
// is currently only synchronous
545+
return Flux.deferContextual((ctx) -> {
546+
ToolExecutionResult toolExecutionResult;
547+
try {
548+
ToolCallReactiveContextHolder.setContext(ctx);
549+
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
550+
} finally {
551+
ToolCallReactiveContextHolder.clearContext();
552+
}
546553
if (toolExecutionResult.returnDirect()) {
547554
// Return tool execution result directly to the client.
548555
return Flux.just(ChatResponse.builder().from(response)

models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import org.springframework.ai.model.tool.ToolCallingManager;
5757
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
5858
import org.springframework.ai.model.tool.ToolExecutionResult;
59+
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
5960
import org.springframework.ai.retry.RetryUtils;
6061
import org.springframework.ai.tool.definition.ToolDefinition;
6162
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
@@ -357,10 +358,16 @@ public Flux<ChatResponse> stream(Prompt prompt) {
357358
// @formatter:off
358359
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
359360
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response)) {
360-
return Flux.defer(() -> {
361-
// FIXME: bounded elastic needs to be used since tool calling
362-
// is currently only synchronous
363-
var toolExecutionResult = this.toolCallingManager.executeToolCalls(requestPrompt, response);
361+
// FIXME: bounded elastic needs to be used since tool calling
362+
// is currently only synchronous
363+
return Flux.deferContextual((ctx) -> {
364+
ToolExecutionResult toolExecutionResult;
365+
try {
366+
ToolCallReactiveContextHolder.setContext(ctx);
367+
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
368+
} finally {
369+
ToolCallReactiveContextHolder.clearContext();
370+
}
364371
if (toolExecutionResult.returnDirect()) {
365372
// Return tool execution result directly to the client.
366373
return Flux.just(ChatResponse.builder().from(response)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package org.springframework.ai.model.tool.internal;
2+
3+
import reactor.util.context.Context;
4+
import reactor.util.context.ContextView;
5+
6+
/**
7+
* This class bridges calling blocking tools within a reactive context. When calling
8+
* tools, itt captures the reactive context in a thread local, to then re-inject it in a
9+
* reactive call downstream.
10+
*
11+
* @author Daniel Garnier-Moiroux
12+
* @since 1.1.0
13+
*/
14+
public class ToolCallReactiveContextHolder {
15+
16+
private static final ThreadLocal<ContextView> context = ThreadLocal.withInitial(Context::empty);
17+
18+
public static void setContext(ContextView contextView) {
19+
context.set(contextView);
20+
}
21+
22+
public static ContextView getContext() {
23+
return context.get();
24+
}
25+
26+
public static void clearContext() {
27+
context.remove();
28+
}
29+
30+
}

0 commit comments

Comments
 (0)