diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java index cc320a54d5c..d5f3970ea5a 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java @@ -118,8 +118,8 @@ public String call(String toolInput, @Nullable ToolContext toolContext) { private void validateToolContextSupport(@Nullable ToolContext toolContext) { var isNonEmptyToolContextProvided = toolContext != null && !CollectionUtils.isEmpty(toolContext.getContext()); var isToolContextAcceptedByMethod = Stream.of(this.toolMethod.getParameterTypes()) - .anyMatch(type -> ClassUtils.isAssignable(type, ToolContext.class)); - if (isToolContextAcceptedByMethod && !isNonEmptyToolContextProvided) { + .anyMatch(type -> ClassUtils.isAssignable(ToolContext.class, type)); + if (isNonEmptyToolContextProvided && !isToolContextAcceptedByMethod) { throw new IllegalArgumentException("ToolContext is required by the method as an argument"); } } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackGenericTypesTest.java b/spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackGenericTypesTest.java index b99faa71a3a..544d440a9a9 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackGenericTypesTest.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackGenericTypesTest.java @@ -22,10 +22,12 @@ import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * Tests for {@link MethodToolCallback} with generic types. @@ -137,6 +139,76 @@ void testNestedGenericType() throws Exception { assertThat(result).isEqualTo("2 maps processed: [{a=1, b=2}, {c=3, d=4}]"); } + @Test + void testToolContextType() throws Exception { + // Create a test object with a method that takes a List> + TestGenericClass testObject = new TestGenericClass(); + Method method = TestGenericClass.class.getMethod("processStringListInToolContext", ToolContext.class); + + // Create a tool definition + ToolDefinition toolDefinition = DefaultToolDefinition.builder() + .name("processToolContext") + .description("Process tool context") + .inputSchema("{}") + .build(); + + // Create a MethodToolCallback + MethodToolCallback callback = MethodToolCallback.builder() + .toolDefinition(toolDefinition) + .toolMethod(method) + .toolObject(testObject) + .build(); + + // Create an empty JSON input + String toolInput = """ + {} + """; + + // Create a toolContext + ToolContext toolContext = new ToolContext(Map.of("foo", "bar")); + + // Call the tool + String result = callback.call(toolInput, toolContext); + + // Verify the result + assertThat(result).isEqualTo("1 entries processed {foo=bar}"); + } + + @Test + void testToolContextTypeWithNonToolContextArgs() throws Exception { + // Create a test object with a method that takes a List + TestGenericClass testObject = new TestGenericClass(); + Method method = TestGenericClass.class.getMethod("processStringList", List.class); + + // Create a tool definition + ToolDefinition toolDefinition = DefaultToolDefinition.builder() + .name("processStringList") + .description("Process a list of strings") + .inputSchema("{}") + .build(); + + // Create a MethodToolCallback + MethodToolCallback callback = MethodToolCallback.builder() + .toolDefinition(toolDefinition) + .toolMethod(method) + .toolObject(testObject) + .build(); + + // Create a JSON input with a list of strings + String toolInput = """ + { + "strings": ["one", "two", "three"] + } + """; + + // Create a toolContext + ToolContext toolContext = new ToolContext(Map.of("foo", "bar")); + + // Call the tool and verify + assertThatThrownBy(() -> callback.call(toolInput, toolContext)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("ToolContext is required by the method as an argument"); + } + /** * Test class with methods that use generic types. */ @@ -154,6 +226,11 @@ public String processListOfMaps(List> listOfMaps) { return listOfMaps.size() + " maps processed: " + listOfMaps; } + public String processStringListInToolContext(ToolContext toolContext) { + Map context = toolContext.getContext(); + return context.size() + " entries processed " + context; + } + } }