diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java index a59e8a71e58..f1ffcc5ba22 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java @@ -101,6 +101,28 @@ public class MistralAiChatOptions implements ToolCallingChatOptions { */ private @JsonProperty("stop") List stop; + /** + * Number between -2.0 and 2.0. frequency_penalty penalizes the repetition of words + * based on their frequency in the generated text. A higher frequency penalty + * discourages the model from repeating words that have already appeared frequently in + * the output, promoting diversity and reducing repetition. + */ + private @JsonProperty("frequency_penalty") Double frequencyPenalty; + + /** + * Number between -2.0 and 2.0. presence_penalty determines how much the model + * penalizes the repetition of words or phrases. A higher presence penalty encourages + * the model to use a wider variety of words and phrases, making the output more + * diverse and creative. + */ + private @JsonProperty("presence_penalty") Double presencePenalty; + + /** + * Number of completions to return for each request, input tokens are only billed + * once. + */ + private @JsonProperty("n") Integer n; + /** * A list of tools the model may call. Currently, only functions are supported as a * tool. Use this to provide a list of functions the model may generate JSON inputs @@ -151,6 +173,9 @@ public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions) .topP(fromOptions.getTopP()) .responseFormat(fromOptions.getResponseFormat()) .stop(fromOptions.getStop()) + .frequencyPenalty(fromOptions.getFrequencyPenalty()) + .presencePenalty(fromOptions.getPresencePenalty()) + .n(fromOptions.getN()) .tools(fromOptions.getTools()) .toolChoice(fromOptions.getToolChoice()) .toolCallbacks(fromOptions.getToolCallbacks()) @@ -255,6 +280,32 @@ public void setTopP(Double topP) { this.topP = topP; } + @Override + public Double getFrequencyPenalty() { + return this.frequencyPenalty; + } + + public void setFrequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + @Override + public Double getPresencePenalty() { + return this.presencePenalty; + } + + public void setPresencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public Integer getN() { + return this.n; + } + + public void setN(Integer n) { + this.n = n; + } + @Override @JsonIgnore public List getToolCallbacks() { @@ -325,18 +376,6 @@ public void setFunctions(Set functionNames) { this.setToolNames(functionNames); } - @Override - @JsonIgnore - public Double getFrequencyPenalty() { - return null; - } - - @Override - @JsonIgnore - public Double getPresencePenalty() { - return null; - } - @Override @JsonIgnore public Integer getTopK() { @@ -376,8 +415,8 @@ public MistralAiChatOptions copy() { @Override public int hashCode() { return Objects.hash(this.model, this.temperature, this.topP, this.maxTokens, this.safePrompt, this.randomSeed, - this.responseFormat, this.stop, this.tools, this.toolChoice, this.toolCallbacks, this.tools, - this.internalToolExecutionEnabled, this.toolContext); + this.responseFormat, this.stop, this.frequencyPenalty, this.presencePenalty, this.n, this.tools, + this.toolChoice, this.toolCallbacks, this.tools, this.internalToolExecutionEnabled, this.toolContext); } @Override @@ -397,6 +436,8 @@ public boolean equals(Object obj) { && Objects.equals(this.safePrompt, other.safePrompt) && Objects.equals(this.randomSeed, other.randomSeed) && Objects.equals(this.responseFormat, other.responseFormat) && Objects.equals(this.stop, other.stop) + && Objects.equals(this.frequencyPenalty, other.frequencyPenalty) + && Objects.equals(this.presencePenalty, other.presencePenalty) && Objects.equals(this.n, other.n) && Objects.equals(this.tools, other.tools) && Objects.equals(this.toolChoice, other.toolChoice) && Objects.equals(this.toolCallbacks, other.toolCallbacks) && Objects.equals(this.toolNames, other.toolNames) @@ -438,6 +479,21 @@ public Builder stop(List stop) { return this; } + public Builder frequencyPenalty(Double frequencyPenalty) { + this.options.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder presencePenalty(Double presencePenalty) { + this.options.presencePenalty = presencePenalty; + return this; + } + + public Builder n(Integer n) { + this.options.n = n; + return this; + } + public Builder temperature(Double temperature) { this.options.setTemperature(temperature); return this; diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java index 31070144baf..7e2aeed21ea 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java @@ -74,6 +74,9 @@ void observationForChatOperation() { .stop(List.of("this-is-the-end")) .temperature(0.7) .topP(1.0) + .presencePenalty(0.0) + .frequencyPenalty(0.0) + .n(2) .build(); Prompt prompt = new Prompt("Why does a raven look like a desk?", options); @@ -95,6 +98,9 @@ void observationForStreamingChatOperation() { .stop(List.of("this-is-the-end")) .temperature(0.7) .topP(1.0) + .presencePenalty(0.0) + .frequencyPenalty(0.0) + .n(2) .build(); Prompt prompt = new Prompt("Why does a raven look like a desk?", options); @@ -133,9 +139,9 @@ private void validate(ChatResponseMetadata responseMetadata) { .hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), StringUtils.hasText(responseMetadata.getModel()) ? responseMetadata.getModel() : KeyValue.NONE_VALUE) - .doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString()) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString(), "0.0") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString(), "0.0") .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_MAX_TOKENS.asString(), "2048") - .doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString()) .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_STOP_SEQUENCES.asString(), "[\"this-is-the-end\"]") .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TEMPERATURE.asString(), "0.7")