Skip to content

Commit 3a3439e

Browse files
Bart Veenstramarkpollack
authored andcommitted
feat: enhance AzureOpenAiResponseFormat to support JSON schema and builder pattern
Signed-off-by: Bart Veenstra <[email protected]>
1 parent 52675d8 commit 3a3439e

File tree

4 files changed

+267
-10
lines changed

4 files changed

+267
-10
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package org.springframework.ai.azure.openai;
1818

19+
import com.azure.ai.openai.models.ChatCompletionsJsonSchemaResponseFormat;
20+
import com.azure.ai.openai.models.ChatCompletionsJsonSchemaResponseFormatJsonSchema;
1921
import java.util.ArrayList;
2022
import java.util.Base64;
2123
import java.util.Collections;
@@ -59,6 +61,8 @@
5961
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
6062
import org.slf4j.Logger;
6163
import org.slf4j.LoggerFactory;
64+
import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.JsonSchema;
65+
import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.Type;
6266
import reactor.core.publisher.Flux;
6367
import reactor.core.scheduler.Schedulers;
6468

@@ -115,6 +119,7 @@
115119
* @author Alexandros Pappas
116120
* @author Berjan Jonker
117121
* @author Andres da Silva Santos
122+
* @author Bart Veenstra
118123
* @see ChatModel
119124
* @see com.azure.ai.openai.OpenAIClient
120125
* @since 1.0.0
@@ -918,9 +923,16 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {
918923
* @return Azure response format
919924
*/
920925
private ChatCompletionsResponseFormat toAzureResponseFormat(AzureOpenAiResponseFormat responseFormat) {
921-
if (responseFormat == AzureOpenAiResponseFormat.JSON) {
926+
if (responseFormat.getType() == Type.JSON_OBJECT) {
922927
return new ChatCompletionsJsonResponseFormat();
923928
}
929+
if (responseFormat.getType() == Type.JSON_SCHEMA) {
930+
JsonSchema jsonSchema = responseFormat.getJsonSchema();
931+
var responseFormatJsonSchema = new ChatCompletionsJsonSchemaResponseFormatJsonSchema(jsonSchema.getName());
932+
String jsonString = ModelOptionsUtils.toJsonString(jsonSchema.getSchema());
933+
responseFormatJsonSchema.setSchema(BinaryData.fromString(jsonString));
934+
return new ChatCompletionsJsonSchemaResponseFormat(responseFormatJsonSchema);
935+
}
924936
return new ChatCompletionsTextResponseFormat();
925937
}
926938

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiResponseFormat.java

Lines changed: 242 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,23 @@
1616

1717
package org.springframework.ai.azure.openai;
1818

19+
import com.fasterxml.jackson.annotation.JsonInclude;
20+
import com.fasterxml.jackson.annotation.JsonInclude.Include;
21+
import com.fasterxml.jackson.annotation.JsonProperty;
22+
import java.util.Map;
23+
import java.util.Objects;
24+
import org.springframework.ai.model.ModelOptionsUtils;
25+
import org.springframework.util.StringUtils;
26+
1927
/**
2028
* Utility enumeration for representing the response format that may be requested from the
2129
* Azure OpenAI model. Please check <a href=
2230
* "https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format">OpenAI
2331
* API documentation</a> for more details.
2432
*/
25-
public enum AzureOpenAiResponseFormat {
33+
@JsonInclude(Include.NON_NULL)
34+
public class AzureOpenAiResponseFormat {
2635

27-
// default value used by OpenAI
28-
TEXT,
2936
/*
3037
* From the OpenAI API documentation: Compatibility: Compatible with GPT-4 Turbo and
3138
* all GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Caveats: This enables JSON
@@ -36,7 +43,238 @@ public enum AzureOpenAiResponseFormat {
3643
* long-running and seemingly "stuck" request. Also note that the message content may
3744
* be partially cut off if finish_reason="length", which indicates the generation
3845
* exceeded max_tokens or the conversation exceeded the max context length.
46+
*
47+
* Type Must be one of 'text', 'json_object' or 'json_schema'.
48+
*/
49+
@JsonProperty("type")
50+
private Type type;
51+
52+
/**
53+
* JSON schema object that describes the format of the JSON object. Only applicable
54+
* when type is 'json_schema'.
55+
*/
56+
@JsonProperty("json_schema")
57+
private JsonSchema jsonSchema = null;
58+
59+
private String schema;
60+
61+
public AzureOpenAiResponseFormat() {
62+
63+
}
64+
65+
public Type getType() {
66+
return this.type;
67+
}
68+
69+
public void setType(Type type) {
70+
this.type = type;
71+
}
72+
73+
public JsonSchema getJsonSchema() {
74+
return this.jsonSchema;
75+
}
76+
77+
public void setJsonSchema(JsonSchema jsonSchema) {
78+
this.jsonSchema = jsonSchema;
79+
}
80+
81+
public String getSchema() {
82+
return this.schema;
83+
}
84+
85+
public void setSchema(String schema) {
86+
this.schema = schema;
87+
if (schema != null) {
88+
this.jsonSchema = JsonSchema.builder().schema(schema).strict(true).build();
89+
}
90+
}
91+
92+
private AzureOpenAiResponseFormat(Type type, JsonSchema jsonSchema) {
93+
this.type = type;
94+
this.jsonSchema = jsonSchema;
95+
}
96+
97+
public AzureOpenAiResponseFormat(Type type, String schema) {
98+
this(type, StringUtils.hasText(schema) ? JsonSchema.builder().schema(schema).strict(true).build() : null);
99+
}
100+
101+
public static Builder builder() {
102+
return new Builder();
103+
}
104+
105+
@Override
106+
public boolean equals(Object o) {
107+
if (this == o) {
108+
return true;
109+
}
110+
if (o == null || getClass() != o.getClass()) {
111+
return false;
112+
}
113+
AzureOpenAiResponseFormat that = (AzureOpenAiResponseFormat) o;
114+
return this.type == that.type && Objects.equals(this.jsonSchema, that.jsonSchema);
115+
}
116+
117+
@Override
118+
public int hashCode() {
119+
return Objects.hash(this.type, this.jsonSchema);
120+
}
121+
122+
@Override
123+
public String toString() {
124+
return "ResponseFormat{" + "type=" + this.type + ", jsonSchema=" + this.jsonSchema + '}';
125+
}
126+
127+
public static final class Builder {
128+
129+
private Type type;
130+
131+
private JsonSchema jsonSchema;
132+
133+
private Builder() {
134+
}
135+
136+
public Builder type(Type type) {
137+
this.type = type;
138+
return this;
139+
}
140+
141+
public Builder jsonSchema(JsonSchema jsonSchema) {
142+
this.jsonSchema = jsonSchema;
143+
return this;
144+
}
145+
146+
public Builder jsonSchema(String jsonSchema) {
147+
this.jsonSchema = JsonSchema.builder().schema(jsonSchema).build();
148+
return this;
149+
}
150+
151+
public AzureOpenAiResponseFormat build() {
152+
return new AzureOpenAiResponseFormat(this.type, this.jsonSchema);
153+
}
154+
155+
}
156+
157+
public enum Type {
158+
159+
/**
160+
* Generates a text response. (default)
161+
*/
162+
@JsonProperty("text")
163+
TEXT,
164+
165+
/**
166+
* Enables JSON mode, which guarantees the message the model generates is valid
167+
* JSON.
168+
*/
169+
@JsonProperty("json_object")
170+
JSON_OBJECT,
171+
172+
/**
173+
* Enables Structured Outputs which guarantees the model will match your supplied
174+
* JSON schema.
175+
*/
176+
@JsonProperty("json_schema")
177+
JSON_SCHEMA
178+
179+
}
180+
181+
/**
182+
* JSON schema object that describes the format of the JSON object. Applicable for the
183+
* 'json_schema' type only.
39184
*/
40-
JSON
185+
@JsonInclude(Include.NON_NULL)
186+
public static class JsonSchema {
187+
188+
@JsonProperty("name")
189+
private String name;
190+
191+
@JsonProperty("schema")
192+
private Map<String, Object> schema;
193+
194+
@JsonProperty("strict")
195+
private Boolean strict;
196+
197+
public JsonSchema() {
198+
199+
}
200+
201+
public String getName() {
202+
return this.name;
203+
}
204+
205+
public Map<String, Object> getSchema() {
206+
return this.schema;
207+
}
208+
209+
public Boolean getStrict() {
210+
return this.strict;
211+
}
212+
213+
private JsonSchema(String name, Map<String, Object> schema, Boolean strict) {
214+
this.name = name;
215+
this.schema = schema;
216+
this.strict = strict;
217+
}
218+
219+
public static Builder builder() {
220+
return new Builder();
221+
}
222+
223+
@Override
224+
public int hashCode() {
225+
return Objects.hash(this.name, this.schema, this.strict);
226+
}
227+
228+
@Override
229+
public boolean equals(Object o) {
230+
if (this == o) {
231+
return true;
232+
}
233+
if (o == null || getClass() != o.getClass()) {
234+
return false;
235+
}
236+
JsonSchema that = (JsonSchema) o;
237+
return Objects.equals(this.name, that.name) && Objects.equals(this.schema, that.schema)
238+
&& Objects.equals(this.strict, that.strict);
239+
}
240+
241+
public static final class Builder {
242+
243+
private String name = "custom_schema";
244+
245+
private Map<String, Object> schema;
246+
247+
private Boolean strict = true;
248+
249+
private Builder() {
250+
}
251+
252+
public Builder name(String name) {
253+
this.name = name;
254+
return this;
255+
}
256+
257+
public Builder schema(Map<String, Object> schema) {
258+
this.schema = schema;
259+
return this;
260+
}
261+
262+
public Builder schema(String schema) {
263+
this.schema = ModelOptionsUtils.jsonToMap(schema);
264+
return this;
265+
}
266+
267+
public Builder strict(Boolean strict) {
268+
this.strict = strict;
269+
return this;
270+
}
271+
272+
public JsonSchema build() {
273+
return new JsonSchema(this.name, this.schema, this.strict);
274+
}
275+
276+
}
277+
278+
}
41279

42280
}

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.junit.jupiter.params.provider.MethodSource;
3131
import org.mockito.Mockito;
3232

33+
import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.Type;
3334
import org.springframework.ai.chat.prompt.Prompt;
3435

3536
import static org.assertj.core.api.Assertions.assertThat;
@@ -68,7 +69,7 @@ public void createRequestWithChatOptions() {
6869
.logprobs(true)
6970
.topLogprobs(5)
7071
.enhancements(mockAzureChatEnhancementConfiguration)
71-
.responseFormat(AzureOpenAiResponseFormat.TEXT)
72+
.responseFormat(AzureOpenAiResponseFormat.builder().type(Type.TEXT).build())
7273
.build();
7374

7475
var client = AzureOpenAiChatModel.builder()
@@ -114,7 +115,7 @@ public void createRequestWithChatOptions() {
114115
.logprobs(true)
115116
.topLogprobs(4)
116117
.enhancements(anotherMockAzureChatEnhancementConfiguration)
117-
.responseFormat(AzureOpenAiResponseFormat.JSON)
118+
.responseFormat(AzureOpenAiResponseFormat.builder().type(Type.JSON_OBJECT).build())
118119
.build();
119120

120121
requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message content", runtimeOptions));

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ class AzureOpenAiChatOptionsTests {
3636

3737
@Test
3838
void testBuilderWithAllFields() {
39-
AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.TEXT;
39+
AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.builder()
40+
.type(AzureOpenAiResponseFormat.Type.TEXT)
41+
.build();
4042
ChatCompletionStreamOptions streamOptions = new ChatCompletionStreamOptions();
4143
streamOptions.setIncludeUsage(true);
4244

@@ -75,7 +77,9 @@ void testBuilderWithAllFields() {
7577

7678
@Test
7779
void testCopy() {
78-
AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.TEXT;
80+
AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.builder()
81+
.type(AzureOpenAiResponseFormat.Type.TEXT)
82+
.build();
7983
ChatCompletionStreamOptions streamOptions = new ChatCompletionStreamOptions();
8084
streamOptions.setIncludeUsage(true);
8185

@@ -113,7 +117,9 @@ void testCopy() {
113117

114118
@Test
115119
void testSetters() {
116-
AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.TEXT;
120+
AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.builder()
121+
.type(AzureOpenAiResponseFormat.Type.TEXT)
122+
.build();
117123
ChatCompletionStreamOptions streamOptions = new ChatCompletionStreamOptions();
118124
streamOptions.setIncludeUsage(true);
119125
AzureChatEnhancementConfiguration enhancements = new AzureChatEnhancementConfiguration();

0 commit comments

Comments
 (0)