Skip to content

Commit b74e308

Browse files
jee14sobychacko
authored andcommitted
Add three request body parameters to Mistral AI Chat Completion
Add presence_penalty, frequency_penalty, and n parameters Following Mistral AI API specifications as referenced in https://docs.mistral.ai/api/#tag/chat Rename Builder method N() to n() Signed-off-by: Seunghyeon Ji <[email protected]>
1 parent 1e1ad41 commit b74e308

File tree

2 files changed

+78
-16
lines changed

2 files changed

+78
-16
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,28 @@ public class MistralAiChatOptions implements ToolCallingChatOptions {
100100
*/
101101
private @JsonProperty("stop") List<String> stop;
102102

103+
/**
104+
* Number between -2.0 and 2.0. frequency_penalty penalizes the repetition of words
105+
* based on their frequency in the generated text. A higher frequency penalty
106+
* discourages the model from repeating words that have already appeared frequently in
107+
* the output, promoting diversity and reducing repetition.
108+
*/
109+
private @JsonProperty("frequency_penalty") Double frequencyPenalty;
110+
111+
/**
112+
* Number between -2.0 and 2.0. presence_penalty determines how much the model
113+
* penalizes the repetition of words or phrases. A higher presence penalty encourages
114+
* the model to use a wider variety of words and phrases, making the output more
115+
* diverse and creative.
116+
*/
117+
private @JsonProperty("presence_penalty") Double presencePenalty;
118+
119+
/**
120+
* Number of completions to return for each request, input tokens are only billed
121+
* once.
122+
*/
123+
private @JsonProperty("n") Integer n;
124+
103125
/**
104126
* A list of tools the model may call. Currently, only functions are supported as a
105127
* tool. Use this to provide a list of functions the model may generate JSON inputs
@@ -150,6 +172,9 @@ public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions)
150172
.topP(fromOptions.getTopP())
151173
.responseFormat(fromOptions.getResponseFormat())
152174
.stop(fromOptions.getStop())
175+
.frequencyPenalty(fromOptions.getFrequencyPenalty())
176+
.presencePenalty(fromOptions.getPresencePenalty())
177+
.n(fromOptions.getN())
153178
.tools(fromOptions.getTools())
154179
.toolChoice(fromOptions.getToolChoice())
155180
.toolCallbacks(fromOptions.getToolCallbacks())
@@ -254,6 +279,32 @@ public void setTopP(Double topP) {
254279
this.topP = topP;
255280
}
256281

282+
@Override
283+
public Double getFrequencyPenalty() {
284+
return this.frequencyPenalty;
285+
}
286+
287+
public void setFrequencyPenalty(Double frequencyPenalty) {
288+
this.frequencyPenalty = frequencyPenalty;
289+
}
290+
291+
@Override
292+
public Double getPresencePenalty() {
293+
return this.presencePenalty;
294+
}
295+
296+
public void setPresencePenalty(Double presencePenalty) {
297+
this.presencePenalty = presencePenalty;
298+
}
299+
300+
public Integer getN() {
301+
return this.n;
302+
}
303+
304+
public void setN(Integer n) {
305+
this.n = n;
306+
}
307+
257308
@Override
258309
@JsonIgnore
259310
public List<ToolCallback> getToolCallbacks() {
@@ -296,18 +347,6 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut
296347
this.internalToolExecutionEnabled = internalToolExecutionEnabled;
297348
}
298349

299-
@Override
300-
@JsonIgnore
301-
public Double getFrequencyPenalty() {
302-
return null;
303-
}
304-
305-
@Override
306-
@JsonIgnore
307-
public Double getPresencePenalty() {
308-
return null;
309-
}
310-
311350
@Override
312351
@JsonIgnore
313352
public Integer getTopK() {
@@ -334,8 +373,8 @@ public MistralAiChatOptions copy() {
334373
@Override
335374
public int hashCode() {
336375
return Objects.hash(this.model, this.temperature, this.topP, this.maxTokens, this.safePrompt, this.randomSeed,
337-
this.responseFormat, this.stop, this.tools, this.toolChoice, this.toolCallbacks, this.tools,
338-
this.internalToolExecutionEnabled, this.toolContext);
376+
this.responseFormat, this.stop, this.frequencyPenalty, this.presencePenalty, this.n, this.tools,
377+
this.toolChoice, this.toolCallbacks, this.tools, this.internalToolExecutionEnabled, this.toolContext);
339378
}
340379

341380
@Override
@@ -355,6 +394,8 @@ public boolean equals(Object obj) {
355394
&& Objects.equals(this.safePrompt, other.safePrompt)
356395
&& Objects.equals(this.randomSeed, other.randomSeed)
357396
&& Objects.equals(this.responseFormat, other.responseFormat) && Objects.equals(this.stop, other.stop)
397+
&& Objects.equals(this.frequencyPenalty, other.frequencyPenalty)
398+
&& Objects.equals(this.presencePenalty, other.presencePenalty) && Objects.equals(this.n, other.n)
358399
&& Objects.equals(this.tools, other.tools) && Objects.equals(this.toolChoice, other.toolChoice)
359400
&& Objects.equals(this.toolCallbacks, other.toolCallbacks)
360401
&& Objects.equals(this.toolNames, other.toolNames)
@@ -396,6 +437,21 @@ public Builder stop(List<String> stop) {
396437
return this;
397438
}
398439

440+
public Builder frequencyPenalty(Double frequencyPenalty) {
441+
this.options.frequencyPenalty = frequencyPenalty;
442+
return this;
443+
}
444+
445+
public Builder presencePenalty(Double presencePenalty) {
446+
this.options.presencePenalty = presencePenalty;
447+
return this;
448+
}
449+
450+
public Builder n(Integer n) {
451+
this.options.n = n;
452+
return this;
453+
}
454+
399455
public Builder temperature(Double temperature) {
400456
this.options.setTemperature(temperature);
401457
return this;

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ void observationForChatOperation() {
7474
.stop(List.of("this-is-the-end"))
7575
.temperature(0.7)
7676
.topP(1.0)
77+
.presencePenalty(0.0)
78+
.frequencyPenalty(0.0)
79+
.n(2)
7780
.build();
7881

7982
Prompt prompt = new Prompt("Why does a raven look like a desk?", options);
@@ -95,6 +98,9 @@ void observationForStreamingChatOperation() {
9598
.stop(List.of("this-is-the-end"))
9699
.temperature(0.7)
97100
.topP(1.0)
101+
.presencePenalty(0.0)
102+
.frequencyPenalty(0.0)
103+
.n(2)
98104
.build();
99105

100106
Prompt prompt = new Prompt("Why does a raven look like a desk?", options);
@@ -133,9 +139,9 @@ private void validate(ChatResponseMetadata responseMetadata) {
133139
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(),
134140
StringUtils.hasText(responseMetadata.getModel()) ? responseMetadata.getModel()
135141
: KeyValue.NONE_VALUE)
136-
.doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString())
142+
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString(), "0.0")
143+
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString(), "0.0")
137144
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_MAX_TOKENS.asString(), "2048")
138-
.doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString())
139145
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_STOP_SEQUENCES.asString(),
140146
"[\"this-is-the-end\"]")
141147
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TEMPERATURE.asString(), "0.7")

0 commit comments

Comments
 (0)