Skip to content

Commit 6c9b4d6

Browse files
apappascsilayaperumalg
authored andcommitted
feat: OpenAI Web Search Annotations
This PR adds support for retrieving web search annotations from the OpenAI API, as described in their [web search documentation](https://platform.openai.com/docs/guides/web-search). This allows us to access citation URLs and their context within generated responses when using models like `gpt-4o-search-preview`. **Changes:** * Added `annotations` (with `Annotation` and `UrlCitation` records) to `ChatCompletionMessage` in `OpenAiApi.java`. * Updated `OpenAiChatModel` to populate the `annotations` field (via metadata) for both regular and streaming responses. * Added integration tests (`webSearchAnnotationsTest`, `streamWebSearchAnnotationsTest`) to `OpenAiChatModelIT.java`. * Added `GPT_4_O_SEARCH_PREVIEW` and `GPT_4_O_MINI_SEARCH_PREVIEW` to `OpenAiApi.ChatModel`. * Added `WebSearchOptions` and related records to `OpenAiApi`. * Minor updates to `ChatCompletionRequest` and its `Builder`. Resolves spring-projects#2449 Signed-off-by: Alexandros Pappas <[email protected]>
1 parent 2c9214b commit 6c9b4d6

File tree

7 files changed

+217
-27
lines changed

7 files changed

+217
-27
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
217217
"role", choice.message().role() != null ? choice.message().role().name() : "",
218218
"index", choice.index(),
219219
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "",
220-
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "");
220+
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "",
221+
"annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of());
221222
return buildGeneration(choice, metadata, request);
222223
}).toList();
223224
// @formatter:on
@@ -315,8 +316,8 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
315316
"role", roleMap.getOrDefault(id, ""),
316317
"index", choice.index(),
317318
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "",
318-
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "");
319-
319+
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "",
320+
"annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of());
320321
return buildGeneration(choice, metadata, request);
321322
}).toList();
322323
// @formatter:on
@@ -583,7 +584,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {
583584

584585
}
585586
return List.of(new ChatCompletionMessage(assistantMessage.getText(),
586-
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput));
587+
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput, null));
587588
}
588589
else if (message.getMessageType() == MessageType.TOOL) {
589590
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
@@ -593,7 +594,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
593594
return toolMessage.getResponses()
594595
.stream()
595596
.map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(),
596-
tr.id(), null, null, null))
597+
tr.id(), null, null, null, null))
597598
.toList();
598599
}
599600
else {

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters;
3737
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.StreamOptions;
3838
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoiceBuilder;
39+
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.WebSearchOptions;
3940
import org.springframework.ai.openai.api.ResponseFormat;
4041
import org.springframework.ai.tool.ToolCallback;
4142
import org.springframework.lang.Nullable;
@@ -194,6 +195,11 @@ public class OpenAiChatOptions implements ToolCallingChatOptions {
194195
*/
195196
private @JsonProperty("reasoning_effort") String reasoningEffort;
196197

198+
/**
199+
* This tool searches the web for relevant results to use in a response.
200+
*/
201+
private @JsonProperty("web_search_options") WebSearchOptions webSearchOptions;
202+
197203
/**
198204
* Collection of {@link ToolCallback}s to be used for tool calling in the chat completion requests.
199205
*/
@@ -548,6 +554,14 @@ public void setReasoningEffort(String reasoningEffort) {
548554
this.reasoningEffort = reasoningEffort;
549555
}
550556

557+
public WebSearchOptions getWebSearchOptions() {
558+
return this.webSearchOptions;
559+
}
560+
561+
public void setWebSearchOptions(WebSearchOptions webSearchOptions) {
562+
this.webSearchOptions = webSearchOptions;
563+
}
564+
551565
@Override
552566
public OpenAiChatOptions copy() {
553567
return OpenAiChatOptions.fromOptions(this);
@@ -560,7 +574,7 @@ public int hashCode() {
560574
this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice,
561575
this.user, this.parallelToolCalls, this.toolCallbacks, this.toolNames, this.httpHeaders,
562576
this.internalToolExecutionEnabled, this.toolContext, this.outputModalities, this.outputAudio,
563-
this.store, this.metadata, this.reasoningEffort);
577+
this.store, this.metadata, this.reasoningEffort, this.webSearchOptions);
564578
}
565579

566580
@Override
@@ -592,7 +606,8 @@ public boolean equals(Object o) {
592606
&& Objects.equals(this.outputModalities, other.outputModalities)
593607
&& Objects.equals(this.outputAudio, other.outputAudio) && Objects.equals(this.store, other.store)
594608
&& Objects.equals(this.metadata, other.metadata)
595-
&& Objects.equals(this.reasoningEffort, other.reasoningEffort);
609+
&& Objects.equals(this.reasoningEffort, other.reasoningEffort)
610+
&& Objects.equals(this.webSearchOptions, other.webSearchOptions);
596611
}
597612

598613
@Override
@@ -780,6 +795,11 @@ public Builder reasoningEffort(String reasoningEffort) {
780795
return this;
781796
}
782797

798+
public Builder webSearchOptions(WebSearchOptions webSearchOptions) {
799+
this.options.webSearchOptions = webSearchOptions;
800+
return this;
801+
}
802+
783803
public OpenAiChatOptions build() {
784804
return this.options;
785805
}

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 116 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,21 @@ public enum ChatModel implements ChatModelDescription {
590590
* Context window: 4,096 tokens. Max output tokens: 4,096 tokens. Knowledge
591591
* cutoff: September, 2021.
592592
*/
593-
GPT_3_5_TURBO_INSTRUCT("gpt-3.5-turbo-instruct");
593+
GPT_3_5_TURBO_INSTRUCT("gpt-3.5-turbo-instruct"),
594+
595+
/**
596+
* <b>GPT-4o Search Preview</b> is a specialized model for web search in Chat
597+
* Completions. It is trained to understand and execute web search queries. See
598+
* the web search guide for more information.
599+
*/
600+
GPT_4_O_SEARCH_PREVIEW("gpt-4o-search-preview"),
601+
602+
/**
603+
* <b>GPT-4o mini Search Preview</b> is a specialized model for web search in Chat
604+
* Completions. It is trained to understand and execute web search queries. See
605+
* the web search guide for more information.
606+
*/
607+
GPT_4_O_MINI_SEARCH_PREVIEW("gpt-4o-mini-search-preview");
594608

595609
public final String value;
596610

@@ -951,6 +965,10 @@ public enum OutputModality {
951965
* @param parallelToolCalls If set to true, the model will call all functions in the
952966
* tools list in parallel. Otherwise, the model will call the functions in the tools
953967
* list in the order they are provided.
968+
* @param reasoningEffort Constrains effort on reasoning for reasoning models.
969+
* Currently supported values are low, medium, and high. Reducing reasoning effort can
970+
* result in faster responses and fewer tokens used on reasoning in a response.
971+
* @param webSearchOptions Options for web search.
954972
*/
955973
@JsonInclude(Include.NON_NULL)
956974
public record ChatCompletionRequest(// @formatter:off
@@ -980,7 +998,8 @@ public record ChatCompletionRequest(// @formatter:off
980998
@JsonProperty("tool_choice") Object toolChoice,
981999
@JsonProperty("parallel_tool_calls") Boolean parallelToolCalls,
9821000
@JsonProperty("user") String user,
983-
@JsonProperty("reasoning_effort") String reasoningEffort) {
1001+
@JsonProperty("reasoning_effort") String reasoningEffort,
1002+
@JsonProperty("web_search_options") WebSearchOptions webSearchOptions) {
9841003

9851004
/**
9861005
* Shortcut constructor for a chat completion request with the given messages, model and temperature.
@@ -992,7 +1011,7 @@ public record ChatCompletionRequest(// @formatter:off
9921011
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature) {
9931012
this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null,
9941013
null, null, null, false, null, temperature, null,
995-
null, null, null, null, null);
1014+
null, null, null, null, null, null);
9961015
}
9971016

9981017
/**
@@ -1006,7 +1025,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
10061025
this(messages, model, null, null, null, null, null, null,
10071026
null, null, null, List.of(OutputModality.AUDIO, OutputModality.TEXT), audio, null, null,
10081027
null, null, null, stream, null, null, null,
1009-
null, null, null, null, null);
1028+
null, null, null, null, null, null);
10101029
}
10111030

10121031
/**
@@ -1021,7 +1040,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
10211040
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature, boolean stream) {
10221041
this(messages, model, null, null, null, null, null, null, null, null, null,
10231042
null, null, null, null, null, null, null, stream, null, temperature, null,
1024-
null, null, null, null, null);
1043+
null, null, null, null, null, null);
10251044
}
10261045

10271046
/**
@@ -1037,7 +1056,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
10371056
List<FunctionTool> tools, Object toolChoice) {
10381057
this(messages, model, null, null, null, null, null, null, null, null, null,
10391058
null, null, null, null, null, null, null, false, null, 0.8, null,
1040-
tools, toolChoice, null, null, null);
1059+
tools, toolChoice, null, null, null, null);
10411060
}
10421061

10431062
/**
@@ -1050,7 +1069,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
10501069
public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean stream) {
10511070
this(messages, null, null, null, null, null, null, null, null, null, null,
10521071
null, null, null, null, null, null, null, stream, null, null, null,
1053-
null, null, null, null, null);
1072+
null, null, null, null, null, null);
10541073
}
10551074

10561075
/**
@@ -1063,7 +1082,7 @@ public ChatCompletionRequest streamOptions(StreamOptions streamOptions) {
10631082
return new ChatCompletionRequest(this.messages, this.model, this.store, this.metadata, this.frequencyPenalty, this.logitBias, this.logprobs,
10641083
this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.outputModalities, this.audioParameters, this.presencePenalty,
10651084
this.responseFormat, this.seed, this.serviceTier, this.stop, this.stream, streamOptions, this.temperature, this.topP,
1066-
this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort);
1085+
this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort, this.webSearchOptions);
10671086
}
10681087

10691088
/**
@@ -1145,6 +1164,61 @@ public record StreamOptions(
11451164

11461165
public static StreamOptions INCLUDE_USAGE = new StreamOptions(true);
11471166
}
1167+
1168+
/**
1169+
* This tool searches the web for relevant results to use in a response.
1170+
*
1171+
* @param searchContextSize
1172+
* @param userLocation
1173+
*/
1174+
@JsonInclude(Include.NON_NULL)
1175+
public record WebSearchOptions(@JsonProperty("search_context_size") SearchContextSize searchContextSize,
1176+
@JsonProperty("user_location") UserLocation userLocation) {
1177+
1178+
/**
1179+
* High level guidance for the amount of context window space to use for the
1180+
* search. One of low, medium, or high. medium is the default.
1181+
*/
1182+
public enum SearchContextSize {
1183+
1184+
/**
1185+
* Low context size.
1186+
*/
1187+
@JsonProperty("low")
1188+
LOW,
1189+
1190+
/**
1191+
* Medium context size. This is the default.
1192+
*/
1193+
@JsonProperty("medium")
1194+
MEDIUM,
1195+
1196+
/**
1197+
* High context size.
1198+
*/
1199+
@JsonProperty("high")
1200+
HIGH
1201+
1202+
}
1203+
1204+
/**
1205+
* Approximate location parameters for the search.
1206+
*
1207+
* @param type The type of location approximation. Always "approximate".
1208+
* @param approximate The approximate location details.
1209+
*/
1210+
@JsonInclude(Include.NON_NULL)
1211+
public record UserLocation(@JsonProperty("type") String type,
1212+
@JsonProperty("approximate") Approximate approximate) {
1213+
1214+
@JsonInclude(Include.NON_NULL)
1215+
public record Approximate(@JsonProperty("city") String city, @JsonProperty("country") String country,
1216+
@JsonProperty("region") String region, @JsonProperty("timezone") String timezone) {
1217+
}
1218+
}
1219+
1220+
}
1221+
11481222
} // @formatter:on
11491223

11501224
/**
@@ -1163,19 +1237,22 @@ public record StreamOptions(
11631237
* Applicable only for {@link Role#ASSISTANT} role and null otherwise.
11641238
* @param refusal The refusal message by the assistant. Applicable only for
11651239
* {@link Role#ASSISTANT} role and null otherwise.
1166-
* @param audioOutput Audio response from the model. >>>>>>> bdb66e577 (OpenAI -
1167-
* Support audio input modality)
1240+
* @param audioOutput Audio response from the model.
1241+
* @param annotations Annotations for the message, when applicable, as when using the
1242+
* web search tool.
11681243
*/
1169-
@JsonInclude(Include.NON_NULL)
1170-
public record ChatCompletionMessage(// @formatter:off
1244+
@JsonInclude(JsonInclude.Include.NON_NULL)
1245+
public record ChatCompletionMessage(
1246+
// @formatter:off
11711247
@JsonProperty("content") Object rawContent,
11721248
@JsonProperty("role") Role role,
11731249
@JsonProperty("name") String name,
11741250
@JsonProperty("tool_call_id") String toolCallId,
1175-
@JsonProperty("tool_calls")
1176-
@JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) List<ToolCall> toolCalls,
1251+
@JsonProperty("tool_calls") @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) List<ToolCall> toolCalls,
11771252
@JsonProperty("refusal") String refusal,
1178-
@JsonProperty("audio") AudioOutput audioOutput) { // @formatter:on
1253+
@JsonProperty("audio") AudioOutput audioOutput,
1254+
@JsonProperty("annotations") List<Annotation> annotations
1255+
) { // @formatter:on
11791256

11801257
/**
11811258
* Create a chat completion message with the given content and role. All other
@@ -1184,8 +1261,7 @@ public record ChatCompletionMessage(// @formatter:off
11841261
* @param role The role of the author of this message.
11851262
*/
11861263
public ChatCompletionMessage(Object content, Role role) {
1187-
this(content, role, null, null, null, null, null);
1188-
1264+
this(content, role, null, null, null, null, null, null);
11891265
}
11901266

11911267
/**
@@ -1362,6 +1438,29 @@ public record AudioOutput(// @formatter:off
13621438
@JsonProperty("transcript") String transcript
13631439
) { // @formatter:on
13641440
}
1441+
1442+
/**
1443+
* Represents an annotation within a message, specifically for URL citations.
1444+
*/
1445+
@JsonInclude(JsonInclude.Include.NON_NULL)
1446+
public record Annotation(@JsonProperty("type") String type,
1447+
@JsonProperty("url_citation") UrlCitation urlCitation) {
1448+
/**
1449+
* A URL citation when using web search.
1450+
*
1451+
* @param endIndex The index of the last character of the URL citation in the
1452+
* message.
1453+
* @param startIndex The index of the first character of the URL citation in
1454+
* the message.
1455+
* @param title The title of the web resource.
1456+
* @param url The URL of the web resource.
1457+
*/
1458+
@JsonInclude(JsonInclude.Include.NON_NULL)
1459+
public record UrlCitation(@JsonProperty("end_index") Integer endIndex,
1460+
@JsonProperty("start_index") Integer startIndex, @JsonProperty("title") String title,
1461+
@JsonProperty("url") String url) {
1462+
}
1463+
}
13651464
}
13661465

13671466
/**

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
*
4141
* @author Christian Tzolov
4242
* @author Thomas Vitale
43+
* @author Alexandros Pappas
4344
* @since 0.8.1
4445
*/
4546
public class OpenAiStreamFunctionCallingHelper {
@@ -98,6 +99,8 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti
9899
String refusal = (current.refusal() != null ? current.refusal() : previous.refusal());
99100
ChatCompletionMessage.AudioOutput audioOutput = (current.audioOutput() != null ? current.audioOutput()
100101
: previous.audioOutput());
102+
List<ChatCompletionMessage.Annotation> annotations = (current.annotations() != null ? current.annotations()
103+
: previous.annotations());
101104

102105
List<ToolCall> toolCalls = new ArrayList<>();
103106
ToolCall lastPreviousTooCall = null;
@@ -127,7 +130,7 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti
127130
toolCalls.add(lastPreviousTooCall);
128131
}
129132
}
130-
return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput);
133+
return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput, annotations);
131134
}
132135

133136
private ToolCall merge(ToolCall previous, ToolCall current) {

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ void validateReasoningTokens() {
7575
"If a train travels 100 miles in 2 hours, what is its average speed?", ChatCompletionMessage.Role.USER);
7676
ChatCompletionRequest request = new ChatCompletionRequest(List.of(userMessage), "o1", null, null, null, null,
7777
null, null, null, null, null, null, null, null, null, null, null, null, false, null, null, null, null,
78-
null, null, null, "low");
78+
null, null, null, "low", null);
7979
ResponseEntity<ChatCompletion> response = this.openAiApi.chatCompletionEntity(request);
8080

8181
assertThat(response).isNotNull();

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
*
4545
* @author Christian Tzolov
4646
* @author Thomas Vitale
47+
* @author Alexandros Pappas
4748
*/
4849
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
4950
public class OpenAiApiToolFunctionCallIT {
@@ -129,7 +130,7 @@ public void toolFunctionCall() {
129130

130131
// extend conversation with function response.
131132
messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL,
132-
functionName, toolCall.id(), null, null, null));
133+
functionName, toolCall.id(), null, null, null, null));
133134
}
134135
}
135136

0 commit comments

Comments
 (0)