|
21 | 21 | import com.azure.ai.openai.OpenAIClientBuilder;
|
22 | 22 | import com.azure.ai.openai.models.*;
|
23 | 23 | import com.azure.core.util.BinaryData;
|
| 24 | +import io.micrometer.observation.Observation; |
24 | 25 | import io.micrometer.observation.ObservationRegistry;
|
25 | 26 | import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage;
|
26 | 27 | import org.springframework.ai.chat.messages.AssistantMessage;
|
|
37 | 38 | import org.springframework.ai.chat.model.ChatModel;
|
38 | 39 | import org.springframework.ai.chat.model.ChatResponse;
|
39 | 40 | import org.springframework.ai.chat.model.Generation;
|
| 41 | +import org.springframework.ai.chat.model.MessageAggregator; |
40 | 42 | import org.springframework.ai.chat.observation.ChatModelObservationContext;
|
41 | 43 | import org.springframework.ai.chat.observation.ChatModelObservationConvention;
|
42 | 44 | import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
|
|
51 | 53 | import org.springframework.ai.observation.conventions.AiProvider;
|
52 | 54 | import org.springframework.util.Assert;
|
53 | 55 | import org.springframework.util.CollectionUtils;
|
| 56 | + |
| 57 | +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; |
54 | 58 | import reactor.core.publisher.Flux;
|
55 |
| -import reactor.core.publisher.Mono; |
56 | 59 |
|
57 | 60 | import java.util.ArrayList;
|
58 | 61 | import java.util.Base64;
|
|
62 | 65 | import java.util.Map;
|
63 | 66 | import java.util.Optional;
|
64 | 67 | import java.util.Set;
|
| 68 | +import java.util.concurrent.ConcurrentHashMap; |
65 | 69 | import java.util.concurrent.atomic.AtomicBoolean;
|
66 | 70 |
|
67 | 71 | /**
|
@@ -189,51 +193,83 @@ && isToolCall(response, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS
|
189 | 193 | @Override
|
190 | 194 | public Flux<ChatResponse> stream(Prompt prompt) {
|
191 | 195 |
|
192 |
| - ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt); |
193 |
| - options.setStream(true); |
194 |
| - |
195 |
| - Flux<ChatCompletions> chatCompletionsStream = this.openAIAsyncClient |
196 |
| - .getChatCompletionsStream(options.getModel(), options); |
197 |
| - |
198 |
| - final var isFunctionCall = new AtomicBoolean(false); |
199 |
| - final Flux<ChatCompletions> accessibleChatCompletionsFlux = chatCompletionsStream |
200 |
| - // Note: the first chat completions can be ignored when using Azure OpenAI |
201 |
| - // service which is a known service bug. |
202 |
| - .filter(chatCompletions -> !CollectionUtils.isEmpty(chatCompletions.getChoices())) |
203 |
| - .map(chatCompletions -> { |
204 |
| - final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls(); |
205 |
| - isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty()); |
206 |
| - return chatCompletions; |
207 |
| - }) |
208 |
| - .windowUntil(chatCompletions -> { |
209 |
| - if (isFunctionCall.get() && chatCompletions.getChoices() |
210 |
| - .get(0) |
211 |
| - .getFinishReason() == CompletionsFinishReason.TOOL_CALLS) { |
212 |
| - isFunctionCall.set(false); |
213 |
| - return true; |
| 196 | + return Flux.deferContextual(contextView -> { |
| 197 | + ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt); |
| 198 | + options.setStream(true); |
| 199 | + |
| 200 | + Flux<ChatCompletions> chatCompletionsStream = this.openAIAsyncClient |
| 201 | + .getChatCompletionsStream(options.getModel(), options); |
| 202 | + |
| 203 | + // For chunked responses, only the first chunk contains the choice role. |
| 204 | + // The rest of the chunks with same ID share the same role. |
| 205 | + ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>(); |
| 206 | + |
| 207 | + ChatModelObservationContext observationContext = ChatModelObservationContext.builder() |
| 208 | + .prompt(prompt) |
| 209 | + .provider(AiProvider.AZURE_OPENAI.value()) |
| 210 | + .requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions) |
| 211 | + .build(); |
| 212 | + |
| 213 | + Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( |
| 214 | + this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, |
| 215 | + this.observationRegistry); |
| 216 | + |
| 217 | + observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); |
| 218 | + |
| 219 | + final var isFunctionCall = new AtomicBoolean(false); |
| 220 | + |
| 221 | + final Flux<ChatCompletions> accessibleChatCompletionsFlux = chatCompletionsStream |
| 222 | + // Note: the first chat completions can be ignored when using Azure OpenAI |
| 223 | + // service which is a known service bug. |
| 224 | + .filter(chatCompletions -> !CollectionUtils.isEmpty(chatCompletions.getChoices())) |
| 225 | + .map(chatCompletions -> { |
| 226 | + final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls(); |
| 227 | + isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty()); |
| 228 | + return chatCompletions; |
| 229 | + }) |
| 230 | + .windowUntil(chatCompletions -> { |
| 231 | + if (isFunctionCall.get() && chatCompletions.getChoices() |
| 232 | + .get(0) |
| 233 | + .getFinishReason() == CompletionsFinishReason.TOOL_CALLS) { |
| 234 | + isFunctionCall.set(false); |
| 235 | + return true; |
| 236 | + } |
| 237 | + return !isFunctionCall.get(); |
| 238 | + }) |
| 239 | + .concatMapIterable(window -> { |
| 240 | + final var reduce = window.reduce(MergeUtils.emptyChatCompletions(), |
| 241 | + MergeUtils::mergeChatCompletions); |
| 242 | + return List.of(reduce); |
| 243 | + }) |
| 244 | + .flatMap(mono -> mono); |
| 245 | + |
| 246 | + return accessibleChatCompletionsFlux.switchMap(chatCompletions -> { |
| 247 | + |
| 248 | + ChatResponse chatResponse = toChatResponse(chatCompletions); |
| 249 | + |
| 250 | + if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(chatResponse, |
| 251 | + Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) { |
| 252 | + var toolCallConversation = handleToolCalls(prompt, chatResponse); |
| 253 | + // Recursively call the call method with the tool call message |
| 254 | + // conversation that contains the call responses. |
| 255 | + return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); |
214 | 256 | }
|
215 |
| - return !isFunctionCall.get(); |
216 |
| - }) |
217 |
| - .concatMapIterable(window -> { |
218 |
| - final var reduce = window.reduce(MergeUtils.emptyChatCompletions(), MergeUtils::mergeChatCompletions); |
219 |
| - return List.of(reduce); |
220 |
| - }) |
221 |
| - .flatMap(mono -> mono); |
222 |
| - |
223 |
| - return accessibleChatCompletionsFlux.switchMap(chatCompletions -> { |
224 |
| - |
225 |
| - ChatResponse chatResponse = toChatResponse(chatCompletions); |
226 |
| - |
227 |
| - if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(chatResponse, |
228 |
| - Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) { |
229 |
| - var toolCallConversation = handleToolCalls(prompt, chatResponse); |
230 |
| - // Recursively call the call method with the tool call message |
231 |
| - // conversation that contains the call responses. |
232 |
| - return this.stream(new Prompt(toolCallConversation, prompt.getOptions())); |
233 |
| - } |
234 | 257 |
|
235 |
| - return Mono.just(chatResponse); |
| 258 | + Flux<ChatResponse> flux = Flux.just(chatResponse).doOnError(observation::error).doFinally(s -> { |
| 259 | + // TODO: Consider a custom ObservationContext and |
| 260 | + // include additional metadata |
| 261 | + // if (s == SignalType.CANCEL) { |
| 262 | + // observationContext.setAborted(true); |
| 263 | + // } |
| 264 | + observation.stop(); |
| 265 | + }).contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); |
| 266 | + // @formatter:on |
| 267 | + |
| 268 | + return new MessageAggregator().aggregate(flux, observationContext::setResponse); |
| 269 | + }); |
| 270 | + |
236 | 271 | });
|
| 272 | + |
237 | 273 | }
|
238 | 274 |
|
239 | 275 | private ChatResponse toChatResponse(ChatCompletions chatCompletions) {
|
|
0 commit comments