Skip to content

Commit 84c1713

Browse files
committed
Add tool context support to chat options and enhance function calling
This commit adds support for tool context in various chat options classes across different AI model implementations and enhances function calling capabilities. The tool context allows passing additional contextual information to function callbacks. - Add toolContext field to chat options classes - Update builder classes to support setting toolContext - Enhance FunctionCallback interface to support context-aware function calls - Update AbstractFunctionCallback to implement BiFunction instead of Function - Modify FunctionCallbackWrapper to support both Function and BiFunction and to use the new SchemaType location - Add support for BiFunction in TypeResolverHelper - Update ChatClient interface and DefaultChatClient implementation to support new function calling methods with Function, BiFunction and FunctionCallback arguments - Refactor AbstractToolCallSupport to pass tool context to function execution - Update all affected <Model>ChatOptions with tool context support - Simplify OpenAiChatClientMultipleFunctionCallsIT test - Add tests for function calling with tool context - Add new test cases for function callbacks with context in various integration tests - Modify existing tests to incorporate new context-aware function calling capabilities Resolves spring-projects#864, spring-projects#1303, spring-projects#991
1 parent 05292ac commit 84c1713

File tree

30 files changed

+665
-186
lines changed

30 files changed

+665
-186
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.ArrayList;
1919
import java.util.HashSet;
2020
import java.util.List;
21+
import java.util.Map;
2122
import java.util.Set;
2223

2324
import com.fasterxml.jackson.annotation.JsonIgnore;
@@ -80,6 +81,10 @@ public class AnthropicChatOptions implements ChatOptions, FunctionCallingOptions
8081

8182
@JsonIgnore
8283
private Boolean proxyToolCalls;
84+
85+
@JsonIgnore
86+
private Map<String, Object> toolContext;
87+
8388
// @formatter:on
8489

8590
public static Builder builder() {
@@ -152,6 +157,16 @@ public Builder withProxyToolCalls(Boolean proxyToolCalls) {
152157
return this;
153158
}
154159

160+
public Builder withToolContext(Map<String, Object> toolContext) {
161+
if (this.options.toolContext == null) {
162+
this.options.toolContext = toolContext;
163+
}
164+
else {
165+
this.options.toolContext.putAll(toolContext);
166+
}
167+
return this;
168+
}
169+
155170
public AnthropicChatOptions build() {
156171
return this.options;
157172
}
@@ -263,6 +278,16 @@ public void setProxyToolCalls(Boolean proxyToolCalls) {
263278
this.proxyToolCalls = proxyToolCalls;
264279
}
265280

281+
@Override
282+
public Map<String, Object> getToolContext() {
283+
return this.toolContext;
284+
}
285+
286+
@Override
287+
public void setToolContext(Map<String, Object> toolContext) {
288+
this.toolContext = toolContext;
289+
}
290+
266291
@Override
267292
public AnthropicChatOptions copy() {
268293
return fromOptions(this);
@@ -279,6 +304,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
279304
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
280305
.withFunctions(fromOptions.getFunctions())
281306
.withProxyToolCalls(fromOptions.getProxyToolCalls())
307+
.withToolContext(fromOptions.getToolContext())
282308
.build();
283309
}
284310

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,17 @@
2222
import java.util.Map;
2323
import java.util.Set;
2424

25-
import com.azure.ai.openai.models.AzureChatEnhancementConfiguration;
26-
import com.fasterxml.jackson.annotation.JsonIgnore;
27-
import com.fasterxml.jackson.annotation.JsonInclude;
28-
import com.fasterxml.jackson.annotation.JsonInclude.Include;
29-
import com.fasterxml.jackson.annotation.JsonProperty;
30-
3125
import org.springframework.ai.chat.prompt.ChatOptions;
3226
import org.springframework.ai.model.function.FunctionCallback;
3327
import org.springframework.ai.model.function.FunctionCallingOptions;
3428
import org.springframework.boot.context.properties.NestedConfigurationProperty;
3529
import org.springframework.util.Assert;
36-
import org.stringtemplate.v4.compiler.CodeGenerator.primary_return;
30+
31+
import com.azure.ai.openai.models.AzureChatEnhancementConfiguration;
32+
import com.fasterxml.jackson.annotation.JsonIgnore;
33+
import com.fasterxml.jackson.annotation.JsonInclude;
34+
import com.fasterxml.jackson.annotation.JsonInclude.Include;
35+
import com.fasterxml.jackson.annotation.JsonProperty;
3736

3837
/**
3938
* The configuration information for a chat completions request. Completions support a
@@ -199,6 +198,10 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio
199198
@JsonIgnore
200199
private AzureChatEnhancementConfiguration enhancements;
201200

201+
@NestedConfigurationProperty
202+
@JsonIgnore
203+
private Map<String, Object> toolContext;
204+
202205
public static Builder builder() {
203206
return new Builder();
204207
}
@@ -312,6 +315,16 @@ public Builder withEnhancements(AzureChatEnhancementConfiguration enhancements)
312315
return this;
313316
}
314317

318+
public Builder withToolContext(Map<String, Object> toolContext) {
319+
if (this.options.toolContext == null) {
320+
this.options.toolContext = toolContext;
321+
}
322+
else {
323+
this.options.toolContext.putAll(toolContext);
324+
}
325+
return this;
326+
}
327+
315328
public AzureOpenAiChatOptions build() {
316329
return this.options;
317330
}
@@ -498,6 +511,16 @@ public void setProxyToolCalls(Boolean proxyToolCalls) {
498511
this.proxyToolCalls = proxyToolCalls;
499512
}
500513

514+
@Override
515+
public Map<String, Object> getToolContext() {
516+
return this.toolContext;
517+
}
518+
519+
@Override
520+
public void setToolContext(Map<String, Object> toolContext) {
521+
this.toolContext = toolContext;
522+
}
523+
501524
@Override
502525
public AzureOpenAiChatOptions copy() {
503526
return fromOptions(this);
@@ -521,6 +544,7 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti
521544
.withLogprobs(fromOptions.isLogprobs())
522545
.withTopLogprobs(fromOptions.getTopLogProbs())
523546
.withEnhancements(fromOptions.getEnhancements())
547+
.withToolContext(fromOptions.getToolContext())
524548
.build();
525549
}
526550

models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import java.util.ArrayList;
3030
import java.util.HashSet;
3131
import java.util.List;
32+
import java.util.Map;
3233
import java.util.Set;
3334

3435
/**
@@ -145,6 +146,11 @@ public class MiniMaxChatOptions implements FunctionCallingOptions, ChatOptions {
145146

146147
@JsonIgnore
147148
private Boolean proxyToolCalls;
149+
150+
@NestedConfigurationProperty
151+
@JsonIgnore
152+
private Map<String, Object> toolContext;
153+
148154
// @formatter:on
149155

150156
public static Builder builder() {
@@ -250,6 +256,16 @@ public Builder withProxyToolCalls(Boolean proxyToolCalls) {
250256
return this;
251257
}
252258

259+
public Builder withToolContext(Map<String, Object> toolContext) {
260+
if (this.options.toolContext == null) {
261+
this.options.toolContext = toolContext;
262+
}
263+
else {
264+
this.options.toolContext.putAll(toolContext);
265+
}
266+
return this;
267+
}
268+
253269
public MiniMaxChatOptions build() {
254270
return this.options;
255271
}
@@ -411,6 +427,16 @@ public void setProxyToolCalls(Boolean proxyToolCalls) {
411427
this.proxyToolCalls = proxyToolCalls;
412428
}
413429

430+
@Override
431+
public Map<String, Object> getToolContext() {
432+
return this.toolContext;
433+
}
434+
435+
@Override
436+
public void setToolContext(Map<String, Object> toolContext) {
437+
this.toolContext = toolContext;
438+
}
439+
414440
@Override
415441
public int hashCode() {
416442
final int prime = 31;
@@ -429,6 +455,7 @@ public int hashCode() {
429455
result = prime * result + ((tools == null) ? 0 : tools.hashCode());
430456
result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode());
431457
result = prime * result + ((proxyToolCalls == null) ? 0 : proxyToolCalls.hashCode());
458+
result = prime * result + ((toolContext == null) ? 0 : toolContext.hashCode());
432459
return result;
433460
}
434461

@@ -525,6 +552,14 @@ else if (!toolChoice.equals(other.toolChoice))
525552
}
526553
else if (!proxyToolCalls.equals(other.proxyToolCalls))
527554
return false;
555+
556+
if (this.toolContext == null) {
557+
if (other.toolContext != null)
558+
return false;
559+
}
560+
else if (!toolContext.equals(other.toolContext))
561+
return false;
562+
528563
return true;
529564
}
530565

@@ -550,6 +585,7 @@ public static MiniMaxChatOptions fromOptions(MiniMaxChatOptions fromOptions) {
550585
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
551586
.withFunctions(fromOptions.getFunctions())
552587
.withProxyToolCalls(fromOptions.getProxyToolCalls())
588+
.withToolContext(fromOptions.getToolContext())
553589
.build();
554590
}
555591

0 commit comments

Comments
 (0)