diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index c437786dd5e..a849716e37f 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -89,7 +89,6 @@ import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.model.function.FunctionCallingOptions; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder; import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.util.Assert; @@ -131,7 +130,7 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch private final BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient; - private FunctionCallingOptions defaultOptions; + private BedrockProxyChatOptions defaultOptions; /** * Observation registry used for instrumentation. @@ -144,7 +143,7 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch private ChatModelObservationConvention observationConvention; public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient, - BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, FunctionCallingOptions defaultOptions, + BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, BedrockProxyChatOptions defaultOptions, FunctionCallbackContext functionCallbackContext, List toolFunctionCallbacks, ObservationRegistry observationRegistry) { @@ -305,17 +304,14 @@ else if (message.getMessageType() == MessageType.TOOL) { .map(sysMessage -> SystemContentBlock.builder().text(sysMessage.getContent()).build()) .toList(); - FunctionCallingOptions updatedRuntimeOptions = (FunctionCallingOptions) this.defaultOptions.copy(); + BedrockProxyChatOptions updatedRuntimeOptions = (BedrockProxyChatOptions) this.defaultOptions.copy(); if (prompt.getOptions() != null) { - if (prompt.getOptions() instanceof FunctionCallingOptions) { - var functionCallingOptions = (FunctionCallingOptions) prompt.getOptions(); - updatedRuntimeOptions = ((PortableFunctionCallingOptions) updatedRuntimeOptions) - .merge(functionCallingOptions); + if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) { + updatedRuntimeOptions = (BedrockProxyChatOptions) updatedRuntimeOptions.merge(functionCallingOptions); } - else if (prompt.getOptions() instanceof ChatOptions) { - var chatOptions = (ChatOptions) prompt.getOptions(); - updatedRuntimeOptions = ((PortableFunctionCallingOptions) updatedRuntimeOptions).merge(chatOptions); + else if (prompt.getOptions() instanceof ChatOptions chatOptions) { + updatedRuntimeOptions = updatedRuntimeOptions.merge(chatOptions); } } @@ -334,6 +330,7 @@ else if (prompt.getOptions() instanceof ChatOptions) { ? updatedRuntimeOptions.getTemperature().floatValue() : null) .topP(updatedRuntimeOptions.getTopP() != null ? updatedRuntimeOptions.getTopP().floatValue() : null) .build(); + Document additionalModelRequestFields = ConverseApiUtils .getChatOptionsAdditionalModelRequestFields(this.defaultOptions, prompt.getOptions()); @@ -586,7 +583,7 @@ public static final class Builder { private Duration timeout = Duration.ofMinutes(10); - private FunctionCallingOptions defaultOptions = new FunctionCallingOptionsBuilder().build(); + private BedrockProxyChatOptions defaultOptions = BedrockProxyChatOptions.builder().build(); private FunctionCallbackContext functionCallbackContext; @@ -621,7 +618,7 @@ public Builder withTimeout(Duration timeout) { return this; } - public Builder withDefaultOptions(FunctionCallingOptions defaultOptions) { + public Builder withDefaultOptions(BedrockProxyChatOptions defaultOptions) { Assert.notNull(defaultOptions, "'defaultOptions' must not be null."); this.defaultOptions = defaultOptions; return this; diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatOptions.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatOptions.java new file mode 100644 index 00000000000..bf474608f27 --- /dev/null +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatOptions.java @@ -0,0 +1,277 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.springframework.ai.bedrock.converse; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +/** + * @author Christian Tzolov + * @since 1.0.0 + */ +public class BedrockProxyChatOptions implements FunctionCallingOptions { + + private List functionCallbacks = new ArrayList<>(); + + private Set functions = new HashSet<>(); + + private String model; + + private Double frequencyPenalty; + + private Integer maxTokens; + + private Double presencePenalty; + + private List stopSequences; + + private Double temperature; + + private Integer topK; + + private Double topP; + + private Boolean proxyToolCalls = false; + + private Map context = new HashMap<>(); + + private Map additional = new HashMap<>(); + + public static BedrockProxyChatOptionsBuilder builder() { + return new BedrockProxyChatOptionsBuilder(); + } + + @Override + public List getFunctionCallbacks() { + return Collections.unmodifiableList(this.functionCallbacks); + } + + public void setFunctionCallbacks(List functionCallbacks) { + Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null"); + this.functionCallbacks = new ArrayList<>(functionCallbacks); + } + + @Override + public Set getFunctions() { + return Collections.unmodifiableSet(this.functions); + } + + public void setFunctions(Set functions) { + Assert.notNull(functions, "Functions must not be null"); + this.functions = new HashSet<>(functions); + } + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Double getFrequencyPenalty() { + return this.frequencyPenalty; + } + + public void setFrequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + @Override + public Integer getMaxTokens() { + return this.maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + @Override + public Double getPresencePenalty() { + return this.presencePenalty; + } + + public void setPresencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + @Override + public List getStopSequences() { + return this.stopSequences; + } + + public void setStopSequences(List stopSequences) { + this.stopSequences = stopSequences; + } + + @Override + public Double getTemperature() { + return this.temperature; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + @Override + public Integer getTopK() { + return this.topK; + } + + public void setTopK(Integer topK) { + this.topK = topK; + } + + @Override + public Double getTopP() { + return this.topP; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + @Override + public Boolean getProxyToolCalls() { + return this.proxyToolCalls; + } + + public void setProxyToolCalls(Boolean proxyToolCalls) { + this.proxyToolCalls = proxyToolCalls; + } + + public Map getToolContext() { + return Collections.unmodifiableMap(this.context); + } + + public void setToolContext(Map context) { + Assert.notNull(context, "Context must not be null"); + this.context = new HashMap<>(context); + } + + public Map getAdditional() { + return Collections.unmodifiableMap(this.additional); + } + + public void setAdditional(Map additional) { + Assert.notNull(additional, "Additional must not be null"); + this.additional = new HashMap<>(additional); + } + + @Override + public ChatOptions copy() { + return new BedrockProxyChatOptionsBuilder().model(this.model) + .frequencyPenalty(this.frequencyPenalty) + .maxTokens(this.maxTokens) + .presencePenalty(this.presencePenalty) + .stopSequences(this.stopSequences != null ? new ArrayList<>(this.stopSequences) : null) + .temperature(this.temperature) + .topK(this.topK) + .topP(this.topP) + .functions(new HashSet<>(this.functions)) + .functionCallbacks(new ArrayList<>(this.functionCallbacks)) + .proxyToolCalls(this.proxyToolCalls) + .toolContext(new HashMap<>(this.getToolContext())) + .additional(new HashMap<>(this.additional)) + .build(); + } + + public BedrockProxyChatOptions merge(FunctionCallingOptions options) { + + var builder = builder().model(StringUtils.hasText(options.getModel()) ? options.getModel() : this.model) + .frequencyPenalty( + options.getFrequencyPenalty() != null ? options.getFrequencyPenalty() : this.frequencyPenalty) + .maxTokens(options.getMaxTokens() != null ? options.getMaxTokens() : this.maxTokens) + .presencePenalty(options.getPresencePenalty() != null ? options.getPresencePenalty() : this.presencePenalty) + .stopSequences(options.getStopSequences() != null ? options.getStopSequences() : this.stopSequences) + .temperature(options.getTemperature() != null ? options.getTemperature() : this.temperature) + .topK(options.getTopK() != null ? options.getTopK() : this.topK) + .topP(options.getTopP() != null ? options.getTopP() : this.topP) + .proxyToolCalls(options.getProxyToolCalls() != null ? options.getProxyToolCalls() : this.proxyToolCalls); + + Set functions = new HashSet<>(); + if (!CollectionUtils.isEmpty(this.functions)) { + functions.addAll(this.functions); + } + if (!CollectionUtils.isEmpty(options.getFunctions())) { + functions.addAll(options.getFunctions()); + } + builder.functions(functions); + + List functionCallbacks = new ArrayList<>(); + if (!CollectionUtils.isEmpty(this.functionCallbacks)) { + functionCallbacks.addAll(this.functionCallbacks); + } + if (!CollectionUtils.isEmpty(options.getFunctionCallbacks())) { + functionCallbacks.addAll(options.getFunctionCallbacks()); + } + builder.functionCallbacks(functionCallbacks); + + Map context = new HashMap<>(); + if (!CollectionUtils.isEmpty(this.context)) { + context.putAll(this.context); + } + if (!CollectionUtils.isEmpty(options.getToolContext())) { + context.putAll(options.getToolContext()); + } + builder.toolContext(context); + + Map additional = new HashMap<>(); + if (!CollectionUtils.isEmpty(this.additional)) { + context.putAll(this.additional); + } + + if (options instanceof BedrockProxyChatOptions bedrockProxyChatOptions) { + if (!CollectionUtils.isEmpty(bedrockProxyChatOptions.getAdditional())) { + additional.putAll(bedrockProxyChatOptions.getAdditional()); + } + } + builder.additional(additional); + + return builder.build(); + } + + public BedrockProxyChatOptions merge(ChatOptions options) { + + var builder = BedrockProxyChatOptions.builder() + .model(StringUtils.hasText(options.getModel()) ? options.getModel() : this.model) + .frequencyPenalty( + options.getFrequencyPenalty() != null ? options.getFrequencyPenalty() : this.frequencyPenalty) + .maxTokens(options.getMaxTokens() != null ? options.getMaxTokens() : this.maxTokens) + .presencePenalty(options.getPresencePenalty() != null ? options.getPresencePenalty() : this.presencePenalty) + .stopSequences(options.getStopSequences() != null ? options.getStopSequences() : this.stopSequences) + .temperature(options.getTemperature() != null ? options.getTemperature() : this.temperature) + .topK(options.getTopK() != null ? options.getTopK() : this.topK) + .topP(options.getTopP() != null ? options.getTopP() : this.topP); + + return builder.build(); + } + +} diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatOptionsBuilder.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatOptionsBuilder.java new file mode 100644 index 00000000000..0e3b8eac581 --- /dev/null +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatOptionsBuilder.java @@ -0,0 +1,151 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.bedrock.converse; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.util.Assert; + +/** + * Builder for {@link FunctionCallingOptions}. Using the {@link FunctionCallingOptions} + * permits options portability between different AI providers that support + * function-calling. + * + * @author Christian Tzolov + * @author Thomas Vitale + * @since 0.8.1 + */ +public class BedrockProxyChatOptionsBuilder { + + private BedrockProxyChatOptions options; + + BedrockProxyChatOptionsBuilder() { + this.options = new BedrockProxyChatOptions(); + } + + public BedrockProxyChatOptionsBuilder functionCallbacks(List functionCallbacks) { + this.options.setFunctionCallbacks(functionCallbacks); + return this; + } + + public BedrockProxyChatOptionsBuilder functionCallback(FunctionCallback functionCallback) { + Assert.notNull(functionCallback, "FunctionCallback must not be null"); + this.options.getFunctionCallbacks().add(functionCallback); + return this; + } + + public BedrockProxyChatOptionsBuilder functions(Set functions) { + this.options.setFunctions(functions); + return this; + } + + public BedrockProxyChatOptionsBuilder function(String function) { + Assert.notNull(function, "Function must not be null"); + var set = new HashSet<>(this.options.getFunctions()); + set.add(function); + this.options.setFunctions(set); + return this; + } + + public BedrockProxyChatOptionsBuilder model(String model) { + this.options.setModel(model); + return this; + } + + public BedrockProxyChatOptionsBuilder frequencyPenalty(Double frequencyPenalty) { + this.options.setFrequencyPenalty(frequencyPenalty); + return this; + } + + public BedrockProxyChatOptionsBuilder maxTokens(Integer maxTokens) { + this.options.setMaxTokens(maxTokens); + return this; + } + + public BedrockProxyChatOptionsBuilder presencePenalty(Double presencePenalty) { + this.options.setPresencePenalty(presencePenalty); + return this; + } + + public BedrockProxyChatOptionsBuilder stopSequences(List stopSequences) { + this.options.setStopSequences(stopSequences); + return this; + } + + public BedrockProxyChatOptionsBuilder temperature(Double temperature) { + this.options.setTemperature(temperature); + return this; + } + + public BedrockProxyChatOptionsBuilder topK(Integer topK) { + this.options.setTopK(topK); + return this; + } + + public BedrockProxyChatOptionsBuilder topP(Double topP) { + this.options.setTopP(topP); + return this; + } + + public BedrockProxyChatOptionsBuilder proxyToolCalls(Boolean proxyToolCalls) { + this.options.setProxyToolCalls(proxyToolCalls); + return this; + } + + public BedrockProxyChatOptionsBuilder toolContext(Map context) { + Assert.notNull(context, "Tool context must not be null"); + Map newContext = new HashMap<>(this.options.getToolContext()); + newContext.putAll(context); + this.options.setToolContext(newContext); + return this; + } + + public BedrockProxyChatOptionsBuilder toolContext(String key, Object value) { + Assert.notNull(key, "Key must not be null"); + Assert.notNull(value, "Value must not be null"); + Map newContext = new HashMap<>(this.options.getToolContext()); + newContext.put(key, value); + this.options.setToolContext(newContext); + return this; + } + + public BedrockProxyChatOptionsBuilder additional(Map additional) { + Assert.notNull(additional, "Additional must not be null"); + this.options.setAdditional(additional); + return this; + } + + public BedrockProxyChatOptionsBuilder additional(String key, Object value) { + Assert.notNull(key, "Key must not be null"); + Assert.notNull(value, "Value must not be null"); + Map newAdditional = new HashMap<>(this.options.getAdditional()); + newAdditional.put(key, value); + this.options.setAdditional(newAdditional); + return this; + } + + public BedrockProxyChatOptions build() { + return this.options; + } + +} diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java index 29038d972f4..8b583acbbfd 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java @@ -45,6 +45,7 @@ import software.amazon.awssdk.services.bedrockruntime.model.TokenUsage; import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlockStart; +import org.springframework.ai.bedrock.converse.BedrockProxyChatOptions; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; @@ -255,9 +256,9 @@ else if (event.sdkEventType() == EventType.METADATA) { return event; } - @SuppressWarnings("unchecked") public static Document getChatOptionsAdditionalModelRequestFields(ChatOptions defaultOptions, ModelOptions promptOptions) { + if (defaultOptions == null && promptOptions == null) { return null; } @@ -266,9 +267,19 @@ public static Document getChatOptionsAdditionalModelRequestFields(ChatOptions de if (defaultOptions != null) { attributes.putAll(ModelOptionsUtils.objectToMap(defaultOptions)); + if (defaultOptions instanceof BedrockProxyChatOptions bedrockProxyChatOptions) { + if (!CollectionUtils.isEmpty(bedrockProxyChatOptions.getAdditional())) { + attributes.putAll(bedrockProxyChatOptions.getAdditional()); + } + } } if (promptOptions != null) { + if (promptOptions instanceof BedrockProxyChatOptions bedrockProxyChatOptions) { + if (!CollectionUtils.isEmpty(bedrockProxyChatOptions.getAdditional())) { + attributes.putAll(bedrockProxyChatOptions.getAdditional()); + } + } if (promptOptions instanceof ChatOptions runtimeOptions) { attributes.putAll(ModelOptionsUtils.objectToMap(runtimeOptions)); } @@ -283,6 +294,7 @@ public static Document getChatOptionsAdditionalModelRequestFields(ChatOptions de attributes.remove("functions"); attributes.remove("toolContext"); attributes.remove("functionCallbacks"); + attributes.remove("additional"); attributes.remove("temperature"); attributes.remove("topK"); diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java index eaf220fb284..f2afdd0276b 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java @@ -21,7 +21,6 @@ import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; -import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.boot.SpringBootConfiguration; import org.springframework.context.annotation.Bean; @@ -42,7 +41,7 @@ public BedrockProxyChatModel bedrockConverseChatModel() { .withRegion(Region.US_EAST_1) .withTimeout(Duration.ofSeconds(120)) // .withRegion(Region.US_EAST_1) - .withDefaultOptions(FunctionCallingOptions.builder().withModel(modelId).build()) + .withDefaultOptions(BedrockProxyChatOptions.builder().model(modelId).build()) .build(); } diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java index 1b2be8a2724..1a49db335e1 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java @@ -17,11 +17,13 @@ package org.springframework.ai.bedrock.converse; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistry; import io.micrometer.observation.tck.TestObservationRegistryAssert; +import net.minidev.json.JSONObject; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; @@ -34,7 +36,6 @@ import org.springframework.ai.chat.observation.ChatModelObservationDocumentation.LowCardinalityKeyNames; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; import org.springframework.ai.observation.conventions.AiOperationType; import org.springframework.ai.observation.conventions.AiProvider; @@ -88,6 +89,30 @@ void observationForChatOperation() { validate(responseMetadata, "[\"end_turn\"]"); } + @Test + void observationForChatOperation2() { + var options = BedrockProxyChatOptions.builder() + .model("anthropic.claude-3-5-sonnet-20240620-v1:0") + .maxTokens(2048) + .stopSequences(List.of("this-is-the-end")) + .temperature(0.7) + .topP(1.0) + .additional("top_k", 100) // Additional parameter + // .additional("tool_choice", new JSONObject(Map.of("type", "auto"))) // + // Additional parameter + .build(); + + Prompt prompt = new Prompt("Why does a raven look like a desk?", options); + + ChatResponse chatResponse = this.chatModel.call(prompt); + assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty(); + + ChatResponseMetadata responseMetadata = chatResponse.getMetadata(); + assertThat(responseMetadata).isNotNull(); + + validate(responseMetadata, "[\"end_turn\"]"); + } + @Test void observationForStreamingChatOperation() { var options = PortableFunctionCallingOptions.builder() @@ -174,7 +199,7 @@ public BedrockProxyChatModel bedrockConverseChatModel(ObservationRegistry observ .withCredentialsProvider(EnvironmentVariableCredentialsProvider.create()) .withRegion(Region.US_EAST_1) .withObservationRegistry(observationRegistry) - .withDefaultOptions(FunctionCallingOptions.builder().withModel(modelId).build()) + .withDefaultOptions(BedrockProxyChatOptions.builder().model(modelId).build()) .build(); } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatProperties.java index 88f5068025a..2693d1a4d94 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatProperties.java @@ -16,7 +16,7 @@ package org.springframework.ai.autoconfigure.bedrock.converse; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; +import org.springframework.ai.bedrock.converse.BedrockProxyChatOptions; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; import org.springframework.util.Assert; @@ -38,10 +38,10 @@ public class BedrockConverseProxyChatProperties { private boolean enabled = true; @NestedConfigurationProperty - private PortableFunctionCallingOptions options = PortableFunctionCallingOptions.builder() - .withTemperature(0.7) - .withMaxTokens(300) - .withTopK(10) + private BedrockProxyChatOptions options = BedrockProxyChatOptions.builder() + .temperature(0.7) + .maxTokens(300) + .topK(10) .build(); public boolean isEnabled() { @@ -52,11 +52,11 @@ public void setEnabled(boolean enabled) { this.enabled = enabled; } - public PortableFunctionCallingOptions getOptions() { + public BedrockProxyChatOptions getOptions() { return this.options; } - public void setOptions(PortableFunctionCallingOptions options) { + public void setOptions(BedrockProxyChatOptions options) { Assert.notNull(options, "PortableFunctionCallingOptions must not be null"); this.options = options; }