Skip to content

Add request body parameters to Mistral AI Chat Completion #2706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,28 @@ public class MistralAiChatOptions implements ToolCallingChatOptions {
*/
private @JsonProperty("stop") List<String> stop;

/**
* Number between -2.0 and 2.0. frequency_penalty penalizes the repetition of words
* based on their frequency in the generated text. A higher frequency penalty
* discourages the model from repeating words that have already appeared frequently in
* the output, promoting diversity and reducing repetition.
*/
private @JsonProperty("frequency_penalty") Double frequencyPenalty;

/**
* Number between -2.0 and 2.0. presence_penalty determines how much the model
* penalizes the repetition of words or phrases. A higher presence penalty encourages
* the model to use a wider variety of words and phrases, making the output more
* diverse and creative.
*/
private @JsonProperty("presence_penalty") Double presencePenalty;

/**
* Number of completions to return for each request, input tokens are only billed
* once.
*/
private @JsonProperty("n") Integer n;

/**
* A list of tools the model may call. Currently, only functions are supported as a
* tool. Use this to provide a list of functions the model may generate JSON inputs
Expand Down Expand Up @@ -151,6 +173,9 @@ public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions)
.topP(fromOptions.getTopP())
.responseFormat(fromOptions.getResponseFormat())
.stop(fromOptions.getStop())
.frequencyPenalty(fromOptions.getFrequencyPenalty())
.presencePenalty(fromOptions.getPresencePenalty())
.n(fromOptions.getN())
.tools(fromOptions.getTools())
.toolChoice(fromOptions.getToolChoice())
.toolCallbacks(fromOptions.getToolCallbacks())
Expand Down Expand Up @@ -255,6 +280,32 @@ public void setTopP(Double topP) {
this.topP = topP;
}

@Override
public Double getFrequencyPenalty() {
return this.frequencyPenalty;
}

public void setFrequencyPenalty(Double frequencyPenalty) {
this.frequencyPenalty = frequencyPenalty;
}

@Override
public Double getPresencePenalty() {
return this.presencePenalty;
}

public void setPresencePenalty(Double presencePenalty) {
this.presencePenalty = presencePenalty;
}

public Integer getN() {
return this.n;
}

public void setN(Integer n) {
this.n = n;
}

@Override
@JsonIgnore
public List<FunctionCallback> getToolCallbacks() {
Expand Down Expand Up @@ -325,18 +376,6 @@ public void setFunctions(Set<String> functionNames) {
this.setToolNames(functionNames);
}

@Override
@JsonIgnore
public Double getFrequencyPenalty() {
return null;
}

@Override
@JsonIgnore
public Double getPresencePenalty() {
return null;
}

@Override
@JsonIgnore
public Integer getTopK() {
Expand Down Expand Up @@ -376,8 +415,8 @@ public MistralAiChatOptions copy() {
@Override
public int hashCode() {
return Objects.hash(this.model, this.temperature, this.topP, this.maxTokens, this.safePrompt, this.randomSeed,
this.responseFormat, this.stop, this.tools, this.toolChoice, this.toolCallbacks, this.tools,
this.internalToolExecutionEnabled, this.toolContext);
this.responseFormat, this.stop, this.frequencyPenalty, this.presencePenalty, this.n, this.tools,
this.toolChoice, this.toolCallbacks, this.tools, this.internalToolExecutionEnabled, this.toolContext);
}

@Override
Expand All @@ -397,6 +436,8 @@ public boolean equals(Object obj) {
&& Objects.equals(this.safePrompt, other.safePrompt)
&& Objects.equals(this.randomSeed, other.randomSeed)
&& Objects.equals(this.responseFormat, other.responseFormat) && Objects.equals(this.stop, other.stop)
&& Objects.equals(this.frequencyPenalty, other.frequencyPenalty)
&& Objects.equals(this.presencePenalty, other.presencePenalty) && Objects.equals(this.n, other.n)
&& Objects.equals(this.tools, other.tools) && Objects.equals(this.toolChoice, other.toolChoice)
&& Objects.equals(this.toolCallbacks, other.toolCallbacks)
&& Objects.equals(this.toolNames, other.toolNames)
Expand Down Expand Up @@ -438,6 +479,21 @@ public Builder stop(List<String> stop) {
return this;
}

public Builder frequencyPenalty(Double frequencyPenalty) {
this.options.frequencyPenalty = frequencyPenalty;
return this;
}

public Builder presencePenalty(Double presencePenalty) {
this.options.presencePenalty = presencePenalty;
return this;
}

public Builder n(Integer n) {
this.options.n = n;
return this;
}

public Builder temperature(Double temperature) {
this.options.setTemperature(temperature);
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ void observationForChatOperation() {
.stop(List.of("this-is-the-end"))
.temperature(0.7)
.topP(1.0)
.presencePenalty(0.0)
.frequencyPenalty(0.0)
.n(2)
.build();

Prompt prompt = new Prompt("Why does a raven look like a desk?", options);
Expand All @@ -95,6 +98,9 @@ void observationForStreamingChatOperation() {
.stop(List.of("this-is-the-end"))
.temperature(0.7)
.topP(1.0)
.presencePenalty(0.0)
.frequencyPenalty(0.0)
.n(2)
.build();

Prompt prompt = new Prompt("Why does a raven look like a desk?", options);
Expand Down Expand Up @@ -133,9 +139,9 @@ private void validate(ChatResponseMetadata responseMetadata) {
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(),
StringUtils.hasText(responseMetadata.getModel()) ? responseMetadata.getModel()
: KeyValue.NONE_VALUE)
.doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString())
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString(), "0.0")
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString(), "0.0")
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_MAX_TOKENS.asString(), "2048")
.doesNotHaveHighCardinalityKeyValueWithKey(HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString())
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_STOP_SEQUENCES.asString(),
"[\"this-is-the-end\"]")
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_TEMPERATURE.asString(), "0.7")
Expand Down