30
30
import org .springframework .ai .chat .model .ChatResponse ;
31
31
import org .springframework .ai .chat .prompt .Prompt ;
32
32
import org .springframework .ai .openai .OpenAiChatModel ;
33
+ import org .springframework .ai .openai .OpenAiChatOptions ;
33
34
import org .springframework .ai .openai .api .OpenAiApi ;
34
35
import org .springframework .ai .openai .metadata .support .OpenAiApiResponseHeaders ;
35
36
import org .springframework .beans .factory .annotation .Autowired ;
@@ -73,7 +74,7 @@ void resetMockServer() {
73
74
@ Test
74
75
void aiResponseContainsAiMetadata () {
75
76
76
- prepareMock ();
77
+ prepareMock (false );
77
78
78
79
Prompt prompt = new Prompt ("Reach for the sky." );
79
80
@@ -118,13 +119,32 @@ void aiResponseContainsAiMetadata() {
118
119
119
120
response .getResults ().forEach (generation -> {
120
121
ChatGenerationMetadata chatGenerationMetadata = generation .getMetadata ();
122
+ var logprobs = chatGenerationMetadata .get ("logprobs" );
123
+ assertThat (logprobs ).isNull ();
121
124
assertThat (chatGenerationMetadata ).isNotNull ();
122
125
assertThat (chatGenerationMetadata .getFinishReason ()).isEqualTo ("STOP" );
123
126
assertThat (chatGenerationMetadata .getContentFilters ()).isEmpty ();
124
127
});
125
128
}
126
129
127
- private void prepareMock () {
130
+ @ Test
131
+ void aiResponseContainsAiLogprobsMetadata () {
132
+
133
+ prepareMock (true );
134
+
135
+ Prompt prompt = new Prompt ("Reach for the sky." , new OpenAiChatOptions .Builder ().logprobs (true ).build ());
136
+
137
+ ChatResponse response = this .openAiChatClient .call (prompt );
138
+
139
+ assertThat (response ).isNotNull ();
140
+ assertThat (response .getResult ()).isNotNull ();
141
+ assertThat (response .getResult ().getMetadata ()).isNotNull ();
142
+
143
+ var logprobs = response .getResult ().getMetadata ().get ("logprobs" );
144
+ assertThat (logprobs ).isNotNull ().isInstanceOf (OpenAiApi .LogProbs .class );
145
+ }
146
+
147
+ private void prepareMock (boolean includeLogprobs ) {
128
148
129
149
HttpHeaders httpHeaders = new HttpHeaders ();
130
150
httpHeaders .set (OpenAiApiResponseHeaders .REQUESTS_LIMIT_HEADER .getName (), "4000" );
@@ -137,34 +157,58 @@ private void prepareMock() {
137
157
this .server .expect (requestTo (StringContains .containsString ("/v1/chat/completions" )))
138
158
.andExpect (method (HttpMethod .POST ))
139
159
.andExpect (header (HttpHeaders .AUTHORIZATION , "Bearer " + TEST_API_KEY ))
140
- .andRespond (withSuccess (getJson (), MediaType .APPLICATION_JSON ).headers (httpHeaders ));
160
+ .andRespond (withSuccess (getJson (includeLogprobs ), MediaType .APPLICATION_JSON ).headers (httpHeaders ));
141
161
142
162
}
143
163
144
- private String getJson () {
164
+ private String getBaseJson () {
145
165
return """
146
- {
147
- "id": "chatcmpl-123",
148
- "object": "chat.completion",
149
- "created": 1677652288,
150
- "model": "gpt-3.5-turbo-0613",
151
- "choices": [{
152
- "index": 0,
153
- "message": {
154
- "role": "assistant",
155
- "content": "I surrender!"
156
- },
157
- "finish_reason": "stop"
158
- }],
159
- "usage": {
160
- "prompt_tokens": 9,
161
- "completion_tokens": 12,
162
- "total_tokens": 21
163
- }
164
- }
166
+ {
167
+ "id": "chatcmpl-123",
168
+ "object": "chat.completion",
169
+ "created": 1677652288,
170
+ "model": "gpt-3.5-turbo-0613",
171
+ "choices": [{
172
+ "index": 0,
173
+ "message": {
174
+ "role": "assistant",
175
+ "content": "I surrender!"
176
+ },
177
+ %s
178
+ "finish_reason": "stop"
179
+ }],
180
+ "usage": {
181
+ "prompt_tokens": 9,
182
+ "completion_tokens": 12,
183
+ "total_tokens": 21
184
+ }
185
+ }
165
186
""" ;
166
187
}
167
188
189
+ private String getJson (boolean includeLogprobs ) {
190
+ if (includeLogprobs ) {
191
+ String logprobs = """
192
+ "logprobs" : {
193
+ "content" : [ {
194
+ "token" : "I",
195
+ "logprob" : -0.029507114,
196
+ "bytes" : [ 73 ],
197
+ "top_logprobs" : [ ]
198
+ }, {
199
+ "token" : " surrender!",
200
+ "logprob" : -0.061970375,
201
+ "bytes" : [ 32, 115, 117, 114, 114, 101, 110, 100, 101, 114, 33 ],
202
+ "top_logprobs" : [ ]
203
+ } ]
204
+ },
205
+ """ ;
206
+ return String .format (getBaseJson (), logprobs );
207
+ }
208
+
209
+ return String .format (getBaseJson (), "" );
210
+ }
211
+
168
212
@ SpringBootConfiguration
169
213
static class Config {
170
214
0 commit comments