Skip to content

Commit e278508

Browse files
committed
feat: Add three request body parameters to Chat Completion
Add presence_penalty, frequency_penalty, and n parameters to Chat Completion request body following Mistral AI API specifications as referenced in https://docs.mistral.ai/api/#tag/chat
1 parent 8f20aab commit e278508

File tree

2 files changed

+77
-16
lines changed

2 files changed

+77
-16
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,26 @@ public class MistralAiChatOptions implements ToolCallingChatOptions {
101101
*/
102102
private @JsonProperty("stop") List<String> stop;
103103

104+
/**
105+
* Number between -2.0 and 2.0. frequency_penalty penalizes the repetition of words
106+
* based on their frequency in the generated text. A higher frequency penalty discourages
107+
* the model from repeating words that have already appeared frequently in the
108+
* output, promoting diversity and reducing repetition.
109+
*/
110+
private @JsonProperty("frequency_penalty") Double frequencyPenalty;
111+
112+
/**
113+
* Number between -2.0 and 2.0. presence_penalty determines how much the model penalizes
114+
* the repetition of words or phrases. A higher presence penalty encourages the model to
115+
* use a wider variety of words and phrases, making the output more diverse and creative.
116+
*/
117+
private @JsonProperty("presence_penalty") Double presencePenalty;
118+
119+
/**
120+
* Number of completions to return for each request, input tokens are only billed once.
121+
*/
122+
private @JsonProperty("n") Integer n;
123+
104124
/**
105125
* A list of tools the model may call. Currently, only functions are supported as a
106126
* tool. Use this to provide a list of functions the model may generate JSON inputs
@@ -151,6 +171,9 @@ public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions)
151171
.topP(fromOptions.getTopP())
152172
.responseFormat(fromOptions.getResponseFormat())
153173
.stop(fromOptions.getStop())
174+
.frequencyPenalty(fromOptions.getFrequencyPenalty())
175+
.presencePenalty(fromOptions.getPresencePenalty())
176+
.N(fromOptions.getN())
154177
.tools(fromOptions.getTools())
155178
.toolChoice(fromOptions.getToolChoice())
156179
.toolCallbacks(fromOptions.getToolCallbacks())
@@ -255,6 +278,32 @@ public void setTopP(Double topP) {
255278
this.topP = topP;
256279
}
257280

281+
@Override
282+
public Double getFrequencyPenalty() {
283+
return this.frequencyPenalty;
284+
}
285+
286+
public void setFrequencyPenalty(Double frequencyPenalty) {
287+
this.frequencyPenalty = frequencyPenalty;
288+
}
289+
290+
@Override
291+
public Double getPresencePenalty() {
292+
return this.presencePenalty;
293+
}
294+
295+
public void setPresencePenalty(Double presencePenalty) {
296+
this.presencePenalty = presencePenalty;
297+
}
298+
299+
public Integer getN() {
300+
return this.n;
301+
}
302+
303+
public void setN(Integer n) {
304+
this.n = n;
305+
}
306+
258307
@Override
259308
@JsonIgnore
260309
public List<FunctionCallback> getToolCallbacks() {
@@ -325,18 +374,6 @@ public void setFunctions(Set<String> functionNames) {
325374
this.setToolNames(functionNames);
326375
}
327376

328-
@Override
329-
@JsonIgnore
330-
public Double getFrequencyPenalty() {
331-
return null;
332-
}
333-
334-
@Override
335-
@JsonIgnore
336-
public Double getPresencePenalty() {
337-
return null;
338-
}
339-
340377
@Override
341378
@JsonIgnore
342379
public Integer getTopK() {
@@ -376,8 +413,8 @@ public MistralAiChatOptions copy() {
376413
@Override
377414
public int hashCode() {
378415
return Objects.hash(this.model, this.temperature, this.topP, this.maxTokens, this.safePrompt, this.randomSeed,
379-
this.responseFormat, this.stop, this.tools, this.toolChoice, this.toolCallbacks, this.tools,
380-
this.internalToolExecutionEnabled, this.toolContext);
416+
this.responseFormat, this.stop, this.frequencyPenalty, this.presencePenalty, this.n, this.tools,
417+
this.toolChoice, this.toolCallbacks, this.tools, this.internalToolExecutionEnabled, this.toolContext);
381418
}
382419

383420
@Override
@@ -397,6 +434,9 @@ public boolean equals(Object obj) {
397434
&& Objects.equals(this.safePrompt, other.safePrompt)
398435
&& Objects.equals(this.randomSeed, other.randomSeed)
399436
&& Objects.equals(this.responseFormat, other.responseFormat) && Objects.equals(this.stop, other.stop)
437+
&& Objects.equals(this.frequencyPenalty, other.frequencyPenalty)
438+
&& Objects.equals(this.presencePenalty, other.presencePenalty)
439+
&& Objects.equals(this.n, other.n)
400440
&& Objects.equals(this.tools, other.tools) && Objects.equals(this.toolChoice, other.toolChoice)
401441
&& Objects.equals(this.toolCallbacks, other.toolCallbacks)
402442
&& Objects.equals(this.toolNames, other.toolNames)
@@ -438,6 +478,21 @@ public Builder stop(List<String> stop) {
438478
return this;
439479
}
440480

481+
public Builder frequencyPenalty(Double frequencyPenalty) {
482+
this.options.frequencyPenalty = frequencyPenalty;
483+
return this;
484+
}
485+
486+
public Builder presencePenalty(Double presencePenalty) {
487+
this.options.presencePenalty = presencePenalty;
488+
return this;
489+
}
490+
491+
public Builder N(Integer n) {
492+
this.options.n = n;
493+
return this;
494+
}
495+
441496
public Builder temperature(Double temperature) {
442497
this.options.setTemperature(temperature);
443498
return this;

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ void observationForChatOperation() {
7474
.stop(List.of("this-is-the-end"))
7575
.temperature(0.7)
7676
.topP(1.0)
77+
.presencePenalty(0.0)
78+
.frequencyPenalty(0.0)
79+
.N(2)
7780
.build();
7881

7982
Prompt prompt = new Prompt("Why does a raven look like a desk?", options);
@@ -95,6 +98,9 @@ void observationForStreamingChatOperation() {
9598
.stop(List.of("this-is-the-end"))
9699
.temperature(0.7)
97100
.topP(1.0)
101+
.presencePenalty(0.0)
102+
.frequencyPenalty(0.0)
103+
.N(2)
98104
.build();
99105

100106
Prompt prompt = new Prompt("Why does a raven look like a desk?", options);
@@ -133,9 +139,9 @@ private void validate(ChatResponseMetadata responseMetadata) {
133139
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(),
134140
StringUtils.hasText(responseMetadata.getModel()) ? responseMetadata.getModel()
135141
: KeyValue.NONE_VALUE)
136-
.doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString())
142+
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString(), "0.0")
143+
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString(), "0.0")
137144
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_MAX_TOKENS.asString(), "2048")
138-
.doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString())
139145
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_STOP_SEQUENCES.asString(),
140146
"[\"this-is-the-end\"]")
141147
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TEMPERATURE.asString(), "0.7")

0 commit comments

Comments
 (0)