From 5edafc9782cfd67d05722a2c00e6e9ca490386fd Mon Sep 17 00:00:00 2001 From: Lorenzo Caenazzo Date: Fri, 17 May 2024 10:25:29 +0200 Subject: [PATCH] :sparkles: avoid second roundtrip in function callbacks --- .../ai/anthropic/AnthropicChatClient.java | 17 ++++++++-- .../azure/openai/AzureOpenAiChatClient.java | 17 ++++++++-- .../AzureOpenAiChatClientFunctionCallIT.java | 23 +++++++++++++ .../function/SpyingMockWeatherService.java | 34 +++++++++++++++++++ .../ai/mistralai/MistralAiChatClient.java | 23 ++++++++----- .../ai/openai/OpenAiChatClient.java | 19 ++++++++--- .../gemini/VertexAiGeminiChatClient.java | 29 +++++++++++----- .../function/AbstractFunctionCallSupport.java | 16 +++++---- .../function/AbstractFunctionCallback.java | 23 ++++++++++++- .../ai/model/function/FunctionCallback.java | 7 +++- .../function/FunctionCallbackContext.java | 14 ++++++-- .../function/FunctionCallbackWrapper.java | 34 ++++++++++++++++--- 12 files changed, 216 insertions(+), 40 deletions(-) create mode 100644 models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/SpyingMockWeatherService.java diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatClient.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatClient.java index 0f9bcbf4430..a2e68bfc1a0 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatClient.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatClient.java @@ -20,6 +20,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -48,6 +49,7 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.AbstractFunctionCallSupport; +import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.retry.RetryUtils; import org.springframework.http.ResponseEntity; @@ -391,6 +393,16 @@ public ChatCompletion build() { } + @Override + protected boolean hasReturningFunction(RequestMessage responseMessage) { + return responseMessage.content() + .stream() + .filter(c -> c.type() == MediaContent.Type.TOOL_USE) + .map(MediaContent::name) + .map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName))) + .anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false)); + } + @Override protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest, RequestMessage responseMessage, List conversationHistory) { @@ -414,8 +426,9 @@ protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionReques String functionResponse = this.functionCallbackRegister.get(functionName) .call(ModelOptionsUtils.toJsonString(functionArguments)); - - toolResults.add(new MediaContent(Type.TOOL_RESULT, functionCallId, functionResponse)); + if (functionResponse != null) { + toolResults.add(new MediaContent(Type.TOOL_RESULT, functionCallId, functionResponse)); + } } // Add the function response to the conversation. diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java index a49a42ff59f..28c90c79d2b 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java @@ -28,6 +28,7 @@ import com.azure.ai.openai.models.ChatRequestSystemMessage; import com.azure.ai.openai.models.ChatRequestToolMessage; import com.azure.ai.openai.models.ChatRequestUserMessage; +import com.azure.ai.openai.models.ChatResponseMessage; import com.azure.ai.openai.models.CompletionsFinishReason; import com.azure.ai.openai.models.ContentFilterResultsForPrompt; import com.azure.ai.openai.models.FunctionCall; @@ -50,6 +51,7 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.AbstractFunctionCallSupport; +import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -513,6 +515,15 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) { return copyOptions; } + @Override + protected boolean hasReturningFunction(ChatRequestMessage responseMessage) { + return ((ChatRequestAssistantMessage) responseMessage).getToolCalls() + .stream() + .map(toolCall -> ((ChatCompletionsFunctionToolCall) toolCall).getFunction().getName()) + .map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName))) + .anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false)); + } + @Override protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOptions previousRequest, ChatRequestMessage responseMessage, List conversationHistory) { @@ -530,8 +541,10 @@ protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOpti String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments); - // Add the function response to the conversation. - conversationHistory.add(new ChatRequestToolMessage(functionResponse, toolCall.getId())); + if (functionResponse != null) { + // Add the function response to the conversation. + conversationHistory.add(new ChatRequestToolMessage(functionResponse, toolCall.getId())); + } } // Recursively call chatCompletionWithTools until the model doesn't call a diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatClientFunctionCallIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatClientFunctionCallIT.java index 08c81ebd136..ad250237cd4 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatClientFunctionCallIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatClientFunctionCallIT.java @@ -118,6 +118,29 @@ void streamFunctionCallTest() { assertThat(content).containsAnyOf("15.0", "15"); } + @Test + void functionCallWithoutCompleteRoundTrip() { + + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco?"); + + List messages = new ArrayList<>(List.of(userMessage)); + + final var spyingMockWeatherService = new SpyingMockWeatherService(); + var promptOptions = AzureOpenAiChatOptions.builder() + .withDeploymentName(selectedModel) + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(spyingMockWeatherService) + .withName("getCurrentWeather") + .withDescription("Get the current weather in a given location") + .build())) + .build(); + + ChatResponse response = chatClient.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + final var interceptedRequest = spyingMockWeatherService.getInterceptedRequest(); + assertThat(interceptedRequest.location()).containsIgnoringCase("San Francisco"); + } + @SpringBootConfiguration public static class TestConfiguration { diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/SpyingMockWeatherService.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/SpyingMockWeatherService.java new file mode 100644 index 00000000000..afa83991f67 --- /dev/null +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/SpyingMockWeatherService.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.azure.openai.function; + +import java.util.function.Function; + +public class SpyingMockWeatherService implements Function { + + private MockWeatherService.Request interceptedRequest = null; + + @Override + public Void apply(MockWeatherService.Request request) { + interceptedRequest = request; + return null; + } + + public MockWeatherService.Request getInterceptedRequest() { + return interceptedRequest; + } + +} diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java index 98a25025d67..77a350a125b 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java @@ -33,6 +33,7 @@ import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.AbstractFunctionCallSupport; +import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.retry.RetryUtils; import org.springframework.http.ResponseEntity; @@ -98,7 +99,6 @@ public ChatResponse call(Prompt prompt) { var request = createRequest(prompt, false); return retryTemplate.execute(ctx -> { - ResponseEntity completionEntity = this.callWithFunctionSupport(request); var chatCompletion = completionEntity.getBody(); @@ -239,13 +239,18 @@ private List getFunctionTools(Set functionNam }).toList(); } - // - // Function Calling Support - // + @Override + protected boolean hasReturningFunction(ChatCompletionMessage responseMessage) { + return responseMessage.toolCalls() + .stream() + .map(toolCall -> toolCall.function().name()) + .map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName))) + .anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false)); + } + @Override protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest, ChatCompletionMessage responseMessage, List conversationHistory) { - // Every tool-call item requires a separate function call and a response (TOOL) // message. for (ToolCall toolCall : responseMessage.toolCalls()) { @@ -258,10 +263,12 @@ protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionReques } String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments); + if (functionResponse != null) { + // Add the function response to the conversation. + conversationHistory.add(new ChatCompletionMessage(functionResponse, ChatCompletionMessage.Role.TOOL, + functionName, null)); + } - // Add the function response to the conversation. - conversationHistory - .add(new ChatCompletionMessage(functionResponse, ChatCompletionMessage.Role.TOOL, functionName, null)); } // Recursively call chatCompletionWithTools until the model doesn't call a diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java index 6ec6904dbf1..71663420cfd 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java @@ -27,6 +27,7 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.AbstractFunctionCallSupport; +import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; @@ -324,6 +325,15 @@ private List getFunctionTools(Set functionNames) }).toList(); } + @Override + protected boolean hasReturningFunction(ChatCompletionMessage responseMessage) { + return responseMessage.toolCalls() + .stream() + .map(toolCall -> toolCall.function().name()) + .map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName))) + .anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false)); + } + @Override protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest, ChatCompletionMessage responseMessage, List conversationHistory) { @@ -340,10 +350,11 @@ protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionReques } String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments); - - // Add the function response to the conversation. - conversationHistory - .add(new ChatCompletionMessage(functionResponse, Role.TOOL, functionName, toolCall.id(), null)); + if (functionResponse != null) { + // Add the function response to the conversation. + conversationHistory + .add(new ChatCompletionMessage(functionResponse, Role.TOOL, functionName, toolCall.id(), null)); + } } // Recursively call chatCompletionWithTools until the model doesn't call a diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatClient.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatClient.java index e30df03cef8..4d634d6dc1a 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatClient.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatClient.java @@ -44,6 +44,7 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.AbstractFunctionCallSupport; +import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.vertexai.gemini.metadata.VertexAiChatResponseMetadata; import org.springframework.ai.vertexai.gemini.metadata.VertexAiUsage; @@ -57,6 +58,7 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -406,6 +408,14 @@ public void destroy() throws Exception { } } + @Override + protected boolean hasReturningFunction(Content responseMessage) { + final var functionName = responseMessage.getPartsList().get(0).getFunctionCall().getName(); + return Optional.ofNullable(this.functionCallbackRegister.get(functionName)) + .map(FunctionCallback::returningFunction) + .orElse(false); + } + @Override protected GeminiRequest doCreateToolResponseRequest(GeminiRequest previousRequest, Content responseMessage, List conversationHistory) { @@ -420,17 +430,18 @@ protected GeminiRequest doCreateToolResponseRequest(GeminiRequest previousReques } String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments); - - Content contentFnResp = Content.newBuilder() - .addParts(Part.newBuilder() - .setFunctionResponse(FunctionResponse.newBuilder() - .setName(functionCall.getName()) - .setResponse(jsonToStruct(functionResponse)) + if (functionResponse != null) { + Content contentFnResp = Content.newBuilder() + .addParts(Part.newBuilder() + .setFunctionResponse(FunctionResponse.newBuilder() + .setName(functionCall.getName()) + .setResponse(jsonToStruct(functionResponse)) + .build()) .build()) - .build()) - .build(); + .build(); - conversationHistory.add(contentFnResp); + conversationHistory.add(contentFnResp); + } return new GeminiRequest(conversationHistory, previousRequest.model()); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallSupport.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallSupport.java index d5be8ef6cad..de05b51db6f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallSupport.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallSupport.java @@ -60,20 +60,17 @@ protected Set handleFunctionCallbackConfigurations(FunctionCallingOption if (options != null) { if (!CollectionUtils.isEmpty(options.getFunctionCallbacks())) { - options.getFunctionCallbacks().stream().forEach(functionCallback -> { - + options.getFunctionCallbacks().forEach(functionCallback -> { // Register the tool callback. if (isRuntimeCall) { this.functionCallbackRegister.put(functionCallback.getName(), functionCallback); + // Automatically enable the function, usually from prompt + // callback. + functionToCall.add(functionCallback.getName()); } else { this.functionCallbackRegister.putIfAbsent(functionCallback.getName(), functionCallback); } - - // Automatically enable the function, usually from prompt callback. - if (isRuntimeCall) { - functionToCall.add(functionCallback.getName()); - } }); } @@ -147,6 +144,9 @@ protected Resp handleFunctionCallOrReturn(Req request, Resp response) { Req newRequest = this.doCreateToolResponseRequest(request, responseMessage, conversationHistory); + if (!this.hasReturningFunction(responseMessage)) { + return response; + } return this.callWithFunctionSupport(newRequest); } @@ -180,6 +180,8 @@ protected Flux handleFunctionCallOrReturnStream(Req request, Flux re } + abstract protected boolean hasReturningFunction(Msg responseMessage); + abstract protected Req doCreateToolResponseRequest(Req previousRequest, Msg responseMessage, List conversationHistory); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java index f4cdd4ef8e7..f1867d1f388 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java @@ -45,6 +45,8 @@ abstract class AbstractFunctionCallback implements Function, Functio private final Class inputType; + private final Class outputType; + private final String inputTypeSchema; private final ObjectMapper objectMapper; @@ -62,12 +64,13 @@ abstract class AbstractFunctionCallback implements Function, Functio * or OpenAPI Schema)required by the Model's function calling protocol. * @param inputType Used to compute, the argument's JSON schema required by the * Model's function calling protocol. + * @param outputType Used to identify the scope of that function. * @param responseConverter Used to convert the function's output type to a string. * @param objectMapper Used to convert the function's input and output types to and * from JSON. */ protected AbstractFunctionCallback(String name, String description, String inputTypeSchema, Class inputType, - Function responseConverter, ObjectMapper objectMapper) { + Class outputType, Function responseConverter, ObjectMapper objectMapper) { Assert.notNull(name, "Name must not be null"); Assert.notNull(description, "Description must not be null"); Assert.notNull(inputType, "InputType must not be null"); @@ -77,6 +80,7 @@ protected AbstractFunctionCallback(String name, String description, String input this.name = name; this.description = description; this.inputType = inputType; + this.outputType = outputType; this.inputTypeSchema = inputTypeSchema; this.responseConverter = responseConverter; this.objectMapper = objectMapper; @@ -97,6 +101,11 @@ public String getInputTypeSchema() { return this.inputTypeSchema; } + @Override + public boolean returningFunction() { + return !outputType.isAssignableFrom(Void.class); + } + @Override public String call(String functionArguments) { @@ -104,6 +113,10 @@ public String call(String functionArguments) { I request = fromJson(functionArguments, inputType); // extend conversation with function response. + if (outputType.isAssignableFrom(Void.class)) { + this.apply(request); + return null; + } return this.andThen(this.responseConverter).apply(request); } @@ -123,6 +136,7 @@ public int hashCode() { result = prime * result + ((name == null) ? 0 : name.hashCode()); result = prime * result + ((description == null) ? 0 : description.hashCode()); result = prime * result + ((inputType == null) ? 0 : inputType.hashCode()); + result = prime * result + ((outputType == null) ? 0 : outputType.hashCode()); return result; } @@ -153,6 +167,13 @@ else if (!description.equals(other.description)) } else if (!inputType.equals(other.inputType)) return false; + + if (outputType == null) { + if (other.outputType != null) + return false; + } + else if (!outputType.equals(other.outputType)) + return false; return true; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java index 6f4e8bac482..c30f0adcc83 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java @@ -49,4 +49,9 @@ public interface FunctionCallback { */ public String call(String functionInput); -} \ No newline at end of file + /** + * @return This function return a value or not + */ + boolean returningFunction(); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java index 59f43098753..07babd04cd3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java @@ -16,6 +16,7 @@ package org.springframework.ai.model.function; import java.lang.reflect.Type; +import java.util.function.Consumer; import java.util.function.Function; import com.fasterxml.jackson.annotation.JsonClassDescription; @@ -73,9 +74,10 @@ public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable "Functional bean with name: " + beanName + " does not exist in the context."); } - if (!Function.class.isAssignableFrom(FunctionTypeUtils.getRawType(beanType))) { + if (!Function.class.isAssignableFrom(FunctionTypeUtils.getRawType(beanType)) + || !Consumer.class.isAssignableFrom(FunctionTypeUtils.getRawType(beanType))) { throw new IllegalArgumentException( - "Function call Bean must be of type Function. Found: " + beanType.getTypeName()); + "Function call Bean must be of type Function or Consumer. Found: " + beanType.getTypeName()); } Type functionInputType = TypeResolverHelper.getFunctionArgumentType(beanType, 0); @@ -118,6 +120,14 @@ public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable .withInputType(functionInputClass) .build(); } + if (bean instanceof Consumer consumer) { + return FunctionCallbackWrapper.builder(consumer) + .withName(functionName) + .withSchemaType(this.schemaType) + .withDescription(functionDescription) + .withInputType(functionInputClass) + .build(); + } else { throw new IllegalArgumentException("Bean must be of type Function"); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java index d065a5fe4da..fad54fdc313 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.model.function; +import java.util.function.Consumer; import java.util.function.Function; import com.fasterxml.jackson.databind.DeserializationFeature; @@ -37,8 +38,9 @@ public class FunctionCallbackWrapper extends AbstractFunctionCallback function; private FunctionCallbackWrapper(String name, String description, String inputTypeSchema, Class inputType, - Function responseConverter, ObjectMapper objectMapper, Function function) { - super(name, description, inputTypeSchema, inputType, responseConverter, objectMapper); + Class outputType, Function responseConverter, ObjectMapper objectMapper, + Function function) { + super(name, description, inputTypeSchema, inputType, outputType, responseConverter, objectMapper); Assert.notNull(function, "Function must not be null"); this.function = function; } @@ -48,6 +50,11 @@ private static Class resolveInputType(Function function) { return (Class) TypeResolverHelper.getFunctionInputClass((Class>) function.getClass()); } + @SuppressWarnings("unchecked") + private static Class resolveOutputType(Function function) { + return (Class) TypeResolverHelper.getFunctionOutputClass((Class>) function.getClass()); + } + @Override public O apply(I input) { return this.function.apply(input); @@ -57,6 +64,14 @@ public static Builder builder(Function function) { return new Builder<>(function); } + public static Builder builder(Consumer consumer) { + final Function adapter = i -> { + consumer.accept(i); + return null; + }; + return new Builder<>(adapter); + } + public static class Builder { public enum SchemaType { @@ -71,6 +86,8 @@ public enum SchemaType { private Class inputType; + private Class outputType; + private final Function function; private SchemaType schemaType = SchemaType.JSON_SCHEMA; @@ -108,6 +125,12 @@ public Builder withInputType(Class inputType) { return this; } + @SuppressWarnings("unchecked") + public Builder withOutputType(Class outputType) { + this.outputType = (Class) outputType; + return this; + } + public Builder withResponseConverter(Function responseConverter) { Assert.notNull(responseConverter, "ResponseConverter must not be null"); this.responseConverter = responseConverter; @@ -143,6 +166,9 @@ public FunctionCallbackWrapper build() { if (this.inputType == null) { this.inputType = resolveInputType(this.function); } + if (this.outputType == null) { + this.outputType = resolveOutputType(this.function); + } if (this.inputTypeSchema == null) { boolean upperCaseTypeValues = this.schemaType == SchemaType.OPEN_API_SCHEMA; @@ -150,9 +176,9 @@ public FunctionCallbackWrapper build() { } return new FunctionCallbackWrapper<>(this.name, this.description, this.inputTypeSchema, this.inputType, - this.responseConverter, this.objectMapper, this.function); + this.outputType, this.responseConverter, this.objectMapper, this.function); } } -} \ No newline at end of file +}