Skip to content

Commit 8acf4b9

Browse files
ndoemarkpollack
authored andcommitted
GH-2737 Returning logprobs in generation metadata when requested
Signed-off-by: ndoe <[email protected]>
1 parent fef01ed commit 8acf4b9

File tree

2 files changed

+71
-23
lines changed

2 files changed

+71
-23
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,10 @@ private Generation buildGeneration(Choice choice, Map<String, Object> metadata,
432432
generationMetadataBuilder.metadata("audioExpiresAt", audioOutput.expiresAt());
433433
}
434434

435+
if (Boolean.TRUE.equals(request.logprobs())) {
436+
generationMetadataBuilder.metadata("logprobs", choice.logprobs());
437+
}
438+
435439
var assistantMessage = new AssistantMessage(textContent, metadata, toolCalls, media);
436440
return new Generation(assistantMessage, generationMetadataBuilder.build());
437441
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java

Lines changed: 67 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.springframework.ai.chat.model.ChatResponse;
3131
import org.springframework.ai.chat.prompt.Prompt;
3232
import org.springframework.ai.openai.OpenAiChatModel;
33+
import org.springframework.ai.openai.OpenAiChatOptions;
3334
import org.springframework.ai.openai.api.OpenAiApi;
3435
import org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders;
3536
import org.springframework.beans.factory.annotation.Autowired;
@@ -73,7 +74,7 @@ void resetMockServer() {
7374
@Test
7475
void aiResponseContainsAiMetadata() {
7576

76-
prepareMock();
77+
prepareMock(false);
7778

7879
Prompt prompt = new Prompt("Reach for the sky.");
7980

@@ -118,13 +119,32 @@ void aiResponseContainsAiMetadata() {
118119

119120
response.getResults().forEach(generation -> {
120121
ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata();
122+
var logprobs = chatGenerationMetadata.get("logprobs");
123+
assertThat(logprobs).isNull();
121124
assertThat(chatGenerationMetadata).isNotNull();
122125
assertThat(chatGenerationMetadata.getFinishReason()).isEqualTo("STOP");
123126
assertThat(chatGenerationMetadata.getContentFilters()).isEmpty();
124127
});
125128
}
126129

127-
private void prepareMock() {
130+
@Test
131+
void aiResponseContainsAiLogprobsMetadata() {
132+
133+
prepareMock(true);
134+
135+
Prompt prompt = new Prompt("Reach for the sky.", new OpenAiChatOptions.Builder().logprobs(true).build());
136+
137+
ChatResponse response = this.openAiChatClient.call(prompt);
138+
139+
assertThat(response).isNotNull();
140+
assertThat(response.getResult()).isNotNull();
141+
assertThat(response.getResult().getMetadata()).isNotNull();
142+
143+
var logprobs = response.getResult().getMetadata().get("logprobs");
144+
assertThat(logprobs).isNotNull().isInstanceOf(OpenAiApi.LogProbs.class);
145+
}
146+
147+
private void prepareMock(boolean includeLogprobs) {
128148

129149
HttpHeaders httpHeaders = new HttpHeaders();
130150
httpHeaders.set(OpenAiApiResponseHeaders.REQUESTS_LIMIT_HEADER.getName(), "4000");
@@ -137,34 +157,58 @@ private void prepareMock() {
137157
this.server.expect(requestTo(StringContains.containsString("/v1/chat/completions")))
138158
.andExpect(method(HttpMethod.POST))
139159
.andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY))
140-
.andRespond(withSuccess(getJson(), MediaType.APPLICATION_JSON).headers(httpHeaders));
160+
.andRespond(withSuccess(getJson(includeLogprobs), MediaType.APPLICATION_JSON).headers(httpHeaders));
141161

142162
}
143163

144-
private String getJson() {
164+
private String getBaseJson() {
145165
return """
146-
{
147-
"id": "chatcmpl-123",
148-
"object": "chat.completion",
149-
"created": 1677652288,
150-
"model": "gpt-3.5-turbo-0613",
151-
"choices": [{
152-
"index": 0,
153-
"message": {
154-
"role": "assistant",
155-
"content": "I surrender!"
156-
},
157-
"finish_reason": "stop"
158-
}],
159-
"usage": {
160-
"prompt_tokens": 9,
161-
"completion_tokens": 12,
162-
"total_tokens": 21
163-
}
164-
}
166+
{
167+
"id": "chatcmpl-123",
168+
"object": "chat.completion",
169+
"created": 1677652288,
170+
"model": "gpt-3.5-turbo-0613",
171+
"choices": [{
172+
"index": 0,
173+
"message": {
174+
"role": "assistant",
175+
"content": "I surrender!"
176+
},
177+
%s
178+
"finish_reason": "stop"
179+
}],
180+
"usage": {
181+
"prompt_tokens": 9,
182+
"completion_tokens": 12,
183+
"total_tokens": 21
184+
}
185+
}
165186
""";
166187
}
167188

189+
private String getJson(boolean includeLogprobs) {
190+
if (includeLogprobs) {
191+
String logprobs = """
192+
"logprobs" : {
193+
"content" : [ {
194+
"token" : "I",
195+
"logprob" : -0.029507114,
196+
"bytes" : [ 73 ],
197+
"top_logprobs" : [ ]
198+
}, {
199+
"token" : " surrender!",
200+
"logprob" : -0.061970375,
201+
"bytes" : [ 32, 115, 117, 114, 114, 101, 110, 100, 101, 114, 33 ],
202+
"top_logprobs" : [ ]
203+
} ]
204+
},
205+
""";
206+
return String.format(getBaseJson(), logprobs);
207+
}
208+
209+
return String.format(getBaseJson(), "");
210+
}
211+
168212
@SpringBootConfiguration
169213
static class Config {
170214

0 commit comments

Comments
 (0)