diff --git a/OpenAI.sln b/OpenAI.sln index d6350d85..d3c8d978 100644 --- a/OpenAI.sln +++ b/OpenAI.sln @@ -6,7 +6,9 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OpenAI", "src\OpenAI.csproj EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OpenAI.Examples", "examples\OpenAI.Examples.csproj", "{1F1CD1D4-9932-4B73-99D8-C252A67D4B46}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "OpenAI.Tests", "tests\OpenAI.Tests.csproj", "{6F156401-2544-41D7-B204-3148C51C1D09}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OpenAI.Tests", "tests\OpenAI.Tests.csproj", "{6F156401-2544-41D7-B204-3148C51C1D09}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "System.ClientModel", "..\azure-sdk-for-net\sdk\core\System.ClientModel\src\System.ClientModel.csproj", "{8316F2D5-21A7-468B-97DB-B14C44B4F50C}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -26,6 +28,10 @@ Global {6F156401-2544-41D7-B204-3148C51C1D09}.Debug|Any CPU.Build.0 = Debug|Any CPU {6F156401-2544-41D7-B204-3148C51C1D09}.Release|Any CPU.ActiveCfg = Release|Any CPU {6F156401-2544-41D7-B204-3148C51C1D09}.Release|Any CPU.Build.0 = Release|Any CPU + {8316F2D5-21A7-468B-97DB-B14C44B4F50C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {8316F2D5-21A7-468B-97DB-B14C44B4F50C}.Debug|Any CPU.Build.0 = Debug|Any CPU + {8316F2D5-21A7-468B-97DB-B14C44B4F50C}.Release|Any CPU.ActiveCfg = Release|Any CPU + {8316F2D5-21A7-468B-97DB-B14C44B4F50C}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -33,4 +39,4 @@ Global GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {A97F4B90-2591-4689-B1F8-5F21FE6D6CAE} EndGlobalSection -EndGlobal \ No newline at end of file +EndGlobal diff --git a/examples/Assistants/Example01_RetrievalAugmentedGeneration.cs b/examples/Assistants/Example01_RetrievalAugmentedGeneration.cs index 549ea3db..7ae465fb 100644 --- a/examples/Assistants/Example01_RetrievalAugmentedGeneration.cs +++ b/examples/Assistants/Example01_RetrievalAugmentedGeneration.cs @@ -5,7 +5,6 @@ using System.ClientModel; using System.Collections.Generic; using System.IO; -using System.Threading; namespace OpenAI.Examples; @@ -88,18 +87,12 @@ public void Example01_RetrievalAugmentedGeneration() InitialMessages = { "How well did product 113045 sell in February? Graph its trend over time." } }; - ThreadRun threadRun = assistantClient.CreateThreadAndRun(assistant.Id, threadOptions); - - // Check back to see when the run is done - do - { - Thread.Sleep(TimeSpan.FromSeconds(1)); - threadRun = assistantClient.GetRun(threadRun.ThreadId, threadRun.Id); - } while (!threadRun.Status.IsTerminal); + // Passing ReturnWhen.Completed means CreateThreadAndRun will return control after the run is complete. + RunOperation runOperation = assistantClient.CreateThreadAndRun(ReturnWhen.Completed, assistant.Id, threadOptions); // Finally, we'll print out the full history for the thread that includes the augmented generation PageCollection messagePages - = assistantClient.GetMessages(threadRun.ThreadId, new MessageCollectionOptions() { Order = ListOrder.OldestFirst }); + = assistantClient.GetMessages(runOperation.ThreadId, new MessageCollectionOptions() { Order = ListOrder.OldestFirst }); IEnumerable messages = messagePages.GetAllValues(); foreach (ThreadMessage message in messages) @@ -143,7 +136,7 @@ PageCollection messagePages } // Optionally, delete any persistent resources you no longer need. - _ = assistantClient.DeleteThread(threadRun.ThreadId); + _ = assistantClient.DeleteThread(runOperation.ThreadId); _ = assistantClient.DeleteAssistant(assistant); _ = fileClient.DeleteFile(salesFile); } diff --git a/examples/Assistants/Example01_RetrievalAugmentedGenerationAsync.cs b/examples/Assistants/Example01_RetrievalAugmentedGenerationAsync.cs index 83b95fb0..e6afefac 100644 --- a/examples/Assistants/Example01_RetrievalAugmentedGenerationAsync.cs +++ b/examples/Assistants/Example01_RetrievalAugmentedGenerationAsync.cs @@ -5,7 +5,6 @@ using System.ClientModel; using System.Collections.Generic; using System.IO; -using System.Threading; using System.Threading.Tasks; namespace OpenAI.Examples; @@ -89,18 +88,11 @@ public async Task Example01_RetrievalAugmentedGenerationAsync() InitialMessages = { "How well did product 113045 sell in February? Graph its trend over time." } }; - ThreadRun threadRun = await assistantClient.CreateThreadAndRunAsync(assistant.Id, threadOptions); - - // Check back to see when the run is done - do - { - Thread.Sleep(TimeSpan.FromSeconds(1)); - threadRun = assistantClient.GetRun(threadRun.ThreadId, threadRun.Id); - } while (!threadRun.Status.IsTerminal); + RunOperation runOperation = await assistantClient.CreateThreadAndRunAsync(ReturnWhen.Completed, assistant.Id, threadOptions); // Finally, we'll print out the full history for the thread that includes the augmented generation AsyncPageCollection messagePages - = assistantClient.GetMessagesAsync(threadRun.ThreadId, new MessageCollectionOptions() { Order = ListOrder.OldestFirst }); + = assistantClient.GetMessagesAsync(runOperation.ThreadId, new MessageCollectionOptions() { Order = ListOrder.OldestFirst }); IAsyncEnumerable messages = messagePages.GetAllValuesAsync(); await foreach (ThreadMessage message in messages) @@ -144,7 +136,7 @@ AsyncPageCollection messagePages } // Optionally, delete any persistent resources you no longer need. - _ = await assistantClient.DeleteThreadAsync(threadRun.ThreadId); + _ = await assistantClient.DeleteThreadAsync(runOperation.ThreadId); _ = await assistantClient.DeleteAssistantAsync(assistant); _ = await fileClient.DeleteFileAsync(salesFile); } diff --git a/examples/Assistants/Example02_FunctionCalling.cs b/examples/Assistants/Example02_FunctionCalling.cs index a0eb0563..abe753d2 100644 --- a/examples/Assistants/Example02_FunctionCalling.cs +++ b/examples/Assistants/Example02_FunctionCalling.cs @@ -13,7 +13,7 @@ public partial class AssistantExamples [Test] public void Example02_FunctionCalling() { - #region + #region Define Functions string GetCurrentLocation() { // Call a location API here. @@ -64,7 +64,7 @@ string GetCurrentWeather(string location, string unit = "celsius") #pragma warning disable OPENAI001 AssistantClient client = new(Environment.GetEnvironmentVariable("OPENAI_API_KEY")); - #region + #region Create Assistant // Create an assistant that can call the function tools. AssistantCreationOptions assistantOptions = new() { @@ -78,29 +78,27 @@ string GetCurrentWeather(string location, string unit = "celsius") Assistant assistant = client.CreateAssistant("gpt-4-turbo", assistantOptions); #endregion - #region + #region Create Thread and Run // Create a thread with an initial user message and run it. ThreadCreationOptions threadOptions = new() { InitialMessages = { "What's the weather like today?" } }; - ThreadRun run = client.CreateThreadAndRun(assistant.Id, threadOptions); + RunOperation runOperation = client.CreateThreadAndRun(ReturnWhen.Started, assistant.Id, threadOptions); #endregion - #region - // Poll the run until it is no longer queued or in progress. - while (!run.Status.IsTerminal) - { - Thread.Sleep(TimeSpan.FromSeconds(1)); - run = client.GetRun(run.ThreadId, run.Id); + #region Submit tool outputs to run + + IEnumerable updates = runOperation.GetUpdates(); - // If the run requires action, resolve them. - if (run.Status == RunStatus.RequiresAction) + foreach (ThreadRun update in updates) + { + if (update.Status == RunStatus.RequiresAction) { List toolOutputs = []; - foreach (RequiredAction action in run.RequiredActions) + foreach (RequiredAction action in runOperation.Value.RequiredActions) { switch (action.FunctionName) { @@ -142,17 +140,19 @@ string GetCurrentWeather(string location, string unit = "celsius") } // Submit the tool outputs to the assistant, which returns the run to the queued state. - run = client.SubmitToolOutputsToRun(run.ThreadId, run.Id, toolOutputs); + runOperation.SubmitToolOutputsToRun(toolOutputs); } } + #endregion - #region - // With the run complete, list the messages and display their content - if (run.Status == RunStatus.Completed) + #region Get and display messages + + // If the run completed successfully, list the messages and display their content + if (runOperation.Status == RunStatus.Completed) { PageCollection messagePages - = client.GetMessages(run.ThreadId, new MessageCollectionOptions() { Order = ListOrder.OldestFirst }); + = client.GetMessages(runOperation.ThreadId, new MessageCollectionOptions() { Order = ListOrder.OldestFirst }); IEnumerable messages = messagePages.GetAllValues(); foreach (ThreadMessage message in messages) @@ -186,7 +186,7 @@ PageCollection messagePages } else { - throw new NotImplementedException(run.Status.ToString()); + throw new NotImplementedException(runOperation.Status.ToString()); } #endregion } diff --git a/examples/Assistants/Example02_FunctionCallingAsync.cs b/examples/Assistants/Example02_FunctionCallingAsync.cs index 2ee924c5..a828deb9 100644 --- a/examples/Assistants/Example02_FunctionCallingAsync.cs +++ b/examples/Assistants/Example02_FunctionCallingAsync.cs @@ -85,22 +85,22 @@ string GetCurrentWeather(string location, string unit = "celsius") InitialMessages = { "What's the weather like today?" } }; - ThreadRun run = await client.CreateThreadAndRunAsync(assistant.Id, threadOptions); + RunOperation runOperation = await client.CreateThreadAndRunAsync(ReturnWhen.Started, assistant.Id, threadOptions); #endregion #region - // Poll the run until it is no longer queued or in progress. - while (!run.Status.IsTerminal) - { - await Task.Delay(TimeSpan.FromSeconds(1)); - run = await client.GetRunAsync(run.ThreadId, run.Id); + + IEnumerable updates = runOperation.GetUpdates(); + + foreach (ThreadRun update in updates) + { // If the run requires action, resolve them. - if (run.Status == RunStatus.RequiresAction) + if (runOperation.Status == RunStatus.RequiresAction) { List toolOutputs = []; - foreach (RequiredAction action in run.RequiredActions) + foreach (RequiredAction action in runOperation.Value.RequiredActions) { switch (action.FunctionName) { @@ -142,17 +142,17 @@ string GetCurrentWeather(string location, string unit = "celsius") } // Submit the tool outputs to the assistant, which returns the run to the queued state. - run = await client.SubmitToolOutputsToRunAsync(run.ThreadId, run.Id, toolOutputs); + await runOperation.SubmitToolOutputsToRunAsync(toolOutputs); } } #endregion #region // With the run complete, list the messages and display their content - if (run.Status == RunStatus.Completed) + if (runOperation.Status == RunStatus.Completed) { AsyncPageCollection messagePages - = client.GetMessagesAsync(run.ThreadId, new MessageCollectionOptions() { Order = ListOrder.OldestFirst }); + = client.GetMessagesAsync(runOperation.ThreadId, new MessageCollectionOptions() { Order = ListOrder.OldestFirst }); IAsyncEnumerable messages = messagePages.GetAllValuesAsync(); await foreach (ThreadMessage message in messages) @@ -186,7 +186,7 @@ AsyncPageCollection messagePages } else { - throw new NotImplementedException(run.Status.ToString()); + throw new NotImplementedException(runOperation.Status.ToString()); } #endregion } diff --git a/examples/Assistants/Example02b_FunctionCallingStreaming.cs b/examples/Assistants/Example02b_FunctionCallingStreaming.cs index 9c3e0adf..38b81177 100644 --- a/examples/Assistants/Example02b_FunctionCallingStreaming.cs +++ b/examples/Assistants/Example02b_FunctionCallingStreaming.cs @@ -1,7 +1,6 @@ using NUnit.Framework; using OpenAI.Assistants; using System; -using System.ClientModel; using System.ClientModel.Primitives; using System.Collections.Generic; using System.Threading.Tasks; @@ -91,43 +90,34 @@ public async Task Example02b_FunctionCallingStreaming() #endregion #region Step 3 - Initiate a streaming run - AsyncCollectionResult asyncUpdates - = client.CreateRunStreamingAsync(thread, assistant); + StreamingRunOperation runOperation = client.CreateRunStreaming(thread, assistant); + IAsyncEnumerable updates = runOperation.GetUpdatesStreamingAsync(); - ThreadRun currentRun = null; - do + await foreach (StreamingUpdate update in updates) { - currentRun = null; - List outputsToSubmit = []; - await foreach (StreamingUpdate update in asyncUpdates) + if (update is RequiredActionUpdate requiredActionUpdate) { - if (update is RunUpdate runUpdate) - { - currentRun = runUpdate; - } - else if (update is RequiredActionUpdate requiredActionUpdate) + List outputsToSubmit = []; + + foreach (RequiredAction action in requiredActionUpdate.RequiredActions) { - if (requiredActionUpdate.FunctionName == getTemperatureTool.FunctionName) + if (action.FunctionName == getTemperatureTool.FunctionName) { - outputsToSubmit.Add(new ToolOutput(requiredActionUpdate.ToolCallId, "57")); + outputsToSubmit.Add(new ToolOutput(action.ToolCallId, "57")); } - else if (requiredActionUpdate.FunctionName == getRainProbabilityTool.FunctionName) + else if (action.FunctionName == getRainProbabilityTool.FunctionName) { - outputsToSubmit.Add(new ToolOutput(requiredActionUpdate.ToolCallId, "25%")); + outputsToSubmit.Add(new ToolOutput(action.ToolCallId, "25%")); } } - else if (update is MessageContentUpdate contentUpdate) - { - Console.Write(contentUpdate.Text); - } + + await runOperation.SubmitToolOutputsToRunStreamingAsync(outputsToSubmit); } - if (outputsToSubmit.Count > 0) + else if (update is MessageContentUpdate contentUpdate) { - asyncUpdates = client.SubmitToolOutputsToRunStreamingAsync(currentRun, outputsToSubmit); + Console.Write(contentUpdate.Text); } } - while (currentRun?.Status.IsTerminal == false); - #endregion // Optionally, delete the resources for tidiness if no longer needed. diff --git a/examples/Assistants/Example04_AllTheTools.cs b/examples/Assistants/Example04_AllTheTools.cs index 3d55253d..80fd09e7 100644 --- a/examples/Assistants/Example04_AllTheTools.cs +++ b/examples/Assistants/Example04_AllTheTools.cs @@ -5,7 +5,6 @@ using System.ClientModel; using System.Collections.Generic; using System.Text.Json; -using System.Threading; namespace OpenAI.Examples; @@ -89,22 +88,20 @@ static string GetNameOfFamilyMember(string relation) } }); - ThreadRun run = client.CreateRun(thread, assistant); + RunOperation runOperation = client.CreateRun(ReturnWhen.Started, thread, assistant); #endregion #region Complete the run, calling functions as needed - // Poll the run until it is no longer queued or in progress. - while (!run.Status.IsTerminal) - { - Thread.Sleep(TimeSpan.FromSeconds(1)); - run = client.GetRun(run.ThreadId, run.Id); + IEnumerable updates = runOperation.GetUpdates(); + foreach (ThreadRun update in updates) + { // If the run requires action, resolve them. - if (run.Status == RunStatus.RequiresAction) + if (runOperation.Status == RunStatus.RequiresAction) { List toolOutputs = []; - foreach (RequiredAction action in run.RequiredActions) + foreach (RequiredAction action in runOperation.Value.RequiredActions) { switch (action.FunctionName) { @@ -128,17 +125,17 @@ static string GetNameOfFamilyMember(string relation) } // Submit the tool outputs to the assistant, which returns the run to the queued state. - run = client.SubmitToolOutputsToRun(run.ThreadId, run.Id, toolOutputs); + runOperation.SubmitToolOutputsToRun(toolOutputs); } } #endregion #region // With the run complete, list the messages and display their content - if (run.Status == RunStatus.Completed) + if (runOperation.Status == RunStatus.Completed) { PageCollection messagePages - = client.GetMessages(run.ThreadId, new MessageCollectionOptions() { Order = ListOrder.OldestFirst }); + = client.GetMessages(runOperation.ThreadId, new MessageCollectionOptions() { Order = ListOrder.OldestFirst }); IEnumerable messages = messagePages.GetAllValues(); foreach (ThreadMessage message in messages) @@ -172,11 +169,7 @@ PageCollection messagePages #endregion #region List run steps for details about tool calls - PageCollection runSteps = client.GetRunSteps( - run, new RunStepCollectionOptions() - { - Order = ListOrder.OldestFirst - }); + PageCollection runSteps = runOperation.GetRunSteps(new RunStepCollectionOptions() { Order = ListOrder.OldestFirst }); foreach (RunStep step in runSteps.GetAllValues()) { Console.WriteLine($"Run step: {step.Status}"); @@ -193,7 +186,7 @@ PageCollection messagePages } else { - throw new NotImplementedException(run.Status.ToString()); + throw new NotImplementedException(runOperation.Status.ToString()); } #endregion diff --git a/examples/Assistants/Example05_AssistantsWithVision.cs b/examples/Assistants/Example05_AssistantsWithVision.cs index 4d10c84c..178df6a7 100644 --- a/examples/Assistants/Example05_AssistantsWithVision.cs +++ b/examples/Assistants/Example05_AssistantsWithVision.cs @@ -44,7 +44,7 @@ public void Example05_AssistantsWithVision() } }); - CollectionResult streamingUpdates = assistantClient.CreateRunStreaming( + StreamingRunOperation runOperation = assistantClient.CreateRunStreaming( thread, assistant, new RunCreationOptions() @@ -52,7 +52,7 @@ public void Example05_AssistantsWithVision() AdditionalInstructions = "When possible, try to sneak in puns if you're asked to compare things.", }); - foreach (StreamingUpdate streamingUpdate in streamingUpdates) + foreach (StreamingUpdate streamingUpdate in runOperation.GetUpdatesStreaming()) { if (streamingUpdate.UpdateKind == StreamingUpdateReason.RunCreated) { diff --git a/examples/Assistants/Example05_AssistantsWithVisionAsync.cs b/examples/Assistants/Example05_AssistantsWithVisionAsync.cs index 3f79137e..8f9129d0 100644 --- a/examples/Assistants/Example05_AssistantsWithVisionAsync.cs +++ b/examples/Assistants/Example05_AssistantsWithVisionAsync.cs @@ -45,7 +45,7 @@ public async Task Example05_AssistantsWithVisionAsync() } }); - AsyncCollectionResult streamingUpdates = assistantClient.CreateRunStreamingAsync( + StreamingRunOperation runOperation = assistantClient.CreateRunStreaming( thread, assistant, new RunCreationOptions() @@ -53,7 +53,7 @@ public async Task Example05_AssistantsWithVisionAsync() AdditionalInstructions = "When possible, try to sneak in puns if you're asked to compare things.", }); - await foreach (StreamingUpdate streamingUpdate in streamingUpdates) + await foreach (StreamingUpdate streamingUpdate in runOperation.GetUpdatesStreamingAsync()) { if (streamingUpdate.UpdateKind == StreamingUpdateReason.RunCreated) { diff --git a/src/Custom/Assistants/AssistantClient.Convenience.cs b/src/Custom/Assistants/AssistantClient.Convenience.cs index 5eb15133..fa4e020d 100644 --- a/src/Custom/Assistants/AssistantClient.Convenience.cs +++ b/src/Custom/Assistants/AssistantClient.Convenience.cs @@ -218,8 +218,12 @@ public virtual ClientResult DeleteMessage(ThreadMessage message) /// The assistant that should be used when evaluating the thread. /// Additional options for the run. /// A new instance. - public virtual Task> CreateRunAsync(AssistantThread thread, Assistant assistant, RunCreationOptions options = null) - => CreateRunAsync(thread?.Id, assistant?.Id, options); + public virtual async Task CreateRunAsync( + ReturnWhen returnWhen, + AssistantThread thread, + Assistant assistant, + RunCreationOptions options = null) + => await CreateRunAsync(returnWhen, thread?.Id, assistant?.Id, options).ConfigureAwait(false); /// /// Begins a new that evaluates a using a specified @@ -229,21 +233,12 @@ public virtual Task> CreateRunAsync(AssistantThread thre /// The assistant that should be used when evaluating the thread. /// Additional options for the run. /// A new instance. - public virtual ClientResult CreateRun(AssistantThread thread, Assistant assistant, RunCreationOptions options = null) - => CreateRun(thread?.Id, assistant?.Id, options); - - /// - /// Begins a new streaming that evaluates a using a specified - /// . - /// - /// The thread that the run should evaluate. - /// The assistant that should be used when evaluating the thread. - /// Additional options for the run. - public virtual AsyncCollectionResult CreateRunStreamingAsync( - AssistantThread thread, - Assistant assistant, + public virtual RunOperation CreateRun( + ReturnWhen returnWhen, + AssistantThread thread, + Assistant assistant, RunCreationOptions options = null) - => CreateRunStreamingAsync(thread?.Id, assistant?.Id, options); + => CreateRun(returnWhen, thread?.Id, assistant?.Id, options); /// /// Begins a new streaming that evaluates a using a specified @@ -252,7 +247,7 @@ public virtual AsyncCollectionResult CreateRunStreamingAsync( /// The thread that the run should evaluate. /// The assistant that should be used when evaluating the thread. /// Additional options for the run. - public virtual CollectionResult CreateRunStreaming( + public virtual StreamingRunOperation CreateRunStreaming( AssistantThread thread, Assistant assistant, RunCreationOptions options = null) @@ -265,11 +260,12 @@ public virtual CollectionResult CreateRunStreaming( /// Options for the new thread that will be created. /// Additional options to apply to the run that will begin. /// A new . - public virtual Task> CreateThreadAndRunAsync( + public virtual async Task CreateThreadAndRunAsync( + ReturnWhen returnWhen, Assistant assistant, ThreadCreationOptions threadOptions = null, RunCreationOptions runOptions = null) - => CreateThreadAndRunAsync(assistant?.Id, threadOptions, runOptions); + => await CreateThreadAndRunAsync(returnWhen, assistant?.Id, threadOptions, runOptions).ConfigureAwait(false); /// /// Creates a new thread and immediately begins a run against it using the specified . @@ -278,23 +274,12 @@ public virtual Task> CreateThreadAndRunAsync( /// Options for the new thread that will be created. /// Additional options to apply to the run that will begin. /// A new . - public virtual ClientResult CreateThreadAndRun( - Assistant assistant, - ThreadCreationOptions threadOptions = null, - RunCreationOptions runOptions = null) - => CreateThreadAndRun(assistant?.Id, threadOptions, runOptions); - - /// - /// Creates a new thread and immediately begins a streaming run against it using the specified . - /// - /// The assistant that the new run should use. - /// Options for the new thread that will be created. - /// Additional options to apply to the run that will begin. - public virtual AsyncCollectionResult CreateThreadAndRunStreamingAsync( + public virtual RunOperation CreateThreadAndRun( + ReturnWhen returnWhen, Assistant assistant, ThreadCreationOptions threadOptions = null, RunCreationOptions runOptions = null) - => CreateThreadAndRunStreamingAsync(assistant?.Id, threadOptions, runOptions); + => CreateThreadAndRun(returnWhen, assistant?.Id, threadOptions, runOptions); /// /// Creates a new thread and immediately begins a streaming run against it using the specified . @@ -302,7 +287,7 @@ public virtual AsyncCollectionResult CreateThreadAndRunStreamin /// The assistant that the new run should use. /// Options for the new thread that will be created. /// Additional options to apply to the run that will begin. - public virtual CollectionResult CreateThreadAndRunStreaming( + public virtual StreamingRunOperation CreateThreadAndRunStreaming( Assistant assistant, ThreadCreationOptions threadOptions = null, RunCreationOptions runOptions = null) @@ -343,122 +328,4 @@ public virtual PageCollection GetRuns( return GetRuns(thread.Id, options); } - - /// - /// Gets a refreshed instance of an existing . - /// - /// The run to get a refreshed instance of. - /// A new instance with updated information. - public virtual Task> GetRunAsync(ThreadRun run) - => GetRunAsync(run?.ThreadId, run?.Id); - - /// - /// Gets a refreshed instance of an existing . - /// - /// The run to get a refreshed instance of. - /// A new instance with updated information. - public virtual ClientResult GetRun(ThreadRun run) - => GetRun(run?.ThreadId, run?.Id); - - /// - /// Submits a collection of required tool call outputs to a run and resumes the run. - /// - /// The run that reached a requires_action status. - /// - /// The tool outputs, corresponding to instances from the run. - /// - /// The , updated after the submission was processed. - public virtual Task> SubmitToolOutputsToRunAsync( - ThreadRun run, - IEnumerable toolOutputs) - => SubmitToolOutputsToRunAsync(run?.ThreadId, run?.Id, toolOutputs); - - /// - /// Submits a collection of required tool call outputs to a run and resumes the run. - /// - /// The run that reached a requires_action status. - /// - /// The tool outputs, corresponding to instances from the run. - /// - /// The , updated after the submission was processed. - public virtual ClientResult SubmitToolOutputsToRun( - ThreadRun run, - IEnumerable toolOutputs) - => SubmitToolOutputsToRun(run?.ThreadId, run?.Id, toolOutputs); - - /// - /// Submits a collection of required tool call outputs to a run and resumes the run with streaming enabled. - /// - /// The run that reached a requires_action status. - /// - /// The tool outputs, corresponding to instances from the run. - /// - public virtual AsyncCollectionResult SubmitToolOutputsToRunStreamingAsync( - ThreadRun run, - IEnumerable toolOutputs) - => SubmitToolOutputsToRunStreamingAsync(run?.ThreadId, run?.Id, toolOutputs); - - /// - /// Submits a collection of required tool call outputs to a run and resumes the run with streaming enabled. - /// - /// The run that reached a requires_action status. - /// - /// The tool outputs, corresponding to instances from the run. - /// - public virtual CollectionResult SubmitToolOutputsToRunStreaming( - ThreadRun run, - IEnumerable toolOutputs) - => SubmitToolOutputsToRunStreaming(run?.ThreadId, run?.Id, toolOutputs); - - /// - /// Cancels an in-progress . - /// - /// The run to cancel. - /// An updated instance, reflecting the new status of the run. - public virtual Task> CancelRunAsync(ThreadRun run) - => CancelRunAsync(run?.ThreadId, run?.Id); - - /// - /// Cancels an in-progress . - /// - /// The run to cancel. - /// An updated instance, reflecting the new status of the run. - public virtual ClientResult CancelRun(ThreadRun run) - => CancelRun(run?.ThreadId, run?.Id); - - /// - /// Gets a page collection holding instances associated with a . - /// - /// The run to list run steps from. - /// Options describing the collection to return. - /// holds pages of values. To obtain a collection of values, call - /// . To obtain the current - /// page of values, call . - /// A collection of pages of . - public virtual AsyncPageCollection GetRunStepsAsync( - ThreadRun run, - RunStepCollectionOptions options = default) - { - Argument.AssertNotNull(run, nameof(run)); - - return GetRunStepsAsync(run.ThreadId, run.Id, options); - } - - /// - /// Gets a page collection holding instances associated with a . - /// - /// The run to list run steps from. - /// Options describing the collection to return. - /// holds pages of values. To obtain a collection of values, call - /// . To obtain the current - /// page of values, call . - /// A collection of pages of . - public virtual PageCollection GetRunSteps( - ThreadRun run, - RunStepCollectionOptions options = default) - { - Argument.AssertNotNull(run, nameof(run)); - - return GetRunSteps(run.ThreadId, run.Id, options); - } } diff --git a/src/Custom/Assistants/AssistantClient.Protocol.cs b/src/Custom/Assistants/AssistantClient.Protocol.cs index d96ee961..df79e33f 100644 --- a/src/Custom/Assistants/AssistantClient.Protocol.cs +++ b/src/Custom/Assistants/AssistantClient.Protocol.cs @@ -329,25 +329,65 @@ public virtual Task DeleteMessageAsync(string threadId, string mes public virtual ClientResult DeleteMessage(string threadId, string messageId, RequestOptions options) => _messageSubClient.DeleteMessage(threadId, messageId, options); - /// [EditorBrowsable(EditorBrowsableState.Never)] - public virtual Task CreateThreadAndRunAsync(BinaryContent content, RequestOptions options = null) - => _runSubClient.CreateThreadAndRunAsync(content, options); - - /// - [EditorBrowsable(EditorBrowsableState.Never)] - public virtual ClientResult CreateThreadAndRun(BinaryContent content, RequestOptions options = null) - => _runSubClient.CreateThreadAndRun(content, options = null); - - /// - [EditorBrowsable(EditorBrowsableState.Never)] - public virtual Task CreateRunAsync(string threadId, BinaryContent content, RequestOptions options = null) - => _runSubClient.CreateRunAsync(threadId, content, options); + public virtual async Task CreateThreadAndRunAsync( + ReturnWhen returnWhen, + BinaryContent content, + RequestOptions options = null) + { + ClientResult result = await _runSubClient.CreateThreadAndRunAsync(content, options).ConfigureAwait(false); + PipelineResponse response = result.GetRawResponse(); + RunOperation operation = new RunOperation(_pipeline, _endpoint, response); + + if (returnWhen == ReturnWhen.Started) + { + return operation; + } + + bool isStreaming = false; + if (response.Headers.TryGetValue("Content-Type", out string contentType)) + { + isStreaming = contentType == "text/event-stream; charset=utf-8"; + } + + if (isStreaming) + { + throw new NotSupportedException("Streaming runs cannot use 'ReturnWhen.Completed'"); + } + + await operation.WaitUntilStoppedAsync(options?.CancellationToken ?? default).ConfigureAwait(false); + return operation; + } - /// [EditorBrowsable(EditorBrowsableState.Never)] - public virtual ClientResult CreateRun(string threadId, BinaryContent content, RequestOptions options = null) - => _runSubClient.CreateRun(threadId, content, options); + public virtual RunOperation CreateThreadAndRun( + ReturnWhen returnWhen, + BinaryContent content, + RequestOptions options = null) + { + ClientResult result = _runSubClient.CreateThreadAndRun(content, options); + PipelineResponse response = result.GetRawResponse(); + RunOperation operation = new RunOperation(_pipeline, _endpoint, response); + + if (returnWhen == ReturnWhen.Started) + { + return operation; + } + + bool isStreaming = false; + if (response.Headers.TryGetValue("Content-Type", out string contentType)) + { + isStreaming = contentType == "text/event-stream; charset=utf-8"; + } + + if (isStreaming) + { + throw new NotSupportedException("Streaming runs cannot use 'ReturnWhen.Completed'"); + } + + operation.WaitUntilStopped(options?.CancellationToken ?? default); + return operation; + } /// /// [Protocol Method] Returns a paginated collection of runs belonging to a thread. @@ -421,131 +461,97 @@ public virtual IEnumerable GetRuns(string threadId, int? limit, st return PageCollectionHelpers.Create(enumerator); } - /// - [EditorBrowsable(EditorBrowsableState.Never)] - public virtual Task GetRunAsync(string threadId, string runId, RequestOptions options) - => _runSubClient.GetRunAsync(threadId, runId, options); - - /// - [EditorBrowsable(EditorBrowsableState.Never)] - public virtual ClientResult GetRun(string threadId, string runId, RequestOptions options) - => _runSubClient.GetRun(threadId, runId, options); - - /// - [EditorBrowsable(EditorBrowsableState.Never)] - public virtual Task ModifyRunAsync(string threadId, string runId, BinaryContent content, RequestOptions options = null) - => _runSubClient.ModifyRunAsync(threadId, runId, content, options); - - /// [EditorBrowsable(EditorBrowsableState.Never)] - public virtual ClientResult ModifyRun(string threadId, string runId, BinaryContent content, RequestOptions options = null) - => _runSubClient.ModifyRun(threadId, runId, content, options); - - /// - [EditorBrowsable(EditorBrowsableState.Never)] - public virtual Task CancelRunAsync(string threadId, string runId, RequestOptions options) - => _runSubClient.CancelRunAsync(threadId, runId, options); + public virtual async Task CreateRunAsync( + ReturnWhen returnWhen, + string threadId, + BinaryContent content, + RequestOptions options = null) + { + ClientResult result = await _runSubClient.CreateRunAsync(threadId, content, options).ConfigureAwait(false); + PipelineResponse response = result.GetRawResponse(); + RunOperation operation = new RunOperation(_pipeline, _endpoint, response); + + if (returnWhen == ReturnWhen.Started) + { + return operation; + } + + bool isStreaming = false; + if (response.Headers.TryGetValue("Content-Type", out string contentType)) + { + isStreaming = contentType == "text/event-stream; charset=utf-8"; + } + + if (isStreaming) + { + throw new NotSupportedException("Streaming runs cannot use 'ReturnWhen.Completed'"); + } + + await operation.WaitUntilStoppedAsync(options?.CancellationToken ?? default).ConfigureAwait(false); + return operation; + } - /// [EditorBrowsable(EditorBrowsableState.Never)] - public virtual ClientResult CancelRun(string threadId, string runId, RequestOptions options) - => _runSubClient.CancelRun(threadId, runId, options); + public virtual RunOperation CreateRun( + ReturnWhen returnWhen, + string threadId, + BinaryContent content, + RequestOptions options = null) + { + ClientResult result = _runSubClient.CreateRun(threadId, content, options); + PipelineResponse response = result.GetRawResponse(); + RunOperation operation = new RunOperation(_pipeline, _endpoint, response); + + if (returnWhen == ReturnWhen.Started) + { + return operation; + } + + bool isStreaming = false; + if (response.Headers.TryGetValue("Content-Type", out string contentType)) + { + isStreaming = contentType == "text/event-stream; charset=utf-8"; + } + + if (isStreaming) + { + throw new NotSupportedException("Streaming runs cannot use 'ReturnWhen.Completed'"); + } + + operation.WaitUntilStopped(options?.CancellationToken ?? default); + return operation; + } - /// + /// [EditorBrowsable(EditorBrowsableState.Never)] - public virtual Task SubmitToolOutputsToRunAsync(string threadId, string runId, BinaryContent content, RequestOptions options = null) - => _runSubClient.SubmitToolOutputsToRunAsync(threadId, runId, content, options); + internal virtual Task CreateThreadAndRunAsync(BinaryContent content, RequestOptions options = null) + => _runSubClient.CreateThreadAndRunAsync(content, options); - /// + /// [EditorBrowsable(EditorBrowsableState.Never)] - public virtual ClientResult SubmitToolOutputsToRun(string threadId, string runId, BinaryContent content, RequestOptions options = null) - => _runSubClient.SubmitToolOutputsToRun(threadId, runId, content, options); + internal virtual ClientResult CreateThreadAndRun(BinaryContent content, RequestOptions options = null) + => _runSubClient.CreateThreadAndRun(content, options); - /// - /// [Protocol Method] Returns a paginated collection of run steps belonging to a run. - /// - /// The ID of the thread the run and run steps belong to. - /// The ID of the run the run steps belong to. - /// - /// A limit on the number of objects to be returned. Limit can range between 1 and 100, and the - /// default is 20. - /// - /// - /// Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and`desc` - /// for descending order. Allowed values: "asc" | "desc" - /// - /// - /// A cursor for use in pagination. `after` is an object ID that defines your place in the list. - /// For instance, if you make a list request and receive 100 objects, ending with obj_foo, your - /// subsequent call can include after=obj_foo in order to fetch the next page of the list. - /// - /// - /// A cursor for use in pagination. `before` is an object ID that defines your place in the list. - /// For instance, if you make a list request and receive 100 objects, ending with obj_foo, your - /// subsequent call can include before=obj_foo in order to fetch the previous page of the list. - /// - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// or is null. - /// or is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// A collection of service responses, each holding a page of values. + /// [EditorBrowsable(EditorBrowsableState.Never)] - public virtual IAsyncEnumerable GetRunStepsAsync(string threadId, string runId, int? limit, string order, string after, string before, RequestOptions options) - { - Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); - Argument.AssertNotNullOrEmpty(runId, nameof(runId)); - - RunStepsPageEnumerator enumerator = new RunStepsPageEnumerator(_pipeline, _endpoint, threadId, runId, limit, order, after, before, options); - return PageCollectionHelpers.CreateAsync(enumerator); - } + internal virtual Task CreateRunAsync(string threadId, BinaryContent content, RequestOptions options = null) + => _runSubClient.CreateRunAsync(threadId, content, options); - /// - /// [Protocol Method] Returns a paginated collection of run steps belonging to a run. - /// - /// The ID of the thread the run and run steps belong to. - /// The ID of the run the run steps belong to. - /// - /// A limit on the number of objects to be returned. Limit can range between 1 and 100, and the - /// default is 20. - /// - /// - /// Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and`desc` - /// for descending order. Allowed values: "asc" | "desc" - /// - /// - /// A cursor for use in pagination. `after` is an object ID that defines your place in the list. - /// For instance, if you make a list request and receive 100 objects, ending with obj_foo, your - /// subsequent call can include after=obj_foo in order to fetch the next page of the list. - /// - /// - /// A cursor for use in pagination. `before` is an object ID that defines your place in the list. - /// For instance, if you make a list request and receive 100 objects, ending with obj_foo, your - /// subsequent call can include before=obj_foo in order to fetch the previous page of the list. - /// - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// or is null. - /// or is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// A collection of service responses, each holding a page of values. + /// [EditorBrowsable(EditorBrowsableState.Never)] - public virtual IEnumerable GetRunSteps(string threadId, string runId, int? limit, string order, string after, string before, RequestOptions options) - { - Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); - Argument.AssertNotNullOrEmpty(runId, nameof(runId)); - - RunStepsPageEnumerator enumerator = new RunStepsPageEnumerator(_pipeline, _endpoint, threadId, runId, limit, order, after, before, options); - return PageCollectionHelpers.Create(enumerator); - } + internal virtual ClientResult CreateRun(string threadId, BinaryContent content, RequestOptions options = null) + => _runSubClient.CreateRun(threadId, content, options); - /// + /// [EditorBrowsable(EditorBrowsableState.Never)] - public virtual Task GetRunStepAsync(string threadId, string runId, string stepId, RequestOptions options) - => _runSubClient.GetRunStepAsync(threadId, runId, stepId, options); + internal virtual Task GetRunAsync(string threadId, string runId, RequestOptions options = null) + => _runSubClient.GetRunAsync(threadId, runId, options); - /// + /// [EditorBrowsable(EditorBrowsableState.Never)] - public virtual ClientResult GetRunStep(string threadId, string runId, string stepId, RequestOptions options) - => _runSubClient.GetRunStep(threadId, runId, stepId, options); + internal virtual ClientResult GetRun(string threadId, string runId, RequestOptions options = null) + => _runSubClient.GetRun(threadId, runId, options); /// [EditorBrowsable(EditorBrowsableState.Never)] diff --git a/src/Custom/Assistants/AssistantClient.cs b/src/Custom/Assistants/AssistantClient.cs index 611ef7ed..41850cdd 100644 --- a/src/Custom/Assistants/AssistantClient.cs +++ b/src/Custom/Assistants/AssistantClient.cs @@ -3,8 +3,8 @@ using System.ClientModel.Primitives; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; -using System.Linq; using System.Runtime.CompilerServices; +using System.Text; using System.Threading; using System.Threading.Tasks; @@ -18,10 +18,6 @@ namespace OpenAI.Assistants; [CodeGenSuppress("AssistantClient", typeof(ClientPipeline), typeof(ApiKeyCredential), typeof(Uri))] [CodeGenSuppress("CreateAssistantAsync", typeof(AssistantCreationOptions))] [CodeGenSuppress("CreateAssistant", typeof(AssistantCreationOptions))] -[CodeGenSuppress("GetAssistantAsync", typeof(string))] -[CodeGenSuppress("GetAssistant", typeof(string))] -[CodeGenSuppress("ModifyAssistantAsync", typeof(string), typeof(AssistantModificationOptions))] -[CodeGenSuppress("ModifyAssistant", typeof(string), typeof(AssistantModificationOptions))] [CodeGenSuppress("DeleteAssistantAsync", typeof(string))] [CodeGenSuppress("DeleteAssistant", typeof(string))] [CodeGenSuppress("GetAssistantsAsync", typeof(int?), typeof(ListOrder?), typeof(string), typeof(string))] @@ -691,77 +687,103 @@ public virtual ClientResult DeleteMessage(string threadId, string messageI /// Additional options for the run. /// A token that can be used to cancel this method call. /// A new instance. - public virtual async Task> CreateRunAsync(string threadId, string assistantId, RunCreationOptions options = null, CancellationToken cancellationToken = default) + public virtual async Task CreateRunAsync( + ReturnWhen returnWhen, + string threadId, + string assistantId, + RunCreationOptions options = null, + CancellationToken cancellationToken = default) { Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); Argument.AssertNotNullOrEmpty(assistantId, nameof(assistantId)); + options ??= new(); options.AssistantId = assistantId; options.Stream = null; - ClientResult protocolResult = await CreateRunAsync(threadId, options.ToBinaryContent(), cancellationToken.ToRequestOptions()) - .ConfigureAwait(false); - return CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse); + + ClientResult protocolResult = await CreateRunAsync(threadId, options.ToBinaryContent(), cancellationToken.ToRequestOptions()).ConfigureAwait(false); + ClientResult result = CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse); + + RunOperation operation = new RunOperation(_pipeline, _endpoint, + value: result, + status: result.Value.Status, + result.GetRawResponse()); + + if (returnWhen == ReturnWhen.Started) + { + return operation; + } + + await operation.WaitUntilStoppedAsync().ConfigureAwait(false); + return operation; } - /// - /// Begins a new that evaluates a using a specified - /// . - /// - /// The ID of the thread that the run should evaluate. - /// The ID of the assistant that should be used when evaluating the thread. - /// Additional options for the run. - /// A token that can be used to cancel this method call. - /// A new instance. - public virtual ClientResult CreateRun(string threadId, string assistantId, RunCreationOptions options = null, CancellationToken cancellationToken = default) + public virtual RunOperation CreateRun( + ReturnWhen returnWhen, + string threadId, + string assistantId, + RunCreationOptions options = null, + CancellationToken cancellationToken = default) { Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); Argument.AssertNotNullOrEmpty(assistantId, nameof(assistantId)); + options ??= new(); options.AssistantId = assistantId; options.Stream = null; + ClientResult protocolResult = CreateRun(threadId, options.ToBinaryContent(), cancellationToken.ToRequestOptions()); - return CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse); + ClientResult result = CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse); + + RunOperation operation = new RunOperation(_pipeline, _endpoint, + value: result, + status: result.Value.Status, + result.GetRawResponse()); + + if (returnWhen == ReturnWhen.Started) + { + return operation; + } + + operation.WaitUntilStopped(); + return operation; } - /// - /// Begins a new streaming that evaluates a using a specified - /// . - /// - /// The ID of the thread that the run should evaluate. - /// The ID of the assistant that should be used when evaluating the thread. - /// Additional options for the run. - /// A token that can be used to cancel this method call. - public virtual AsyncCollectionResult CreateRunStreamingAsync( - string threadId, - string assistantId, - RunCreationOptions options = null, + public virtual async Task ContinueRunAsync( + ContinuationToken rehydrationToken, CancellationToken cancellationToken = default) { - Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); - Argument.AssertNotNullOrEmpty(assistantId, nameof(assistantId)); + Argument.AssertNotNull(rehydrationToken, nameof(rehydrationToken)); - options ??= new(); - options.AssistantId = assistantId; - options.Stream = true; + RunOperationToken token = RunOperationToken.FromToken(rehydrationToken); + ClientResult protocolResult = await GetRunAsync(token.ThreadId, token.RunId, cancellationToken.ToRequestOptions()).ConfigureAwait(false); + ClientResult result = CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse); - async Task getResultAsync() => - await CreateRunAsync(threadId, options.ToBinaryContent(), cancellationToken.ToRequestOptions(streaming: true)) - .ConfigureAwait(false); + return new RunOperation(_pipeline, _endpoint, + value: result, + status: result.Value.Status, + result.GetRawResponse()); + } + + public virtual RunOperation ContinueRun( + ContinuationToken rehydrationToken, + CancellationToken cancellationToken = default) + { + Argument.AssertNotNull(rehydrationToken, nameof(rehydrationToken)); + + RunOperationToken token = RunOperationToken.FromToken(rehydrationToken); + ClientResult protocolResult = GetRun(token.ThreadId, token.RunId, cancellationToken.ToRequestOptions()); + ClientResult result = CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse); - return new AsyncStreamingUpdateCollection(getResultAsync); + return new RunOperation(_pipeline, _endpoint, + value: result, + status: result.Value.Status, + result.GetRawResponse()); } - /// - /// Begins a new streaming that evaluates a using a specified - /// . - /// - /// The ID of the thread that the run should evaluate. - /// The ID of the assistant that should be used when evaluating the thread. - /// Additional options for the run. - /// A token that can be used to cancel this method call. - public virtual CollectionResult CreateRunStreaming( + public virtual StreamingRunOperation CreateRunStreaming( string threadId, string assistantId, RunCreationOptions options = null, @@ -774,61 +796,51 @@ public virtual CollectionResult CreateRunStreaming( options.AssistantId = assistantId; options.Stream = true; - ClientResult getResult() => CreateRun(threadId, options.ToBinaryContent(), cancellationToken.ToRequestOptions(streaming: true)); + BinaryContent content = options.ToBinaryContent(); + RequestOptions requestOptions = cancellationToken.ToRequestOptions(); - return new StreamingUpdateCollection(getResult); - } + async Task getResultAsync() => + await _runSubClient.CreateRunAsync(threadId, content, requestOptions) + .ConfigureAwait(false); - /// - /// Creates a new thread and immediately begins a run against it using the specified . - /// - /// The ID of the assistant that the new run should use. - /// Options for the new thread that will be created. - /// Additional options to apply to the run that will begin. - /// A token that can be used to cancel this method call. - /// A new . - public virtual async Task> CreateThreadAndRunAsync( - string assistantId, - ThreadCreationOptions threadOptions = null, - RunCreationOptions runOptions = null, - CancellationToken cancellationToken = default) - { - runOptions ??= new(); - runOptions.Stream = null; - BinaryContent protocolContent = CreateThreadAndRunProtocolContent(assistantId, threadOptions, runOptions); - ClientResult protocolResult = await CreateThreadAndRunAsync(protocolContent, cancellationToken.ToRequestOptions()).ConfigureAwait(false); - return CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse); + ClientResult getResult() => + _runSubClient.CreateRun(threadId, content, requestOptions); + + return new StreamingRunOperation(_pipeline, _endpoint, getResultAsync, getResult); } - /// - /// Creates a new thread and immediately begins a run against it using the specified . - /// - /// The ID of the assistant that the new run should use. - /// Options for the new thread that will be created. - /// Additional options to apply to the run that will begin. - /// A token that can be used to cancel this method call. - /// A new . - public virtual ClientResult CreateThreadAndRun( + public virtual async Task CreateThreadAndRunAsync( + ReturnWhen returnWhen, string assistantId, ThreadCreationOptions threadOptions = null, RunCreationOptions runOptions = null, CancellationToken cancellationToken = default) { + Argument.AssertNotNullOrEmpty(assistantId, nameof(assistantId)); + runOptions ??= new(); + runOptions.AssistantId = assistantId; runOptions.Stream = null; - BinaryContent protocolContent = CreateThreadAndRunProtocolContent(assistantId, threadOptions, runOptions); - ClientResult protocolResult = CreateThreadAndRun(protocolContent, cancellationToken.ToRequestOptions()); - return CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse); + + ClientResult protocolResult = await CreateThreadAndRunAsync(runOptions.ToBinaryContent(), cancellationToken.ToRequestOptions()).ConfigureAwait(false); + ClientResult result = CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse); + + RunOperation operation = new RunOperation(_pipeline, _endpoint, + value: result, + status: result.Value.Status, + result.GetRawResponse()); + + if (returnWhen == ReturnWhen.Started) + { + return operation; + } + + await operation.WaitUntilStoppedAsync().ConfigureAwait(false); + return operation; } - /// - /// Creates a new thread and immediately begins a streaming run against it using the specified . - /// - /// The ID of the assistant that the new run should use. - /// Options for the new thread that will be created. - /// Additional options to apply to the run that will begin. - /// A token that can be used to cancel this method call. - public virtual AsyncCollectionResult CreateThreadAndRunStreamingAsync( + public virtual RunOperation CreateThreadAndRun( + ReturnWhen returnWhen, string assistantId, ThreadCreationOptions threadOptions = null, RunCreationOptions runOptions = null, @@ -837,24 +849,27 @@ public virtual AsyncCollectionResult CreateThreadAndRunStreamin Argument.AssertNotNullOrEmpty(assistantId, nameof(assistantId)); runOptions ??= new(); - runOptions.Stream = true; - BinaryContent protocolContent = CreateThreadAndRunProtocolContent(assistantId, threadOptions, runOptions); + runOptions.AssistantId = assistantId; + runOptions.Stream = null; - async Task getResultAsync() => - await CreateThreadAndRunAsync(protocolContent, cancellationToken.ToRequestOptions(streaming: true)) - .ConfigureAwait(false); + ClientResult protocolResult = CreateThreadAndRun(runOptions.ToBinaryContent(), cancellationToken.ToRequestOptions()); + ClientResult result = CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse); + + RunOperation operation = new RunOperation(_pipeline, _endpoint, + value: result, + status: result.Value.Status, + result.GetRawResponse()); + + if (returnWhen == ReturnWhen.Started) + { + return operation; + } - return new AsyncStreamingUpdateCollection(getResultAsync); + operation.WaitUntilStopped(); + return operation; } - /// - /// Creates a new thread and immediately begins a streaming run against it using the specified . - /// - /// The ID of the assistant that the new run should use. - /// Options for the new thread that will be created. - /// Additional options to apply to the run that will begin. - /// A token that can be used to cancel this method call. - public virtual CollectionResult CreateThreadAndRunStreaming( + public virtual StreamingRunOperation CreateThreadAndRunStreaming( string assistantId, ThreadCreationOptions threadOptions = null, RunCreationOptions runOptions = null, @@ -864,11 +879,18 @@ public virtual CollectionResult CreateThreadAndRunStreaming( runOptions ??= new(); runOptions.Stream = true; + BinaryContent protocolContent = CreateThreadAndRunProtocolContent(assistantId, threadOptions, runOptions); + RequestOptions requestOptions = cancellationToken.ToRequestOptions(); + + async Task getResultAsync() => + await _runSubClient.CreateThreadAndRunAsync(protocolContent, requestOptions) + .ConfigureAwait(false); - ClientResult getResult() => CreateThreadAndRun(protocolContent, cancellationToken.ToRequestOptions(streaming: true)); + ClientResult getResult() => + _runSubClient.CreateThreadAndRun(protocolContent, requestOptions); - return new StreamingUpdateCollection(getResult); + return new StreamingRunOperation(_pipeline, _endpoint, getResultAsync, getResult); } /// @@ -981,321 +1003,6 @@ public virtual PageCollection GetRuns( return PageCollectionHelpers.Create(enumerator); } - /// - /// Gets an existing from a known . - /// - /// The ID of the thread to retrieve the run from. - /// The ID of the run to retrieve. - /// A token that can be used to cancel this method call. - /// The existing instance. - public virtual async Task> GetRunAsync(string threadId, string runId, CancellationToken cancellationToken = default) - { - Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); - Argument.AssertNotNullOrEmpty(runId, nameof(runId)); - - ClientResult protocolResult = await GetRunAsync(threadId, runId, cancellationToken.ToRequestOptions()).ConfigureAwait(false); - return CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse); - } - - /// - /// Gets an existing from a known . - /// - /// The ID of the thread to retrieve the run from. - /// The ID of the run to retrieve. - /// A token that can be used to cancel this method call. - /// The existing instance. - public virtual ClientResult GetRun(string threadId, string runId, CancellationToken cancellationToken = default) - { - Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); - Argument.AssertNotNullOrEmpty(runId, nameof(runId)); - - ClientResult protocolResult = GetRun(threadId, runId, cancellationToken.ToRequestOptions()); - return CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse); - } - - /// - /// Submits a collection of required tool call outputs to a run and resumes the run. - /// - /// The thread ID of the thread being run. - /// The ID of the run that reached a requires_action status. - /// - /// The tool outputs, corresponding to instances from the run. - /// - /// A token that can be used to cancel this method call. - /// The , updated after the submission was processed. - public virtual async Task> SubmitToolOutputsToRunAsync( - string threadId, - string runId, - IEnumerable toolOutputs, - CancellationToken cancellationToken = default) - { - Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); - Argument.AssertNotNullOrEmpty(runId, nameof(runId)); - - BinaryContent content = new InternalSubmitToolOutputsRunRequest(toolOutputs).ToBinaryContent(); - ClientResult protocolResult = await SubmitToolOutputsToRunAsync(threadId, runId, content, cancellationToken.ToRequestOptions()) - .ConfigureAwait(false); - return CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse); - } - - /// - /// Submits a collection of required tool call outputs to a run and resumes the run. - /// - /// The thread ID of the thread being run. - /// The ID of the run that reached a requires_action status. - /// - /// The tool outputs, corresponding to instances from the run. - /// - /// A token that can be used to cancel this method call. - /// The , updated after the submission was processed. - public virtual ClientResult SubmitToolOutputsToRun( - string threadId, - string runId, - IEnumerable toolOutputs, - CancellationToken cancellationToken = default) - { - Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); - Argument.AssertNotNullOrEmpty(runId, nameof(runId)); - - BinaryContent content = new InternalSubmitToolOutputsRunRequest(toolOutputs).ToBinaryContent(); - ClientResult protocolResult = SubmitToolOutputsToRun(threadId, runId, content, cancellationToken.ToRequestOptions()); - return CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse); - } - - /// - /// Submits a collection of required tool call outputs to a run and resumes the run with streaming enabled. - /// - /// The thread ID of the thread being run. - /// The ID of the run that reached a requires_action status. - /// - /// The tool outputs, corresponding to instances from the run. - /// - /// A token that can be used to cancel this method call. - public virtual AsyncCollectionResult SubmitToolOutputsToRunStreamingAsync( - string threadId, - string runId, - IEnumerable toolOutputs, - CancellationToken cancellationToken = default) - { - Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); - Argument.AssertNotNullOrEmpty(runId, nameof(runId)); - - BinaryContent content = new InternalSubmitToolOutputsRunRequest(toolOutputs.ToList(), stream: true, null) - .ToBinaryContent(); - - async Task getResultAsync() => - await SubmitToolOutputsToRunAsync(threadId, runId, content, cancellationToken.ToRequestOptions(streaming: true)) - .ConfigureAwait(false); - - return new AsyncStreamingUpdateCollection(getResultAsync); - } - - /// - /// Submits a collection of required tool call outputs to a run and resumes the run with streaming enabled. - /// - /// The thread ID of the thread being run. - /// The ID of the run that reached a requires_action status. - /// - /// The tool outputs, corresponding to instances from the run. - /// - /// A token that can be used to cancel this method call. - public virtual CollectionResult SubmitToolOutputsToRunStreaming( - string threadId, - string runId, - IEnumerable toolOutputs, - CancellationToken cancellationToken = default) - { - Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); - Argument.AssertNotNullOrEmpty(runId, nameof(runId)); - - BinaryContent content = new InternalSubmitToolOutputsRunRequest(toolOutputs.ToList(), stream: true, null) - .ToBinaryContent(); - - ClientResult getResult() => SubmitToolOutputsToRun(threadId, runId, content, cancellationToken.ToRequestOptions(streaming: true)); - - return new StreamingUpdateCollection(getResult); - } - - /// - /// Cancels an in-progress . - /// - /// The ID of the thread associated with the run. - /// The ID of the run to cancel. - /// A token that can be used to cancel this method call. - /// An updated instance, reflecting the new status of the run. - public virtual async Task> CancelRunAsync(string threadId, string runId, CancellationToken cancellationToken = default) - { - Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); - Argument.AssertNotNullOrEmpty(runId, nameof(runId)); - - ClientResult protocolResult = await CancelRunAsync(threadId, runId, cancellationToken.ToRequestOptions()).ConfigureAwait(false); - return CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse); - } - - /// - /// Cancels an in-progress . - /// - /// The ID of the thread associated with the run. - /// The ID of the run to cancel. - /// A token that can be used to cancel this method call. - /// An updated instance, reflecting the new status of the run. - public virtual ClientResult CancelRun(string threadId, string runId, CancellationToken cancellationToken = default) - { - Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); - Argument.AssertNotNullOrEmpty(runId, nameof(runId)); - - ClientResult protocolResult = CancelRun(threadId, runId, cancellationToken.ToRequestOptions()); - return CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse); - } - - /// - /// Gets a page collection holding instances associated with a . - /// - /// The ID of the thread associated with the run. - /// The ID of the run to list run steps from. - /// - /// A token that can be used to cancel this method call. - /// holds pages of values. To obtain a collection of values, call - /// . To obtain the current - /// page of values, call . - /// A collection of pages of . - public virtual AsyncPageCollection GetRunStepsAsync( - string threadId, - string runId, - RunStepCollectionOptions options = default, - CancellationToken cancellationToken = default) - { - Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); - Argument.AssertNotNullOrEmpty(runId, nameof(runId)); - - RunStepsPageEnumerator enumerator = new(_pipeline, _endpoint, - threadId, - runId, - options?.PageSize, - options?.Order?.ToString(), - options?.AfterId, - options?.BeforeId, - cancellationToken.ToRequestOptions()); - - return PageCollectionHelpers.CreateAsync(enumerator); - } - - /// - /// Rehydrates a page collection holding instances from a page token. - /// - /// Page token corresponding to the first page of the collection to rehydrate. - /// A token that can be used to cancel this method call. - /// holds pages of values. To obtain a collection of values, call - /// . To obtain the current - /// page of values, call . - /// A collection of pages of . - public virtual AsyncPageCollection GetRunStepsAsync( - ContinuationToken firstPageToken, - CancellationToken cancellationToken = default) - { - Argument.AssertNotNull(firstPageToken, nameof(firstPageToken)); - - RunStepsPageToken pageToken = RunStepsPageToken.FromToken(firstPageToken); - RunStepsPageEnumerator enumerator = new(_pipeline, _endpoint, - pageToken.ThreadId, - pageToken.RunId, - pageToken.Limit, - pageToken.Order, - pageToken.After, - pageToken.Before, - cancellationToken.ToRequestOptions()); - - return PageCollectionHelpers.CreateAsync(enumerator); - } - - /// - /// Gets a page collection holding instances associated with a . - /// - /// The ID of the thread associated with the run. - /// The ID of the run to list run steps from. - /// - /// A token that can be used to cancel this method call. - /// holds pages of values. To obtain a collection of values, call - /// . To obtain the current - /// page of values, call . - /// A collection of pages of . - public virtual PageCollection GetRunSteps( - string threadId, - string runId, - RunStepCollectionOptions options = default, - CancellationToken cancellationToken = default) - { - Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); - Argument.AssertNotNullOrEmpty(runId, nameof(runId)); - - RunStepsPageEnumerator enumerator = new(_pipeline, _endpoint, - threadId, - runId, - options?.PageSize, - options?.Order?.ToString(), - options?.AfterId, - options?.BeforeId, - cancellationToken.ToRequestOptions()); - - return PageCollectionHelpers.Create(enumerator); - } - - /// - /// Rehydrates a page collection holding instances from a page token. - /// - /// Page token corresponding to the first page of the collection to rehydrate. - /// A token that can be used to cancel this method call. - /// holds pages of values. To obtain a collection of values, call - /// . To obtain the current - /// page of values, call . - /// A collection of pages of . - public virtual PageCollection GetRunSteps( - ContinuationToken firstPageToken, - CancellationToken cancellationToken = default) - { - Argument.AssertNotNull(firstPageToken, nameof(firstPageToken)); - - RunStepsPageToken pageToken = RunStepsPageToken.FromToken(firstPageToken); - RunStepsPageEnumerator enumerator = new(_pipeline, _endpoint, - pageToken.ThreadId, - pageToken.RunId, - pageToken.Limit, - pageToken.Order, - pageToken.After, - pageToken.Before, - cancellationToken.ToRequestOptions()); - - return PageCollectionHelpers.Create(enumerator); - } - - /// - /// Gets a single run step from a run. - /// - /// The ID of the thread associated with the run. - /// The ID of the run. - /// The ID of the run step. - /// A token that can be used to cancel this method call. - /// A instance corresponding to the specified step. - public virtual async Task> GetRunStepAsync(string threadId, string runId, string stepId, CancellationToken cancellationToken = default) - { - ClientResult protocolResult = await GetRunStepAsync(threadId, runId, stepId, cancellationToken.ToRequestOptions()).ConfigureAwait(false); - return CreateResultFromProtocol(protocolResult, RunStep.FromResponse); - } - - /// - /// Gets a single run step from a run. - /// - /// The ID of the thread associated with the run. - /// The ID of the run. - /// The ID of the run step. - /// A token that can be used to cancel this method call. - /// A instance corresponding to the specified step. - public virtual ClientResult GetRunStep(string threadId, string runId, string stepId, CancellationToken cancellationToken = default) - { - ClientResult protocolResult = GetRunStep(threadId, runId, stepId, cancellationToken.ToRequestOptions()); - return CreateResultFromProtocol(protocolResult, RunStep.FromResponse); - } - private static BinaryContent CreateThreadAndRunProtocolContent( string assistantId, ThreadCreationOptions threadOptions, diff --git a/src/Custom/Assistants/Internal/InternalAssistantRunClient.Protocol.cs b/src/Custom/Assistants/Internal/InternalAssistantRunClient.Protocol.cs index 4ffef5d1..865ba7b3 100644 --- a/src/Custom/Assistants/Internal/InternalAssistantRunClient.Protocol.cs +++ b/src/Custom/Assistants/Internal/InternalAssistantRunClient.Protocol.cs @@ -158,194 +158,4 @@ public virtual ClientResult GetRun(string threadId, string runId, RequestOptions using PipelineMessage message = CreateGetRunRequest(threadId, runId, options); return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); } - - /// - /// [Protocol Method] Modifies a run. - /// - /// The ID of the [thread](/docs/api-reference/threads) that was run. - /// The ID of the run to modify. - /// The content to send as the body of the request. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// , or is null. - /// or is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// The response returned from the service. - public virtual async Task ModifyRunAsync(string threadId, string runId, BinaryContent content, RequestOptions options = null) - { - Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); - Argument.AssertNotNullOrEmpty(runId, nameof(runId)); - Argument.AssertNotNull(content, nameof(content)); - - using PipelineMessage message = CreateModifyRunRequest(threadId, runId, content, options); - return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); - } - - /// - /// [Protocol Method] Modifies a run. - /// - /// The ID of the [thread](/docs/api-reference/threads) that was run. - /// The ID of the run to modify. - /// The content to send as the body of the request. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// , or is null. - /// or is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// The response returned from the service. - public virtual ClientResult ModifyRun(string threadId, string runId, BinaryContent content, RequestOptions options = null) - { - Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); - Argument.AssertNotNullOrEmpty(runId, nameof(runId)); - Argument.AssertNotNull(content, nameof(content)); - - using PipelineMessage message = CreateModifyRunRequest(threadId, runId, content, options); - return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); - } - - /// - /// [Protocol Method] Cancels a run that is `in_progress`. - /// - /// The ID of the thread to which this run belongs. - /// The ID of the run to cancel. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// or is null. - /// or is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// The response returned from the service. - public virtual async Task CancelRunAsync(string threadId, string runId, RequestOptions options) - { - Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); - Argument.AssertNotNullOrEmpty(runId, nameof(runId)); - - using PipelineMessage message = CreateCancelRunRequest(threadId, runId, options); - return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); - } - - /// - /// [Protocol Method] Cancels a run that is `in_progress`. - /// - /// The ID of the thread to which this run belongs. - /// The ID of the run to cancel. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// or is null. - /// or is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// The response returned from the service. - public virtual ClientResult CancelRun(string threadId, string runId, RequestOptions options) - { - Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); - Argument.AssertNotNullOrEmpty(runId, nameof(runId)); - - using PipelineMessage message = CreateCancelRunRequest(threadId, runId, options); - return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); - } - - /// - /// [Protocol Method] When a run has the `status: "requires_action"` and `required_action.type` is - /// `submit_tool_outputs`, this endpoint can be used to submit the outputs from the tool calls once - /// they're all completed. All outputs must be submitted in a single request. - /// - /// The ID of the [thread](/docs/api-reference/threads) to which this run belongs. - /// The ID of the run that requires the tool output submission. - /// The content to send as the body of the request. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// , or is null. - /// or is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// The response returned from the service. - public virtual async Task SubmitToolOutputsToRunAsync(string threadId, string runId, BinaryContent content, RequestOptions options = null) - { - Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); - Argument.AssertNotNullOrEmpty(runId, nameof(runId)); - Argument.AssertNotNull(content, nameof(content)); - - PipelineMessage message = null; - try - { - message = CreateSubmitToolOutputsToRunRequest(threadId, runId, content, options); - return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); - } - finally - { - if (options?.BufferResponse != false) - { - message.Dispose(); - } - } - } - - /// - /// [Protocol Method] When a run has the `status: "requires_action"` and `required_action.type` is - /// `submit_tool_outputs`, this endpoint can be used to submit the outputs from the tool calls once - /// they're all completed. All outputs must be submitted in a single request. - /// - /// The ID of the [thread](/docs/api-reference/threads) to which this run belongs. - /// The ID of the run that requires the tool output submission. - /// The content to send as the body of the request. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// , or is null. - /// or is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// The response returned from the service. - public virtual ClientResult SubmitToolOutputsToRun(string threadId, string runId, BinaryContent content, RequestOptions options = null) - { - Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); - Argument.AssertNotNullOrEmpty(runId, nameof(runId)); - Argument.AssertNotNull(content, nameof(content)); - - PipelineMessage message = null; - try - { - message = CreateSubmitToolOutputsToRunRequest(threadId, runId, content, options); - return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); - } - finally - { - if (options?.BufferResponse != false) - { - message.Dispose(); - } - } - } - - /// - /// [Protocol Method] Retrieves a run step. - /// - /// The ID of the thread to which the run and run step belongs. - /// The ID of the run to which the run step belongs. - /// The ID of the run step to retrieve. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// , or is null. - /// , or is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// The response returned from the service. - public virtual async Task GetRunStepAsync(string threadId, string runId, string stepId, RequestOptions options) - { - Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); - Argument.AssertNotNullOrEmpty(runId, nameof(runId)); - Argument.AssertNotNullOrEmpty(stepId, nameof(stepId)); - - using PipelineMessage message = CreateGetRunStepRequest(threadId, runId, stepId, options); - return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); - } - - /// - /// [Protocol Method] Retrieves a run step. - /// - /// The ID of the thread to which the run and run step belongs. - /// The ID of the run to which the run step belongs. - /// The ID of the run step to retrieve. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// , or is null. - /// , or is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// The response returned from the service. - public virtual ClientResult GetRunStep(string threadId, string runId, string stepId, RequestOptions options) - { - Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); - Argument.AssertNotNullOrEmpty(runId, nameof(runId)); - Argument.AssertNotNullOrEmpty(stepId, nameof(stepId)); - - using PipelineMessage message = CreateGetRunStepRequest(threadId, runId, stepId, options); - return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); - } } diff --git a/src/Custom/Assistants/Streaming/RequiredActionUpdate.cs b/src/Custom/Assistants/Streaming/RequiredActionUpdate.cs index 2c189923..c301be98 100644 --- a/src/Custom/Assistants/Streaming/RequiredActionUpdate.cs +++ b/src/Custom/Assistants/Streaming/RequiredActionUpdate.cs @@ -13,23 +13,12 @@ namespace OpenAI.Assistants; /// public class RequiredActionUpdate : RunUpdate { - /// - public string FunctionName => AsFunctionCall?.FunctionName; + public IReadOnlyList RequiredActions { get; } - /// - public string FunctionArguments => AsFunctionCall?.FunctionArguments; - - /// - public string ToolCallId => AsFunctionCall?.Id; - - private InternalRequiredFunctionToolCall AsFunctionCall => _requiredAction as InternalRequiredFunctionToolCall; - - private readonly RequiredAction _requiredAction; - - internal RequiredActionUpdate(ThreadRun run, RequiredAction action) + internal RequiredActionUpdate(ThreadRun run, IReadOnlyList actions) : base(run, StreamingUpdateReason.RunRequiresAction) { - _requiredAction = action; + RequiredActions = actions; } /// @@ -42,11 +31,6 @@ internal RequiredActionUpdate(ThreadRun run, RequiredAction action) internal static IEnumerable DeserializeRequiredActionUpdates(JsonElement element) { ThreadRun run = ThreadRun.DeserializeThreadRun(element); - List updates = []; - foreach (RequiredAction action in run.RequiredActions ?? []) - { - updates.Add(new(run, action)); - } - return updates; + return [new(run, run.RequiredActions)]; } } \ No newline at end of file diff --git a/src/Custom/Assistants/Streaming/StreamingUpdateCollection.cs b/src/Custom/Assistants/Streaming/StreamingUpdateCollection.cs index d0099d99..5a098669 100644 --- a/src/Custom/Assistants/Streaming/StreamingUpdateCollection.cs +++ b/src/Custom/Assistants/Streaming/StreamingUpdateCollection.cs @@ -62,7 +62,7 @@ public StreamingUpdateEnumerator(Func getResult, StreamingUpdate IEnumerator.Current => _current!; - object IEnumerator.Current => throw new NotImplementedException(); + object IEnumerator.Current => _current!; public bool MoveNext() { diff --git a/src/Custom/Batch/BatchClient.Protocol.cs b/src/Custom/Batch/BatchClient.Protocol.cs index 11cc1097..321502ee 100644 --- a/src/Custom/Batch/BatchClient.Protocol.cs +++ b/src/Custom/Batch/BatchClient.Protocol.cs @@ -1,42 +1,58 @@ using System; using System.ClientModel; using System.ClientModel.Primitives; +using System.Text.Json; using System.Threading.Tasks; namespace OpenAI.Batch; public partial class BatchClient { - /// - /// [Protocol Method] Creates and executes a batch from an uploaded file of requests - /// - /// The content to send as the body of the request. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// is null. - /// Service returned a non-success status code. - /// The response returned from the service. - public virtual async Task CreateBatchAsync(BinaryContent content, RequestOptions options = null) + public virtual async Task CreateBatchAsync( + ReturnWhen returnWhen, + BinaryContent content, RequestOptions options = null) { Argument.AssertNotNull(content, nameof(content)); using PipelineMessage message = CreateCreateBatchRequest(content, options); - return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + + PipelineResponse response = await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false); + + using JsonDocument doc = JsonDocument.Parse(response.Content); + string batchId = doc.RootElement.GetProperty("id"u8).GetString(); + string status = doc.RootElement.GetProperty("status"u8).GetString(); + + BatchOperation operation = new BatchOperation(_pipeline, _endpoint, batchId, status, response); + if (returnWhen == ReturnWhen.Started) + { + return operation; + } + + await operation.WaitForCompletionAsync(options?.CancellationToken ?? default).ConfigureAwait(false); + return operation; } - /// - /// [Protocol Method] Creates and executes a batch from an uploaded file of requests - /// - /// The content to send as the body of the request. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// is null. - /// Service returned a non-success status code. - /// The response returned from the service. - public virtual ClientResult CreateBatch(BinaryContent content, RequestOptions options = null) + public virtual BatchOperation CreateBatch( + ReturnWhen returnWhen, + BinaryContent content, RequestOptions options = null) { Argument.AssertNotNull(content, nameof(content)); using PipelineMessage message = CreateCreateBatchRequest(content, options); - return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + PipelineResponse response = _pipeline.ProcessMessage(message, options); + + using JsonDocument doc = JsonDocument.Parse(response.Content); + string batchId = doc.RootElement.GetProperty("id"u8).GetString(); + string status = doc.RootElement.GetProperty("status"u8).GetString(); + + BatchOperation operation = new BatchOperation(_pipeline, _endpoint, batchId, status, response); + if (returnWhen == ReturnWhen.Started) + { + return operation; + } + + operation.WaitForCompletion(options?.CancellationToken ?? default); + return operation; } /// @@ -100,38 +116,4 @@ public virtual ClientResult GetBatch(string batchId, RequestOptions options) using PipelineMessage message = CreateRetrieveBatchRequest(batchId, options); return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); } - - /// - /// [Protocol Method] Cancels an in-progress batch. - /// - /// The ID of the batch to cancel. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// is null. - /// is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// The response returned from the service. - public virtual async Task CancelBatchAsync(string batchId, RequestOptions options) - { - Argument.AssertNotNullOrEmpty(batchId, nameof(batchId)); - - using PipelineMessage message = CreateCancelBatchRequest(batchId, options); - return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); - } - - /// - /// [Protocol Method] Cancels an in-progress batch. - /// - /// The ID of the batch to cancel. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// is null. - /// is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// The response returned from the service. - public virtual ClientResult CancelBatch(string batchId, RequestOptions options) - { - Argument.AssertNotNullOrEmpty(batchId, nameof(batchId)); - - using PipelineMessage message = CreateCancelBatchRequest(batchId, options); - return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); - } } diff --git a/src/Custom/Batch/BatchClient.cs b/src/Custom/Batch/BatchClient.cs index adce1b0d..a5c75e9a 100644 --- a/src/Custom/Batch/BatchClient.cs +++ b/src/Custom/Batch/BatchClient.cs @@ -2,7 +2,6 @@ using System.ClientModel; using System.ClientModel.Primitives; using System.Collections.Generic; -using System.Threading.Tasks; namespace OpenAI.Batch; diff --git a/src/Custom/FineTuning/FineTuningClient.Protocol.cs b/src/Custom/FineTuning/FineTuningClient.Protocol.cs index 2483478d..5c97e02d 100644 --- a/src/Custom/FineTuning/FineTuningClient.Protocol.cs +++ b/src/Custom/FineTuning/FineTuningClient.Protocol.cs @@ -1,6 +1,7 @@ using System; using System.ClientModel; using System.ClientModel.Primitives; +using System.Text.Json; using System.Threading.Tasks; namespace OpenAI.FineTuning; @@ -34,12 +35,27 @@ public partial class FineTuningClient /// is null. /// Service returned a non-success status code. /// The response returned from the service. - public virtual async Task CreateJobAsync(BinaryContent content, RequestOptions options = null) + public virtual async Task CreateJobAsync( + ReturnWhen returnWhen, + BinaryContent content, RequestOptions options = null) { Argument.AssertNotNull(content, nameof(content)); using PipelineMessage message = CreateCreateFineTuningJobRequest(content, options); - return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + PipelineResponse response = await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false); + + using JsonDocument doc = JsonDocument.Parse(response.Content); + string jobId = doc.RootElement.GetProperty("id"u8).GetString(); + string status = doc.RootElement.GetProperty("status"u8).GetString(); + + FineTuningOperation operation = new FineTuningOperation(_pipeline, _endpoint, jobId, status, response); + if (returnWhen == ReturnWhen.Started) + { + return operation; + } + + await operation.WaitForCompletionAsync(options?.CancellationToken ?? default).ConfigureAwait(false); + return operation; } // CUSTOM: @@ -57,12 +73,27 @@ public virtual async Task CreateJobAsync(BinaryContent content, Re /// is null. /// Service returned a non-success status code. /// The response returned from the service. - public virtual ClientResult CreateJob(BinaryContent content, RequestOptions options = null) + public virtual FineTuningOperation CreateJob( + ReturnWhen returnWhen, + BinaryContent content, RequestOptions options = null) { Argument.AssertNotNull(content, nameof(content)); using PipelineMessage message = CreateCreateFineTuningJobRequest(content, options); - return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + PipelineResponse response = _pipeline.ProcessMessage(message, options); + + using JsonDocument doc = JsonDocument.Parse(response.Content); + string jobId = doc.RootElement.GetProperty("id"u8).GetString(); + string status = doc.RootElement.GetProperty("status"u8).GetString(); + + FineTuningOperation operation = new FineTuningOperation(_pipeline, _endpoint, jobId, status, response); + if (returnWhen == ReturnWhen.Started) + { + return operation; + } + + operation.WaitForCompletion(options?.CancellationToken ?? default); + return operation; } // CUSTOM: @@ -98,170 +129,4 @@ public virtual ClientResult GetJobs(string after, int? limit, RequestOptions opt using PipelineMessage message = CreateGetPaginatedFineTuningJobsRequest(after, limit, options); return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); } - - // CUSTOM: - // - Renamed. - // - Edited doc comment. - /// - /// [Protocol Method] Get info about a fine-tuning job. - /// - /// [Learn more about fine-tuning](/docs/guides/fine-tuning) - /// - /// The ID of the fine-tuning job. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// is null. - /// is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// The response returned from the service. - public virtual async Task GetJobAsync(string jobId, RequestOptions options) - { - Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); - - using PipelineMessage message = CreateRetrieveFineTuningJobRequest(jobId, options); - return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); - } - - // CUSTOM: - // - Renamed. - // - Edited doc comment. - /// - /// [Protocol Method] Get info about a fine-tuning job. - /// - /// [Learn more about fine-tuning](/docs/guides/fine-tuning) - /// - /// The ID of the fine-tuning job. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// is null. - /// is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// The response returned from the service. - public virtual ClientResult GetJob(string jobId, RequestOptions options) - { - Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); - - using PipelineMessage message = CreateRetrieveFineTuningJobRequest(jobId, options); - return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); - } - - // CUSTOM: - // - Renamed. - // - Edited doc comment. - /// - /// [Protocol Method] Immediately cancel a fine-tune job. - /// - /// The ID of the fine-tuning job to cancel. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// is null. - /// is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// The response returned from the service. - public virtual async Task CancelJobAsync(string jobId, RequestOptions options) - { - Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); - - using PipelineMessage message = CreateCancelFineTuningJobRequest(jobId, options); - return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); - } - - // CUSTOM: - // - Renamed. - // - Edited doc comment. - /// - /// [Protocol Method] Immediately cancel a fine-tune job. - /// - /// The ID of the fine-tuning job to cancel. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// is null. - /// is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// The response returned from the service. - public virtual ClientResult CancelJob(string jobId, RequestOptions options) - { - Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); - - using PipelineMessage message = CreateCancelFineTuningJobRequest(jobId, options); - return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); - } - - // CUSTOM: - // - Renamed. - // - Edited doc comment. - /// - /// [Protocol Method] Get status updates for a fine-tuning job. - /// - /// The ID of the fine-tuning job to get events for. - /// Identifier for the last event from the previous pagination request. - /// Number of events to retrieve. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// is null. - /// is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// The response returned from the service. - public virtual async Task GetJobEventsAsync(string jobId, string after, int? limit, RequestOptions options) - { - Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); - - using PipelineMessage message = CreateGetFineTuningEventsRequest(jobId, after, limit, options); - return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); - } - - // CUSTOM: - // - Renamed. - // - Edited doc comment. - /// - /// [Protocol Method] Get status updates for a fine-tuning job. - /// - /// The ID of the fine-tuning job to get events for. - /// Identifier for the last event from the previous pagination request. - /// Number of events to retrieve. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// is null. - /// is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// The response returned from the service. - public virtual ClientResult GetJobEvents(string jobId, string after, int? limit, RequestOptions options) - { - Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); - - using PipelineMessage message = CreateGetFineTuningEventsRequest(jobId, after, limit, options); - return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); - } - - /// - /// [Protocol Method] List the checkpoints for a fine-tuning job. - /// - /// The ID of the fine-tuning job to get checkpoints for. - /// Identifier for the last checkpoint ID from the previous pagination request. - /// Number of checkpoints to retrieve. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// is null. - /// is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// The response returned from the service. - public virtual async Task GetJobCheckpointsAsync(string fineTuningJobId, string after, int? limit, RequestOptions options) - { - Argument.AssertNotNullOrEmpty(fineTuningJobId, nameof(fineTuningJobId)); - - using PipelineMessage message = CreateGetFineTuningJobCheckpointsRequest(fineTuningJobId, after, limit, options); - return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); - } - - /// - /// [Protocol Method] List the checkpoints for a fine-tuning job. - /// - /// The ID of the fine-tuning job to get checkpoints for. - /// Identifier for the last checkpoint ID from the previous pagination request. - /// Number of checkpoints to retrieve. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// is null. - /// is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// The response returned from the service. - public virtual ClientResult GetJobCheckpoints(string fineTuningJobId, string after, int? limit, RequestOptions options) - { - Argument.AssertNotNullOrEmpty(fineTuningJobId, nameof(fineTuningJobId)); - - using PipelineMessage message = CreateGetFineTuningJobCheckpointsRequest(fineTuningJobId, after, limit, options); - return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); - } } diff --git a/src/Custom/VectorStores/VectorStoreClient.Convenience.cs b/src/Custom/VectorStores/VectorStoreClient.Convenience.cs index ee48aaca..4148a8c1 100644 --- a/src/Custom/VectorStores/VectorStoreClient.Convenience.cs +++ b/src/Custom/VectorStores/VectorStoreClient.Convenience.cs @@ -170,8 +170,11 @@ public virtual ClientResult RemoveFileFromStore(VectorStore vectorStore, O /// The vector store to associate files with. /// The files to associate with the vector store. /// A instance representing the batch operation. - public virtual Task> CreateBatchFileJobAsync(VectorStore vectorStore, IEnumerable files) - => CreateBatchFileJobAsync(vectorStore?.Id, files?.Select(file => file.Id)); + public virtual Task CreateBatchFileJobAsync( + ReturnWhen returnWhen, + VectorStore vectorStore, + IEnumerable files) + => CreateBatchFileJobAsync(returnWhen, vectorStore?.Id, files?.Select(file => file.Id)); /// /// Begins a batch job to associate multiple jobs with a vector store, beginning the ingestion process. @@ -179,69 +182,9 @@ public virtual Task> CreateBatchFileJobAsy /// The vector store to associate files with. /// The files to associate with the vector store. /// A instance representing the batch operation. - public virtual ClientResult CreateBatchFileJob(VectorStore vectorStore, IEnumerable files) - => CreateBatchFileJob(vectorStore?.Id, files?.Select(file => file.Id)); - - /// - /// Gets an updated instance of an existing , refreshing its status. - /// - /// The job to refresh. - /// The refreshed instance of . - public virtual Task> GetBatchFileJobAsync(VectorStoreBatchFileJob batchJob) - => GetBatchFileJobAsync(batchJob?.VectorStoreId, batchJob?.BatchId); - - /// - /// Gets an updated instance of an existing , refreshing its status. - /// - /// The job to refresh. - /// The refreshed instance of . - public virtual ClientResult GetBatchFileJob(VectorStoreBatchFileJob batchJob) - => GetBatchFileJob(batchJob?.VectorStoreId, batchJob?.BatchId); - - /// - /// Cancels an in-progress . - /// - /// The that should be canceled. - /// An updated instance. - public virtual Task> CancelBatchFileJobAsync(VectorStoreBatchFileJob batchJob) - => CancelBatchFileJobAsync(batchJob?.VectorStoreId, batchJob?.BatchId); - - /// - /// Cancels an in-progress . - /// - /// The that should be canceled. - /// An updated instance. - public virtual ClientResult CancelBatchFileJob(VectorStoreBatchFileJob batchJob) - => CancelBatchFileJob(batchJob?.VectorStoreId, batchJob?.BatchId); - - /// - /// Gets a page collection holding file associations associated with a vector store batch file job, representing the files - /// that were scheduled for ingestion into the vector store. - /// - /// The vector store batch file job to retrieve file associations from. - /// Options describing the collection to return. - /// holds pages of values. To obtain a collection of values, call - /// . To obtain the current - /// page of values, call . - /// A collection of pages of . - public virtual AsyncPageCollection GetFileAssociationsAsync( - VectorStoreBatchFileJob batchJob, - VectorStoreFileAssociationCollectionOptions options = default) - => GetFileAssociationsAsync(batchJob?.VectorStoreId, batchJob?.BatchId, options); - - /// - /// Gets a page collection holding file associations associated with a vector store batch file job, representing the files - /// that were scheduled for ingestion into the vector store. - /// - /// The vector store batch file job to retrieve file associations from. - /// Options describing the collection to return. - /// holds pages of values. To obtain a collection of values, call - /// . To obtain the current - /// page of values, call . - /// A collection of pages of . - public virtual PageCollection GetFileAssociations( - VectorStoreBatchFileJob batchJob, - VectorStoreFileAssociationCollectionOptions options = default) - => GetFileAssociations(batchJob?.VectorStoreId, batchJob?.BatchId, options); - + public virtual VectorStoreFileBatchOperation CreateBatchFileJob( + ReturnWhen returnWhen, + VectorStore vectorStore, + IEnumerable files) + => CreateBatchFileJob(returnWhen, vectorStore?.Id, files?.Select(file => file.Id)); } diff --git a/src/Custom/VectorStores/VectorStoreClient.Protocol.cs b/src/Custom/VectorStores/VectorStoreClient.Protocol.cs index ca3fa76b..4adcd361 100644 --- a/src/Custom/VectorStores/VectorStoreClient.Protocol.cs +++ b/src/Custom/VectorStores/VectorStoreClient.Protocol.cs @@ -3,6 +3,7 @@ using System.ClientModel.Primitives; using System.Collections.Generic; using System.ComponentModel; +using System.Text.Json; using System.Threading.Tasks; namespace OpenAI.VectorStores; @@ -433,13 +434,26 @@ public virtual ClientResult RemoveFileFromStore(string vectorStoreId, string fil /// Service returned a non-success status code. /// The response returned from the service. [EditorBrowsable(EditorBrowsableState.Never)] - public virtual async Task CreateBatchFileJobAsync(string vectorStoreId, BinaryContent content, RequestOptions options = null) + public virtual async Task CreateBatchFileJobAsync(ReturnWhen returnWhen, string vectorStoreId, BinaryContent content, RequestOptions options = null) { Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); Argument.AssertNotNull(content, nameof(content)); using PipelineMessage message = CreateCreateVectorStoreFileBatchRequest(vectorStoreId, content, options); - return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + PipelineResponse response = await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false); + + using JsonDocument doc = JsonDocument.Parse(response.Content); + string batchId = doc.RootElement.GetProperty("id"u8).GetString(); + string status = doc.RootElement.GetProperty("status"u8).GetString(); + + VectorStoreFileBatchOperation operation = new VectorStoreFileBatchOperation(_pipeline, _endpoint, vectorStoreId, batchId, status, response); + if (returnWhen == ReturnWhen.Started) + { + return operation; + } + + await operation.WaitForCompletionAsync(options?.CancellationToken ?? default).ConfigureAwait(false); + return operation; } /// @@ -453,170 +467,27 @@ public virtual async Task CreateBatchFileJobAsync(string vectorSto /// Service returned a non-success status code. /// The response returned from the service. [EditorBrowsable(EditorBrowsableState.Never)] - public virtual ClientResult CreateBatchFileJob(string vectorStoreId, BinaryContent content, RequestOptions options = null) + public virtual VectorStoreFileBatchOperation CreateBatchFileJob( + ReturnWhen returnWhen, + string vectorStoreId, BinaryContent content, RequestOptions options = null) { Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); Argument.AssertNotNull(content, nameof(content)); using PipelineMessage message = CreateCreateVectorStoreFileBatchRequest(vectorStoreId, content, options); - return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); - } - - /// - /// [Protocol Method] Retrieves a vector store file batch. - /// - /// The ID of the vector store that the file batch belongs to. - /// The ID of the file batch being retrieved. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// or is null. - /// or is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// The response returned from the service. - [EditorBrowsable(EditorBrowsableState.Never)] - public virtual async Task GetBatchFileJobAsync(string vectorStoreId, string batchId, RequestOptions options) - { - Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); - Argument.AssertNotNullOrEmpty(batchId, nameof(batchId)); - - using PipelineMessage message = CreateGetVectorStoreFileBatchRequest(vectorStoreId, batchId, options); - return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); - } - - /// - /// [Protocol Method] Retrieves a vector store file batch. - /// - /// The ID of the vector store that the file batch belongs to. - /// The ID of the file batch being retrieved. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// or is null. - /// or is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// The response returned from the service. - [EditorBrowsable(EditorBrowsableState.Never)] - public virtual ClientResult GetBatchFileJob(string vectorStoreId, string batchId, RequestOptions options) - { - Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); - Argument.AssertNotNullOrEmpty(batchId, nameof(batchId)); - - using PipelineMessage message = CreateGetVectorStoreFileBatchRequest(vectorStoreId, batchId, options); - return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); - } - - /// - /// [Protocol Method] Cancel a vector store file batch. This attempts to cancel the processing of files in this batch as soon as possible. - /// - /// The ID of the vector store that the file batch belongs to. - /// The ID of the file batch to cancel. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// or is null. - /// or is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// The response returned from the service. - [EditorBrowsable(EditorBrowsableState.Never)] - public virtual async Task CancelBatchFileJobAsync(string vectorStoreId, string batchId, RequestOptions options) - { - Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); - Argument.AssertNotNullOrEmpty(batchId, nameof(batchId)); - - using PipelineMessage message = CreateCancelVectorStoreFileBatchRequest(vectorStoreId, batchId, options); - return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); - } - - /// - /// [Protocol Method] Cancel a vector store file batch. This attempts to cancel the processing of files in this batch as soon as possible. - /// - /// The ID of the vector store that the file batch belongs to. - /// The ID of the file batch to cancel. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// or is null. - /// or is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// The response returned from the service. - [EditorBrowsable(EditorBrowsableState.Never)] - public virtual ClientResult CancelBatchFileJob(string vectorStoreId, string batchId, RequestOptions options) - { - Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); - Argument.AssertNotNullOrEmpty(batchId, nameof(batchId)); - - using PipelineMessage message = CreateCancelVectorStoreFileBatchRequest(vectorStoreId, batchId, options); - return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); - } - - /// - /// [Protocol Method] Returns a paginated collection of vector store files in a batch. - /// - /// The ID of the vector store that the file batch belongs to. - /// The ID of the file batch that the files belong to. - /// - /// A limit on the number of objects to be returned. Limit can range between 1 and 100, and the - /// default is 20. - /// - /// - /// Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and`desc` - /// for descending order. Allowed values: "asc" | "desc" - /// - /// - /// A cursor for use in pagination. `after` is an object ID that defines your place in the list. - /// For instance, if you make a list request and receive 100 objects, ending with obj_foo, your - /// subsequent call can include after=obj_foo in order to fetch the next page of the list. - /// - /// - /// A cursor for use in pagination. `before` is an object ID that defines your place in the list. - /// For instance, if you make a list request and receive 100 objects, ending with obj_foo, your - /// subsequent call can include before=obj_foo in order to fetch the previous page of the list. - /// - /// Filter by file status. One of `in_progress`, `completed`, `failed`, `cancelled`. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// or is null. - /// or is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// A collection of service responses, each holding a page of values. - [EditorBrowsable(EditorBrowsableState.Never)] - public virtual IAsyncEnumerable GetFileAssociationsAsync(string vectorStoreId, string batchId, int? limit, string order, string after, string before, string filter, RequestOptions options) - { - Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); - Argument.AssertNotNullOrEmpty(batchId, nameof(batchId)); + PipelineResponse response = _pipeline.ProcessMessage(message, options); - VectorStoreFileBatchesPageEnumerator enumerator = new VectorStoreFileBatchesPageEnumerator(_pipeline, _endpoint, vectorStoreId, batchId, limit, order, after, before, filter, options); - return PageCollectionHelpers.CreateAsync(enumerator); - } + using JsonDocument doc = JsonDocument.Parse(response.Content); + string batchId = doc.RootElement.GetProperty("id"u8).GetString(); + string status = doc.RootElement.GetProperty("status"u8).GetString(); - /// - /// [Protocol Method] Returns a paginated collection of vector store files in a batch. - /// - /// The ID of the vector store that the file batch belongs to. - /// The ID of the file batch that the files belong to. - /// - /// A limit on the number of objects to be returned. Limit can range between 1 and 100, and the - /// default is 20. - /// - /// - /// Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and`desc` - /// for descending order. Allowed values: "asc" | "desc" - /// - /// - /// A cursor for use in pagination. `after` is an object ID that defines your place in the list. - /// For instance, if you make a list request and receive 100 objects, ending with obj_foo, your - /// subsequent call can include after=obj_foo in order to fetch the next page of the list. - /// - /// - /// A cursor for use in pagination. `before` is an object ID that defines your place in the list. - /// For instance, if you make a list request and receive 100 objects, ending with obj_foo, your - /// subsequent call can include before=obj_foo in order to fetch the previous page of the list. - /// - /// Filter by file status. One of `in_progress`, `completed`, `failed`, `cancelled`. - /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// or is null. - /// or is an empty string, and was expected to be non-empty. - /// Service returned a non-success status code. - /// A collection of service responses, each holding a page of values. - [EditorBrowsable(EditorBrowsableState.Never)] - public virtual IEnumerable GetFileAssociations(string vectorStoreId, string batchId, int? limit, string order, string after, string before, string filter, RequestOptions options) - { - Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); - Argument.AssertNotNullOrEmpty(batchId, nameof(batchId)); + VectorStoreFileBatchOperation operation = new VectorStoreFileBatchOperation(_pipeline, _endpoint, vectorStoreId, batchId, status, response); + if (returnWhen == ReturnWhen.Started) + { + return operation; + } - VectorStoreFileBatchesPageEnumerator enumerator = new VectorStoreFileBatchesPageEnumerator(_pipeline, _endpoint, vectorStoreId, batchId, limit, order, after, before, filter, options); - return PageCollectionHelpers.Create(enumerator); + operation.WaitForCompletion(options?.CancellationToken ?? default); + return operation; } } diff --git a/src/Custom/VectorStores/VectorStoreClient.cs b/src/Custom/VectorStores/VectorStoreClient.cs index 161f4f6a..c87d3cbb 100644 --- a/src/Custom/VectorStores/VectorStoreClient.cs +++ b/src/Custom/VectorStores/VectorStoreClient.cs @@ -15,10 +15,6 @@ namespace OpenAI.VectorStores; [CodeGenClient("VectorStores")] [CodeGenSuppress("CreateVectorStoreAsync", typeof(VectorStoreCreationOptions))] [CodeGenSuppress("CreateVectorStore", typeof(VectorStoreCreationOptions))] -[CodeGenSuppress("GetVectorStoreAsync", typeof(string))] -[CodeGenSuppress("GetVectorStore", typeof(string))] -[CodeGenSuppress("ModifyVectorStoreAsync", typeof(string), typeof(VectorStoreModificationOptions))] -[CodeGenSuppress("ModifyVectorStore", typeof(string), typeof(VectorStoreModificationOptions))] [CodeGenSuppress("DeleteVectorStoreAsync", typeof(string))] [CodeGenSuppress("DeleteVectorStore", typeof(string))] [CodeGenSuppress("GetVectorStoresAsync", typeof(int?), typeof(ListOrder?), typeof(string), typeof(string))] @@ -544,295 +540,70 @@ public virtual ClientResult RemoveFileFromStore(string vectorStoreId, stri /// The IDs of the files to associate with the vector store. /// A token that can be used to cancel this method call. /// A instance representing the batch operation. - public virtual async Task> CreateBatchFileJobAsync(string vectorStoreId, IEnumerable fileIds, CancellationToken cancellationToken = default) + public virtual async Task CreateBatchFileJobAsync( + ReturnWhen returnWhen, + string vectorStoreId, + IEnumerable fileIds, + CancellationToken cancellationToken = default) { Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); Argument.AssertNotNullOrEmpty(fileIds, nameof(fileIds)); BinaryContent content = new InternalCreateVectorStoreFileBatchRequest(fileIds).ToBinaryContent(); - ClientResult result = await CreateBatchFileJobAsync(vectorStoreId, content, cancellationToken.ToRequestOptions()).ConfigureAwait(false); - PipelineResponse response = result?.GetRawResponse(); - VectorStoreBatchFileJob value = VectorStoreBatchFileJob.FromResponse(response); - return ClientResult.FromValue(value, response); - } - - /// - /// Begins a batch job to associate multiple jobs with a vector store, beginning the ingestion process. - /// - /// The ID of the vector store to associate files with. - /// The IDs of the files to associate with the vector store. - /// A token that can be used to cancel this method call. - /// A instance representing the batch operation. - public virtual ClientResult CreateBatchFileJob(string vectorStoreId, IEnumerable fileIds, CancellationToken cancellationToken = default) - { - Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); - Argument.AssertNotNullOrEmpty(fileIds, nameof(fileIds)); + RequestOptions options = cancellationToken.ToRequestOptions(); - BinaryContent content = new InternalCreateVectorStoreFileBatchRequest(fileIds).ToBinaryContent(); - ClientResult result = CreateBatchFileJob(vectorStoreId, content, cancellationToken.ToRequestOptions()); - PipelineResponse response = result?.GetRawResponse(); + using PipelineMessage message = CreateCreateVectorStoreFileBatchRequest(vectorStoreId, content, options); + PipelineResponse response = await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false); VectorStoreBatchFileJob value = VectorStoreBatchFileJob.FromResponse(response); - return ClientResult.FromValue(value, response); - } - /// - /// Gets an existing vector store batch file ingestion job from a known vector store ID and job ID. - /// - /// The ID of the vector store into which the batch of files was started. - /// The ID of the batch operation adding files to the vector store. - /// A token that can be used to cancel this method call. - /// A instance representing the ingestion operation. - public virtual async Task> GetBatchFileJobAsync(string vectorStoreId, string batchJobId, CancellationToken cancellationToken = default) - { - Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); - Argument.AssertNotNullOrEmpty(batchJobId, nameof(batchJobId)); - - ClientResult result = await GetBatchFileJobAsync(vectorStoreId, batchJobId, cancellationToken.ToRequestOptions()).ConfigureAwait(false); - PipelineResponse response = result?.GetRawResponse(); - VectorStoreBatchFileJob value = VectorStoreBatchFileJob.FromResponse(response); - return ClientResult.FromValue(value, response); - } - - /// - /// Gets an existing vector store batch file ingestion job from a known vector store ID and job ID. - /// - /// The ID of the vector store into which the batch of files was started. - /// The ID of the batch operation adding files to the vector store. - /// A token that can be used to cancel this method call. - /// A instance representing the ingestion operation. - public virtual ClientResult GetBatchFileJob(string vectorStoreId, string batchJobId, CancellationToken cancellationToken = default) - { - Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); - Argument.AssertNotNullOrEmpty(batchJobId, nameof(batchJobId)); + VectorStoreFileBatchOperation operation = new VectorStoreFileBatchOperation( + _pipeline, + _endpoint, + ClientResult.FromValue(value, response)); - ClientResult result = GetBatchFileJob(vectorStoreId, batchJobId, cancellationToken.ToRequestOptions()); - PipelineResponse response = result?.GetRawResponse(); - VectorStoreBatchFileJob value = VectorStoreBatchFileJob.FromResponse(response); - return ClientResult.FromValue(value, response); - } - - /// - /// Cancels an in-progress . - /// - /// - /// The ID of the that is the ingestion target of the batch job being cancelled. - /// - /// - /// The ID of the that should be canceled. - /// - /// A token that can be used to cancel this method call. - /// An updated instance. - public virtual async Task> CancelBatchFileJobAsync(string vectorStoreId, string batchJobId, CancellationToken cancellationToken = default) - { - Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); - Argument.AssertNotNullOrEmpty(batchJobId, nameof(batchJobId)); - - ClientResult result = await CancelBatchFileJobAsync(vectorStoreId, batchJobId, cancellationToken.ToRequestOptions()).ConfigureAwait(false); - PipelineResponse response = result?.GetRawResponse(); - VectorStoreBatchFileJob value = VectorStoreBatchFileJob.FromResponse(response); - return ClientResult.FromValue(value, response); - } - - /// - /// Cancels an in-progress . - /// - /// - /// The ID of the that is the ingestion target of the batch job being cancelled. - /// - /// - /// The ID of the that should be canceled. - /// - /// A token that can be used to cancel this method call. - /// An updated instance. - public virtual ClientResult CancelBatchFileJob(string vectorStoreId, string batchJobId, CancellationToken cancellationToken = default) - { - Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); - Argument.AssertNotNullOrEmpty(batchJobId, nameof(batchJobId)); - - ClientResult result = CancelBatchFileJob(vectorStoreId, batchJobId, cancellationToken.ToRequestOptions()); - PipelineResponse response = result?.GetRawResponse(); - VectorStoreBatchFileJob value = VectorStoreBatchFileJob.FromResponse(response); - return ClientResult.FromValue(value, response); - } - - /// - /// Gets a page collection of file associations associated with a vector store batch file job, representing the files - /// that were scheduled for ingestion into the vector store. - /// - /// - /// The ID of the vector store into which the file batch was scheduled for ingestion. - /// - /// - /// The ID of the batch file job that was previously scheduled. - /// - /// Options describing the collection to return. - /// A token that can be used to cancel this method call. - /// holds pages of values. To obtain a collection of values, call - /// . To obtain the current - /// page of values, call . - /// A collection of pages of . - public virtual AsyncPageCollection GetFileAssociationsAsync( - string vectorStoreId, - string batchJobId, - VectorStoreFileAssociationCollectionOptions options = default, - CancellationToken cancellationToken = default) - { - Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); - Argument.AssertNotNullOrEmpty(batchJobId, nameof(batchJobId)); - - VectorStoreFileBatchesPageEnumerator enumerator = new(_pipeline, _endpoint, - vectorStoreId, - batchJobId, - options?.PageSize, - options?.Order?.ToString(), - options?.AfterId, - options?.BeforeId, - options?.Filter?.ToString(), - cancellationToken.ToRequestOptions()); - - return PageCollectionHelpers.CreateAsync(enumerator); - } - - /// - /// Rehydrates a page collection of file associations from a page token. - /// - /// - /// The ID of the vector store into which the file batch was scheduled for ingestion. - /// - /// - /// The ID of the batch file job that was previously scheduled. - /// - /// Page token corresponding to the first page of the collection to rehydrate. - /// A token that can be used to cancel this method call. - /// holds pages of values. To obtain a collection of values, call - /// . To obtain the current - /// page of values, call . - /// A collection of pages of . - public virtual AsyncPageCollection GetFileAssociationsAsync( - string vectorStoreId, - string batchJobId, - ContinuationToken firstPageToken, - CancellationToken cancellationToken = default) - { - Argument.AssertNotNull(firstPageToken, nameof(firstPageToken)); - - VectorStoreFileBatchesPageToken pageToken = VectorStoreFileBatchesPageToken.FromToken(firstPageToken); - - if (vectorStoreId != pageToken.VectorStoreId) - { - throw new ArgumentException( - "Invalid page token. 'vectorStoreId' value does not match page token value.", - nameof(vectorStoreId)); - } - - if (batchJobId != pageToken.BatchId) + if (returnWhen == ReturnWhen.Started) { - throw new ArgumentException( - "Invalid page token. 'batchJobId' value does not match page token value.", - nameof(vectorStoreId)); + return operation; } - VectorStoreFileBatchesPageEnumerator enumerator = new(_pipeline, _endpoint, - pageToken.VectorStoreId, - pageToken.BatchId, - pageToken.Limit, - pageToken.Order, - pageToken.After, - pageToken.Before, - pageToken.Filter, - cancellationToken.ToRequestOptions()); - - return PageCollectionHelpers.CreateAsync(enumerator); + await operation.WaitForCompletionAsync().ConfigureAwait(false); + return operation; } /// - /// Gets a page collection of file associations associated with a vector store batch file job, representing the files - /// that were scheduled for ingestion into the vector store. + /// Begins a batch job to associate multiple jobs with a vector store, beginning the ingestion process. /// - /// - /// The ID of the vector store into which the file batch was scheduled for ingestion. - /// - /// - /// The ID of the batch file job that was previously scheduled. - /// - /// Options describing the collection to return. + /// The ID of the vector store to associate files with. + /// The IDs of the files to associate with the vector store. /// A token that can be used to cancel this method call. - /// holds pages of values. To obtain a collection of values, call - /// . To obtain the current - /// page of values, call . - /// A collection of pages of . - public virtual PageCollection GetFileAssociations( + /// A instance representing the batch operation. + public virtual VectorStoreFileBatchOperation CreateBatchFileJob( + ReturnWhen returnWhen, string vectorStoreId, - string batchJobId, - VectorStoreFileAssociationCollectionOptions options = default, + IEnumerable fileIds, CancellationToken cancellationToken = default) { Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); - Argument.AssertNotNullOrEmpty(batchJobId, nameof(batchJobId)); - - VectorStoreFileBatchesPageEnumerator enumerator = new(_pipeline, _endpoint, - vectorStoreId, - batchJobId, - options?.PageSize, - options?.Order?.ToString(), - options?.AfterId, - options?.BeforeId, - options?.Filter?.ToString(), - cancellationToken.ToRequestOptions()); - - return PageCollectionHelpers.Create(enumerator); - } + Argument.AssertNotNullOrEmpty(fileIds, nameof(fileIds)); - /// - /// Rehydrates a page collection of file associations from a page token. - /// that were scheduled for ingestion into the vector store. - /// - /// - /// The ID of the vector store into which the file batch was scheduled for ingestion. - /// - /// - /// The ID of the batch file job that was previously scheduled. - /// - /// Page token corresponding to the first page of the collection to rehydrate. - /// A token that can be used to cancel this method call. - /// holds pages of values. To obtain a collection of values, call - /// . To obtain the current - /// page of values, call . - /// A collection of pages of . - public virtual PageCollection GetFileAssociations( - string vectorStoreId, - string batchJobId, - ContinuationToken firstPageToken, - CancellationToken cancellationToken = default) - { - Argument.AssertNotNull(firstPageToken, nameof(firstPageToken)); + BinaryContent content = new InternalCreateVectorStoreFileBatchRequest(fileIds).ToBinaryContent(); + RequestOptions options = cancellationToken.ToRequestOptions(); - VectorStoreFileBatchesPageToken pageToken = VectorStoreFileBatchesPageToken.FromToken(firstPageToken); + using PipelineMessage message = CreateCreateVectorStoreFileBatchRequest(vectorStoreId, content, options); + PipelineResponse response = _pipeline.ProcessMessage(message, options); + VectorStoreBatchFileJob value = VectorStoreBatchFileJob.FromResponse(response); - if (vectorStoreId != pageToken.VectorStoreId) - { - throw new ArgumentException( - "Invalid page token. 'vectorStoreId' value does not match page token value.", - nameof(vectorStoreId)); - } + VectorStoreFileBatchOperation operation = new VectorStoreFileBatchOperation( + _pipeline, + _endpoint, + ClientResult.FromValue(value, response)); - if (batchJobId != pageToken.BatchId) + if (returnWhen == ReturnWhen.Started) { - throw new ArgumentException( - "Invalid page token. 'batchJobId' value does not match page token value.", - nameof(vectorStoreId)); + return operation; } - VectorStoreFileBatchesPageEnumerator enumerator = new(_pipeline, _endpoint, - pageToken.VectorStoreId, - pageToken.BatchId, - pageToken.Limit, - pageToken.Order, - pageToken.After, - pageToken.Before, - pageToken.Filter, - cancellationToken.ToRequestOptions()); - - return PageCollectionHelpers.Create(enumerator); + operation.WaitForCompletion(); + return operation; } } diff --git a/src/Generated/BatchClient.cs b/src/Generated/BatchClient.cs index 3ccd2f90..ff92edef 100644 --- a/src/Generated/BatchClient.cs +++ b/src/Generated/BatchClient.cs @@ -81,23 +81,6 @@ internal PipelineMessage CreateRetrieveBatchRequest(string batchId, RequestOptio return message; } - internal PipelineMessage CreateCancelBatchRequest(string batchId, RequestOptions options) - { - var message = _pipeline.CreateMessage(); - message.ResponseClassifier = PipelineMessageClassifier200; - var request = message.Request; - request.Method = "POST"; - var uri = new ClientUriBuilder(); - uri.Reset(_endpoint); - uri.AppendPath("/batches/", false); - uri.AppendPath(batchId, true); - uri.AppendPath("/cancel", false); - request.Uri = uri.ToUri(); - request.Headers.Set("Accept", "application/json"); - message.Apply(options); - return message; - } - private static PipelineMessageClassifier _pipelineMessageClassifier200; private static PipelineMessageClassifier PipelineMessageClassifier200 => _pipelineMessageClassifier200 ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200 }); } diff --git a/src/Generated/FineTuningClient.cs b/src/Generated/FineTuningClient.cs index 7424e033..1427659e 100644 --- a/src/Generated/FineTuningClient.cs +++ b/src/Generated/FineTuningClient.cs @@ -71,89 +71,6 @@ internal PipelineMessage CreateGetPaginatedFineTuningJobsRequest(string after, i return message; } - internal PipelineMessage CreateRetrieveFineTuningJobRequest(string fineTuningJobId, RequestOptions options) - { - var message = _pipeline.CreateMessage(); - message.ResponseClassifier = PipelineMessageClassifier200; - var request = message.Request; - request.Method = "GET"; - var uri = new ClientUriBuilder(); - uri.Reset(_endpoint); - uri.AppendPath("/fine_tuning/jobs/", false); - uri.AppendPath(fineTuningJobId, true); - request.Uri = uri.ToUri(); - request.Headers.Set("Accept", "application/json"); - message.Apply(options); - return message; - } - - internal PipelineMessage CreateCancelFineTuningJobRequest(string fineTuningJobId, RequestOptions options) - { - var message = _pipeline.CreateMessage(); - message.ResponseClassifier = PipelineMessageClassifier200; - var request = message.Request; - request.Method = "POST"; - var uri = new ClientUriBuilder(); - uri.Reset(_endpoint); - uri.AppendPath("/fine_tuning/jobs/", false); - uri.AppendPath(fineTuningJobId, true); - uri.AppendPath("/cancel", false); - request.Uri = uri.ToUri(); - request.Headers.Set("Accept", "application/json"); - message.Apply(options); - return message; - } - - internal PipelineMessage CreateGetFineTuningJobCheckpointsRequest(string fineTuningJobId, string after, int? limit, RequestOptions options) - { - var message = _pipeline.CreateMessage(); - message.ResponseClassifier = PipelineMessageClassifier200; - var request = message.Request; - request.Method = "GET"; - var uri = new ClientUriBuilder(); - uri.Reset(_endpoint); - uri.AppendPath("/fine_tuning/jobs/", false); - uri.AppendPath(fineTuningJobId, true); - uri.AppendPath("/checkpoints", false); - if (after != null) - { - uri.AppendQuery("after", after, true); - } - if (limit != null) - { - uri.AppendQuery("limit", limit.Value, true); - } - request.Uri = uri.ToUri(); - request.Headers.Set("Accept", "application/json"); - message.Apply(options); - return message; - } - - internal PipelineMessage CreateGetFineTuningEventsRequest(string fineTuningJobId, string after, int? limit, RequestOptions options) - { - var message = _pipeline.CreateMessage(); - message.ResponseClassifier = PipelineMessageClassifier200; - var request = message.Request; - request.Method = "GET"; - var uri = new ClientUriBuilder(); - uri.Reset(_endpoint); - uri.AppendPath("/fine_tuning/jobs/", false); - uri.AppendPath(fineTuningJobId, true); - uri.AppendPath("/events", false); - if (after != null) - { - uri.AppendQuery("after", after, true); - } - if (limit != null) - { - uri.AppendQuery("limit", limit.Value, true); - } - request.Uri = uri.ToUri(); - request.Headers.Set("Accept", "application/json"); - message.Apply(options); - return message; - } - private static PipelineMessageClassifier _pipelineMessageClassifier200; private static PipelineMessageClassifier PipelineMessageClassifier200 => _pipelineMessageClassifier200 ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200 }); } diff --git a/src/Generated/VectorStoreClient.cs b/src/Generated/VectorStoreClient.cs index 6f2be46a..61cb80aa 100644 --- a/src/Generated/VectorStoreClient.cs +++ b/src/Generated/VectorStoreClient.cs @@ -240,82 +240,6 @@ internal PipelineMessage CreateCreateVectorStoreFileBatchRequest(string vectorSt return message; } - internal PipelineMessage CreateGetVectorStoreFileBatchRequest(string vectorStoreId, string batchId, RequestOptions options) - { - var message = _pipeline.CreateMessage(); - message.ResponseClassifier = PipelineMessageClassifier200; - var request = message.Request; - request.Method = "GET"; - var uri = new ClientUriBuilder(); - uri.Reset(_endpoint); - uri.AppendPath("/vector_stores/", false); - uri.AppendPath(vectorStoreId, true); - uri.AppendPath("/file_batches/", false); - uri.AppendPath(batchId, true); - request.Uri = uri.ToUri(); - request.Headers.Set("Accept", "application/json"); - message.Apply(options); - return message; - } - - internal PipelineMessage CreateCancelVectorStoreFileBatchRequest(string vectorStoreId, string batchId, RequestOptions options) - { - var message = _pipeline.CreateMessage(); - message.ResponseClassifier = PipelineMessageClassifier200; - var request = message.Request; - request.Method = "POST"; - var uri = new ClientUriBuilder(); - uri.Reset(_endpoint); - uri.AppendPath("/vector_stores/", false); - uri.AppendPath(vectorStoreId, true); - uri.AppendPath("/file_batches/", false); - uri.AppendPath(batchId, true); - uri.AppendPath("/cancel", false); - request.Uri = uri.ToUri(); - request.Headers.Set("Accept", "application/json"); - message.Apply(options); - return message; - } - - internal PipelineMessage CreateGetFilesInVectorStoreBatchesRequest(string vectorStoreId, string batchId, int? limit, string order, string after, string before, string filter, RequestOptions options) - { - var message = _pipeline.CreateMessage(); - message.ResponseClassifier = PipelineMessageClassifier200; - var request = message.Request; - request.Method = "GET"; - var uri = new ClientUriBuilder(); - uri.Reset(_endpoint); - uri.AppendPath("/vector_stores/", false); - uri.AppendPath(vectorStoreId, true); - uri.AppendPath("/file_batches/", false); - uri.AppendPath(batchId, true); - uri.AppendPath("/files", false); - if (limit != null) - { - uri.AppendQuery("limit", limit.Value, true); - } - if (order != null) - { - uri.AppendQuery("order", order, true); - } - if (after != null) - { - uri.AppendQuery("after", after, true); - } - if (before != null) - { - uri.AppendQuery("before", before, true); - } - if (filter != null) - { - uri.AppendQuery("filter", filter, true); - } - request.Uri = uri.ToUri(); - request.Headers.Set("Accept", "application/json"); - message.Apply(options); - return message; - } - private static PipelineMessageClassifier _pipelineMessageClassifier200; private static PipelineMessageClassifier PipelineMessageClassifier200 => _pipelineMessageClassifier200 ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200 }); } diff --git a/src/OpenAI.csproj b/src/OpenAI.csproj index eed204d5..b244f908 100644 --- a/src/OpenAI.csproj +++ b/src/OpenAI.csproj @@ -72,6 +72,9 @@ - + + + + diff --git a/src/To.Be.Generated/BatchOperation.Protocol.cs b/src/To.Be.Generated/BatchOperation.Protocol.cs new file mode 100644 index 00000000..bc0398d1 --- /dev/null +++ b/src/To.Be.Generated/BatchOperation.Protocol.cs @@ -0,0 +1,153 @@ +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +#nullable enable + +namespace OpenAI.Batch; + +// Protocol version +public partial class BatchOperation : OperationResult +{ + private readonly ClientPipeline _pipeline; + private readonly Uri _endpoint; + + private readonly string _batchId; + + private PollingInterval _pollingInterval; + + // For use with protocol methods where the response has been obtained prior + // to creation of the LRO instance. + internal BatchOperation( + ClientPipeline pipeline, + Uri endpoint, + string batchId, + string status, + PipelineResponse response) + : base(response) + { + _pipeline = pipeline; + _endpoint = endpoint; + + _batchId = batchId; + + IsCompleted = GetIsCompleted(status); + + _pollingInterval = new(); + + RehydrationToken = new BatchOperationToken(batchId); + } + + public override ContinuationToken? RehydrationToken { get; protected set; } + + public override bool IsCompleted { get; protected set; } + + // These are replaced when LRO is evolved to have conveniences + public override async Task WaitForCompletionAsync(CancellationToken cancellationToken = default) + { + IAsyncEnumerator enumerator = + new BatchOperationUpdateEnumerator(_pipeline, _endpoint, _batchId, cancellationToken); + + while (await enumerator.MoveNextAsync().ConfigureAwait(false)) + { + ApplyUpdate(enumerator.Current); + + await _pollingInterval.WaitAsync(cancellationToken); + } + } + + public override void WaitForCompletion(CancellationToken cancellationToken = default) + { + IEnumerator enumerator = new BatchOperationUpdateEnumerator( + _pipeline, _endpoint, _batchId, cancellationToken); + + while (enumerator.MoveNext()) + { + ApplyUpdate(enumerator.Current); + + cancellationToken.ThrowIfCancellationRequested(); + + _pollingInterval.Wait(); + } + } + + private void ApplyUpdate(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + using JsonDocument doc = JsonDocument.Parse(response.Content); + string? status = doc.RootElement.GetProperty("status"u8).GetString(); + + IsCompleted = GetIsCompleted(status); + SetRawResponse(response); + } + + private static bool GetIsCompleted(string? status) + { + return status == "completed" || + status == "cancelled" || + status == "expired" || + status == "failed"; + } + + // Generated protocol methods + + /// + /// [Protocol Method] Cancels an in-progress batch. + /// + /// The ID of the batch to cancel. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// is null. + /// is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + public virtual async Task CancelBatchAsync(string batchId, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(batchId, nameof(batchId)); + + using PipelineMessage message = CreateCancelBatchRequest(batchId, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + + /// + /// [Protocol Method] Cancels an in-progress batch. + /// + /// The ID of the batch to cancel. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// is null. + /// is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + public virtual ClientResult CancelBatch(string batchId, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(batchId, nameof(batchId)); + + using PipelineMessage message = CreateCancelBatchRequest(batchId, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + + internal PipelineMessage CreateCancelBatchRequest(string batchId, RequestOptions options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "POST"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/batches/", false); + uri.AppendPath(batchId, true); + uri.AppendPath("/cancel", false); + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + message.Apply(options); + return message; + } + + + private static PipelineMessageClassifier? _pipelineMessageClassifier200; + private static PipelineMessageClassifier PipelineMessageClassifier200 => _pipelineMessageClassifier200 ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200 }); +} \ No newline at end of file diff --git a/src/To.Be.Generated/BatchOperationToken.cs b/src/To.Be.Generated/BatchOperationToken.cs new file mode 100644 index 00000000..41b3b9a1 --- /dev/null +++ b/src/To.Be.Generated/BatchOperationToken.cs @@ -0,0 +1,90 @@ +using System; +using System.ClientModel; +using System.Diagnostics; +using System.IO; +using System.Text.Json; + +#nullable enable + +namespace OpenAI.Batch; + +internal class BatchOperationToken : ContinuationToken +{ + public BatchOperationToken(string batchId) + { + BatchId = batchId; + } + + public string BatchId { get; } + + public override BinaryData ToBytes() + { + using MemoryStream stream = new(); + using Utf8JsonWriter writer = new(stream); + writer.WriteStartObject(); + + writer.WriteString("batchId", BatchId); + + writer.WriteEndObject(); + + writer.Flush(); + stream.Position = 0; + + return BinaryData.FromStream(stream); + } + + public static BatchOperationToken FromToken(ContinuationToken continuationToken) + { + if (continuationToken is BatchOperationToken token) + { + return token; + } + + BinaryData data = continuationToken.ToBytes(); + + if (data.ToMemory().Length == 0) + { + throw new ArgumentException("Failed to create BatchOperationToken from provided continuationToken.", nameof(continuationToken)); + } + + Utf8JsonReader reader = new(data); + + string batchId = null!; + + reader.Read(); + + Debug.Assert(reader.TokenType == JsonTokenType.StartObject); + + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndObject) + { + break; + } + + Debug.Assert(reader.TokenType == JsonTokenType.PropertyName); + + string propertyName = reader.GetString()!; + + switch (propertyName) + { + case "batchId": + reader.Read(); + Debug.Assert(reader.TokenType == JsonTokenType.String); + batchId = reader.GetString()!; + break; + + default: + throw new JsonException($"Unrecognized property '{propertyName}'."); + } + } + + if (batchId is null) + { + throw new ArgumentException("Failed to create BatchOperationToken from provided continuationToken.", nameof(continuationToken)); + } + + return new(batchId); + } +} + diff --git a/src/To.Be.Generated/BatchOperationUpdateEnumerator.Protocol.cs b/src/To.Be.Generated/BatchOperationUpdateEnumerator.Protocol.cs new file mode 100644 index 00000000..ae568011 --- /dev/null +++ b/src/To.Be.Generated/BatchOperationUpdateEnumerator.Protocol.cs @@ -0,0 +1,165 @@ +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections; +using System.Collections.Generic; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +#nullable enable + +namespace OpenAI.Batch; + +internal partial class BatchOperationUpdateEnumerator : + IAsyncEnumerator, + IEnumerator +{ + private readonly ClientPipeline _pipeline; + private readonly Uri _endpoint; + private readonly CancellationToken _cancellationToken; + + private readonly string _batchId; + + // TODO: does this one need to be nullable? + private ClientResult? _current; + private bool _hasNext = true; + + public BatchOperationUpdateEnumerator( + ClientPipeline pipeline, + Uri endpoint, + string batchId, + CancellationToken cancellationToken) + { + _pipeline = pipeline; + _endpoint = endpoint; + + _batchId = batchId; + + _cancellationToken = cancellationToken; + } + + public ClientResult Current => _current!; + + #region IEnumerator methods + + object IEnumerator.Current => _current!; + + bool IEnumerator.MoveNext() + { + if (!_hasNext) + { + _current = null; + return false; + } + + ClientResult result = GetBatch(_batchId, _cancellationToken.ToRequestOptions()); + + _current = result; + _hasNext = HasNext(result); + + return true; + } + + void IEnumerator.Reset() => _current = null; + + void IDisposable.Dispose() { } + + #endregion + + #region IAsyncEnumerator methods + + ClientResult IAsyncEnumerator.Current => _current!; + + public async ValueTask MoveNextAsync() + { + if (!_hasNext) + { + _current = null; + return false; + } + + ClientResult result = await GetBatchAsync(_batchId, _cancellationToken.ToRequestOptions()).ConfigureAwait(false); + + _current = result; + _hasNext = HasNext(result); + + return true; + } + + // TODO: handle Dispose and DisposeAsync using proper patterns? + ValueTask IAsyncDisposable.DisposeAsync() => default; + + #endregion + + private bool HasNext(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + // TODO: don't parse JsonDocument twice if possible + using JsonDocument doc = JsonDocument.Parse(response.Content); + string? status = doc.RootElement.GetProperty("status"u8).GetString(); + + bool isComplete = status == "completed" || + status == "cancelled" || + status == "expired" || + status == "failed"; + + return !isComplete; + } + + // Generated methods + + /// + /// [Protocol Method] Retrieves a batch. + /// + /// The ID of the batch to retrieve. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// is null. + /// is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + public virtual async Task GetBatchAsync(string batchId, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(batchId, nameof(batchId)); + + using PipelineMessage message = CreateRetrieveBatchRequest(batchId, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + + /// + /// [Protocol Method] Retrieves a batch. + /// + /// The ID of the batch to retrieve. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// is null. + /// is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + public virtual ClientResult GetBatch(string batchId, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(batchId, nameof(batchId)); + + using PipelineMessage message = CreateRetrieveBatchRequest(batchId, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + + internal PipelineMessage CreateRetrieveBatchRequest(string batchId, RequestOptions options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "GET"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/batches/", false); + uri.AppendPath(batchId, true); + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + message.Apply(options); + return message; + } + + private static PipelineMessageClassifier? _pipelineMessageClassifier200; + private static PipelineMessageClassifier PipelineMessageClassifier200 => _pipelineMessageClassifier200 ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200 }); +} diff --git a/src/To.Be.Generated/FineTuningOperation.Protocol.cs b/src/To.Be.Generated/FineTuningOperation.Protocol.cs new file mode 100644 index 00000000..c4a6dd84 --- /dev/null +++ b/src/To.Be.Generated/FineTuningOperation.Protocol.cs @@ -0,0 +1,344 @@ +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +#nullable enable + +namespace OpenAI.FineTuning; + +public partial class FineTuningOperation : OperationResult +{ + private readonly ClientPipeline _pipeline; + private readonly Uri _endpoint; + + private readonly string _jobId; + + private PollingInterval _pollingInterval; + + internal FineTuningOperation( + ClientPipeline pipeline, + Uri endpoint, + string jobId, + string status, + PipelineResponse response) : base(response) + { + _pipeline = pipeline; + _endpoint = endpoint; + + _jobId = jobId; + IsCompleted = GetIsCompleted(status); + + _pollingInterval = new(); + + RehydrationToken = new FineTuningOperationToken(jobId); + } + + public override ContinuationToken? RehydrationToken { get; protected set; } + + public override bool IsCompleted { get; protected set; } + + public override async Task WaitForCompletionAsync(CancellationToken cancellationToken = default) + { + IAsyncEnumerator enumerator = new FineTuningOperationUpdateEnumerator( + _pipeline, _endpoint, _jobId, cancellationToken); + + while (await enumerator.MoveNextAsync().ConfigureAwait(false)) + { + ApplyUpdate(enumerator.Current); + + await _pollingInterval.WaitAsync(cancellationToken); + } + } + + public override void WaitForCompletion(CancellationToken cancellationToken = default) + { + IEnumerator enumerator = new FineTuningOperationUpdateEnumerator( + _pipeline, _endpoint, _jobId, cancellationToken); + + while (enumerator.MoveNext()) + { + ApplyUpdate(enumerator.Current); + + cancellationToken.ThrowIfCancellationRequested(); + + _pollingInterval.Wait(); + } + } + + private void ApplyUpdate(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + using JsonDocument doc = JsonDocument.Parse(response.Content); + string? status = doc.RootElement.GetProperty("status"u8).GetString(); + + IsCompleted = GetIsCompleted(status); + SetRawResponse(response); + } + + private static bool GetIsCompleted(string? status) + { + return status == "succeeded" || + status == "failed" || + status == "cancelled"; + } + + // Generated protocol methods + + + // CUSTOM: + // - Renamed. + // - Edited doc comment. + /// + /// [Protocol Method] Get info about a fine-tuning job. + /// + /// [Learn more about fine-tuning](/docs/guides/fine-tuning) + /// + /// The ID of the fine-tuning job. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// is null. + /// is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + public virtual async Task GetJobAsync(string jobId, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); + + using PipelineMessage message = CreateRetrieveFineTuningJobRequest(jobId, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + + // CUSTOM: + // - Renamed. + // - Edited doc comment. + /// + /// [Protocol Method] Get info about a fine-tuning job. + /// + /// [Learn more about fine-tuning](/docs/guides/fine-tuning) + /// + /// The ID of the fine-tuning job. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// is null. + /// is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + public virtual ClientResult GetJob(string jobId, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); + + using PipelineMessage message = CreateRetrieveFineTuningJobRequest(jobId, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + + // CUSTOM: + // - Renamed. + // - Edited doc comment. + /// + /// [Protocol Method] Immediately cancel a fine-tune job. + /// + /// The ID of the fine-tuning job to cancel. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// is null. + /// is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + public virtual async Task CancelJobAsync(string jobId, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); + + using PipelineMessage message = CreateCancelFineTuningJobRequest(jobId, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + + // CUSTOM: + // - Renamed. + // - Edited doc comment. + /// + /// [Protocol Method] Immediately cancel a fine-tune job. + /// + /// The ID of the fine-tuning job to cancel. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// is null. + /// is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + public virtual ClientResult CancelJob(string jobId, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); + + using PipelineMessage message = CreateCancelFineTuningJobRequest(jobId, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + + // CUSTOM: + // - Renamed. + // - Edited doc comment. + /// + /// [Protocol Method] Get status updates for a fine-tuning job. + /// + /// The ID of the fine-tuning job to get events for. + /// Identifier for the last event from the previous pagination request. + /// Number of events to retrieve. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// is null. + /// is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + public virtual async Task GetJobEventsAsync(string jobId, string after, int? limit, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); + + using PipelineMessage message = CreateGetFineTuningEventsRequest(jobId, after, limit, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + + // CUSTOM: + // - Renamed. + // - Edited doc comment. + /// + /// [Protocol Method] Get status updates for a fine-tuning job. + /// + /// The ID of the fine-tuning job to get events for. + /// Identifier for the last event from the previous pagination request. + /// Number of events to retrieve. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// is null. + /// is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + public virtual ClientResult GetJobEvents(string jobId, string after, int? limit, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); + + using PipelineMessage message = CreateGetFineTuningEventsRequest(jobId, after, limit, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + + /// + /// [Protocol Method] List the checkpoints for a fine-tuning job. + /// + /// The ID of the fine-tuning job to get checkpoints for. + /// Identifier for the last checkpoint ID from the previous pagination request. + /// Number of checkpoints to retrieve. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// is null. + /// is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + public virtual async Task GetJobCheckpointsAsync(string fineTuningJobId, string after, int? limit, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(fineTuningJobId, nameof(fineTuningJobId)); + + using PipelineMessage message = CreateGetFineTuningJobCheckpointsRequest(fineTuningJobId, after, limit, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + + /// + /// [Protocol Method] List the checkpoints for a fine-tuning job. + /// + /// The ID of the fine-tuning job to get checkpoints for. + /// Identifier for the last checkpoint ID from the previous pagination request. + /// Number of checkpoints to retrieve. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// is null. + /// is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + public virtual ClientResult GetJobCheckpoints(string fineTuningJobId, string after, int? limit, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(fineTuningJobId, nameof(fineTuningJobId)); + + using PipelineMessage message = CreateGetFineTuningJobCheckpointsRequest(fineTuningJobId, after, limit, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + + internal PipelineMessage CreateRetrieveFineTuningJobRequest(string fineTuningJobId, RequestOptions options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "GET"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/fine_tuning/jobs/", false); + uri.AppendPath(fineTuningJobId, true); + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + message.Apply(options); + return message; + } + + internal PipelineMessage CreateCancelFineTuningJobRequest(string fineTuningJobId, RequestOptions options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "POST"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/fine_tuning/jobs/", false); + uri.AppendPath(fineTuningJobId, true); + uri.AppendPath("/cancel", false); + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + message.Apply(options); + return message; + } + + internal PipelineMessage CreateGetFineTuningJobCheckpointsRequest(string fineTuningJobId, string after, int? limit, RequestOptions options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "GET"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/fine_tuning/jobs/", false); + uri.AppendPath(fineTuningJobId, true); + uri.AppendPath("/checkpoints", false); + if (after != null) + { + uri.AppendQuery("after", after, true); + } + if (limit != null) + { + uri.AppendQuery("limit", limit.Value, true); + } + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + message.Apply(options); + return message; + } + + internal PipelineMessage CreateGetFineTuningEventsRequest(string fineTuningJobId, string after, int? limit, RequestOptions options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "GET"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/fine_tuning/jobs/", false); + uri.AppendPath(fineTuningJobId, true); + uri.AppendPath("/events", false); + if (after != null) + { + uri.AppendQuery("after", after, true); + } + if (limit != null) + { + uri.AppendQuery("limit", limit.Value, true); + } + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + message.Apply(options); + return message; + } + + private static PipelineMessageClassifier? _pipelineMessageClassifier200; + private static PipelineMessageClassifier PipelineMessageClassifier200 => _pipelineMessageClassifier200 ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200 }); +} diff --git a/src/To.Be.Generated/FineTuningOperationToken.cs b/src/To.Be.Generated/FineTuningOperationToken.cs new file mode 100644 index 00000000..a2f131d3 --- /dev/null +++ b/src/To.Be.Generated/FineTuningOperationToken.cs @@ -0,0 +1,91 @@ +using System; +using System.ClientModel; +using System.Diagnostics; +using System.IO; +using System.Text.Json; + +#nullable enable + +namespace OpenAI.FineTuning; + +internal class FineTuningOperationToken : ContinuationToken +{ + public FineTuningOperationToken(string jobId) + { + JobId = jobId; + } + + public string JobId { get; } + + public override BinaryData ToBytes() + { + using MemoryStream stream = new(); + using Utf8JsonWriter writer = new(stream); + + writer.WriteStartObject(); + + writer.WriteString("jobId", JobId); + + writer.WriteEndObject(); + + writer.Flush(); + stream.Position = 0; + + return BinaryData.FromStream(stream); + } + + public static FineTuningOperationToken FromToken(ContinuationToken continuationToken) + { + if (continuationToken is FineTuningOperationToken token) + { + return token; + } + + BinaryData data = continuationToken.ToBytes(); + + if (data.ToMemory().Length == 0) + { + throw new ArgumentException("Failed to create FineTuningOperationToken from provided continuationToken.", nameof(continuationToken)); + } + + Utf8JsonReader reader = new(data); + + string jobId = null!; + + reader.Read(); + + Debug.Assert(reader.TokenType == JsonTokenType.StartObject); + + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndObject) + { + break; + } + + Debug.Assert(reader.TokenType == JsonTokenType.PropertyName); + + string propertyName = reader.GetString()!; + + switch (propertyName) + { + case "jobId": + reader.Read(); + Debug.Assert(reader.TokenType == JsonTokenType.String); + jobId = reader.GetString()!; + break; + + default: + throw new JsonException($"Unrecognized property '{propertyName}'."); + } + } + + if (jobId is null) + { + throw new ArgumentException("Failed to create RunOperationToken from provided continuationToken.", nameof(continuationToken)); + } + + return new(jobId); + } +} + diff --git a/src/To.Be.Generated/FineTuningOperationUpdateEnumerator.Protocol.cs b/src/To.Be.Generated/FineTuningOperationUpdateEnumerator.Protocol.cs new file mode 100644 index 00000000..29d75f8e --- /dev/null +++ b/src/To.Be.Generated/FineTuningOperationUpdateEnumerator.Protocol.cs @@ -0,0 +1,170 @@ +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections; +using System.Collections.Generic; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +#nullable enable + +namespace OpenAI.FineTuning; + +internal partial class FineTuningOperationUpdateEnumerator : + IAsyncEnumerator, + IEnumerator +{ + private readonly ClientPipeline _pipeline; + private readonly Uri _endpoint; + private readonly CancellationToken _cancellationToken; + + private readonly string _jobId; + + private ClientResult? _current; + private bool _hasNext = true; + + public FineTuningOperationUpdateEnumerator( + ClientPipeline pipeline, + Uri endpoint, + string jobId, + CancellationToken cancellationToken) + { + _pipeline = pipeline; + _endpoint = endpoint; + + _jobId = jobId; + + _cancellationToken = cancellationToken; + } + + public ClientResult Current => _current!; + + #region IEnumerator methods + + object IEnumerator.Current => _current!; + + bool IEnumerator.MoveNext() + { + if (!_hasNext) + { + _current = null; + return false; + } + + ClientResult result = GetJob(_jobId, _cancellationToken.ToRequestOptions()); + + _current = result; + _hasNext = HasNext(result); + + return true; + } + + void IEnumerator.Reset() => _current = null; + + void IDisposable.Dispose() { } + + #endregion + + #region IAsyncEnumerator methods + + ClientResult IAsyncEnumerator.Current => _current!; + + public async ValueTask MoveNextAsync() + { + if (!_hasNext) + { + _current = null; + return false; + } + + ClientResult result = await GetJobAsync(_jobId, _cancellationToken.ToRequestOptions()).ConfigureAwait(false); + + _current = result; + _hasNext = HasNext(result); + + return true; + } + + // TODO: handle Dispose and DisposeAsync using proper patterns? + ValueTask IAsyncDisposable.DisposeAsync() => default; + + #endregion + + private bool HasNext(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + // TODO: don't parse JsonDocument twice if possible + using JsonDocument doc = JsonDocument.Parse(response.Content); + string? status = doc.RootElement.GetProperty("status"u8).GetString(); + + bool isComplete = status == "succeeded" || + status == "failed" || + status == "cancelled"; + + return !isComplete; + } + + // Generated methods + + /// + /// [Protocol Method] Get info about a fine-tuning job. + /// + /// [Learn more about fine-tuning](/docs/guides/fine-tuning) + /// + /// The ID of the fine-tuning job. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// is null. + /// is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + public virtual async Task GetJobAsync(string jobId, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); + + using PipelineMessage message = CreateRetrieveFineTuningJobRequest(jobId, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + + // CUSTOM: + // - Renamed. + // - Edited doc comment. + /// + /// [Protocol Method] Get info about a fine-tuning job. + /// + /// [Learn more about fine-tuning](/docs/guides/fine-tuning) + /// + /// The ID of the fine-tuning job. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// is null. + /// is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + public virtual ClientResult GetJob(string jobId, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); + + using PipelineMessage message = CreateRetrieveFineTuningJobRequest(jobId, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + + internal PipelineMessage CreateRetrieveFineTuningJobRequest(string fineTuningJobId, RequestOptions options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "GET"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/fine_tuning/jobs/", false); + uri.AppendPath(fineTuningJobId, true); + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + message.Apply(options); + return message; + } + + private static PipelineMessageClassifier? _pipelineMessageClassifier200; + private static PipelineMessageClassifier PipelineMessageClassifier200 => _pipelineMessageClassifier200 ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200 }); +} diff --git a/src/To.Be.Generated/Internal/PollingInterval.cs b/src/To.Be.Generated/Internal/PollingInterval.cs new file mode 100644 index 00000000..d159aabd --- /dev/null +++ b/src/To.Be.Generated/Internal/PollingInterval.cs @@ -0,0 +1,29 @@ +using System; +using System.Threading; +using System.Threading.Tasks; + +#nullable enable + +namespace OpenAI; + +internal class PollingInterval +{ + private const int DefaultWaitMilliseconds = 1000; + + private readonly TimeSpan _interval; + + public PollingInterval(TimeSpan? interval = default) + { + _interval = interval ?? new TimeSpan(DefaultWaitMilliseconds); + } + + public async Task WaitAsync(CancellationToken cancellationToken) + { + await Task.Delay(_interval, cancellationToken); + } + + public void Wait() + { + Thread.Sleep(_interval); + } +} diff --git a/src/To.Be.Generated/RunOperation.Protocol.cs b/src/To.Be.Generated/RunOperation.Protocol.cs new file mode 100644 index 00000000..5d9b6c0e --- /dev/null +++ b/src/To.Be.Generated/RunOperation.Protocol.cs @@ -0,0 +1,639 @@ +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.ComponentModel; +using System.Text.Json; +using System.Threading.Tasks; + +#nullable enable + +namespace OpenAI.Assistants; + +// Protocol version +public partial class RunOperation : ClientResult +{ + private readonly ClientPipeline _pipeline; + private readonly Uri _endpoint; + + private string? _threadId; + private string? _runId; + private string? _status; + + private bool _isCompleted; + + private PollingInterval? _pollingInterval; + + // For use with protocol methods where the response has been obtained prior + // to creation of the LRO instance. + internal RunOperation( + ClientPipeline pipeline, + Uri endpoint, + PipelineResponse response) + : base(response) + { + _pipeline = pipeline; + _endpoint = endpoint; + + // Protocol method was called with stream=true option. + bool isStreaming = + response.Headers.TryGetValue("Content-Type", out string? contentType) && + contentType == "text/event-stream; charset=utf-8"; + + if (!isStreaming) + { + _pollingInterval = new(); + + using JsonDocument doc = JsonDocument.Parse(response.Content); + + _status = doc.RootElement.GetProperty("status"u8).GetString(); + _threadId = doc.RootElement.GetProperty("thread_id"u8).GetString(); + _runId = doc.RootElement.GetProperty("id"u8).GetString(); + + if (_status is null || _threadId is null || _runId is null) + { + throw new ArgumentException("Invalid 'response' body.", nameof(response)); + } + + IsCompleted = GetIsCompleted(_status!); + } + } + + #region OperationResult methods + + public virtual bool IsCompleted + { + get + { + // We need this check in the protocol/streaming case. + if (IsStreaming) + { + throw new NotSupportedException("Cannot obtain operation status from streaming operation."); + } + + return _isCompleted; + } + + protected set => _isCompleted = value; + } + + public virtual ContinuationToken? RehydrationToken { get; protected set; } + + internal bool IsStreaming => _pollingInterval == null; + + // Note: these work for protocol-only. + // Once convenience overloads available, these get replaced by those implementations. + + //public override Task WaitAsync(CancellationToken cancellationToken = default) + //{ + // if (_isStreaming) + // { + // // We would have to read from the stream to get the run ID to poll for. + // throw new NotSupportedException("Cannot poll for status updates from streaming operation."); + // } + + // // See: https://platform.openai.com/docs/assistants/how-it-works/polling-for-updates + + // IAsyncEnumerator enumerator = GetUpdateResultEnumeratorAsync(); + + // await while (await enumerator.MoveNextAsync().ConfigureAwait(false)) + // { + // ApplyUpdate(enumerator.Current); + + // // Don't keep polling if would do so infinitely. + // if (_status == "requires_action") + // { + // return; + // } + + // cancellationToken.ThrowIfCancellationRequested(); + + // await _pollingInterval.WaitAsync().ConfigureAwait(); + // } + //} + + //public override void Wait(CancellationToken cancellationToken = default) + //{ + // if (_isStreaming) + // { + // // We would have to read from the stream to get the run ID to poll for. + // throw new NotSupportedException("Cannot poll for status updates from streaming operation."); + // } + + // // See: https://platform.openai.com/docs/assistants/how-it-works/polling-for-updates + + // IEnumerator enumerator = GetUpdateResultEnumerator(); + + // while (enumerator.MoveNext()) + // { + // ApplyUpdate(enumerator.Current); + + // // Don't keep polling if would do so infinitely. + // if (_status == "requires_action") + // { + // return; + // } + + // cancellationToken.ThrowIfCancellationRequested(); + + // _pollingInterval.Wait(); + // } + //} + + //private void ApplyUpdate(ClientResult result) + //{ + // PipelineResponse response = result.GetRawResponse(); + + // using JsonDocument doc = JsonDocument.Parse(response.Content); + // _status = doc.RootElement.GetProperty("status"u8).GetString(); + + // IsCompleted = GetIsCompleted(_status!); + // SetRawResponse(response); + //} + + private static bool GetIsCompleted(string status) + { + bool hasCompleted = + status == "expired" || + status == "completed" || + status == "failed" || + status == "incomplete" || + status == "cancelled"; + + return hasCompleted; + } + + #endregion + + #region Generated protocol methods - i.e. TypeSpec "linked operations" + + /// + /// [Protocol Method] Retrieves a run. + /// + /// The ID of the [thread](/docs/api-reference/threads) that was run. + /// The ID of the run to retrieve. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// or is null. + /// or is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual async Task GetRunAsync(string threadId, string runId, RequestOptions? options) + { + Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); + Argument.AssertNotNullOrEmpty(runId, nameof(runId)); + + using PipelineMessage message = CreateGetRunRequest(threadId, runId, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + + /// + /// [Protocol Method] Retrieves a run. + /// + /// The ID of the [thread](/docs/api-reference/threads) that was run. + /// The ID of the run to retrieve. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// or is null. + /// or is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual ClientResult GetRun(string threadId, string runId, RequestOptions? options) + { + Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); + Argument.AssertNotNullOrEmpty(runId, nameof(runId)); + + using PipelineMessage message = CreateGetRunRequest(threadId, runId, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + + /// + /// [Protocol Method] Modifies a run. + /// + /// The ID of the [thread](/docs/api-reference/threads) that was run. + /// The ID of the run to modify. + /// The content to send as the body of the request. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// , or is null. + /// or is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual async Task ModifyRunAsync(string threadId, string runId, BinaryContent content, RequestOptions? options = null) + { + Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); + Argument.AssertNotNullOrEmpty(runId, nameof(runId)); + Argument.AssertNotNull(content, nameof(content)); + + using PipelineMessage message = CreateModifyRunRequest(threadId, runId, content, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + + /// + /// [Protocol Method] Modifies a run. + /// + /// The ID of the [thread](/docs/api-reference/threads) that was run. + /// The ID of the run to modify. + /// The content to send as the body of the request. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// , or is null. + /// or is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual ClientResult ModifyRun(string threadId, string runId, BinaryContent content, RequestOptions? options = null) + { + Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); + Argument.AssertNotNullOrEmpty(runId, nameof(runId)); + Argument.AssertNotNull(content, nameof(content)); + + using PipelineMessage message = CreateModifyRunRequest(threadId, runId, content, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + + /// + /// [Protocol Method] Cancels a run that is `in_progress`. + /// + /// The ID of the thread to which this run belongs. + /// The ID of the run to cancel. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// or is null. + /// or is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual async Task CancelRunAsync(string threadId, string runId, RequestOptions? options) + { + Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); + Argument.AssertNotNullOrEmpty(runId, nameof(runId)); + + using PipelineMessage message = CreateCancelRunRequest(threadId, runId, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + + /// + /// [Protocol Method] Cancels a run that is `in_progress`. + /// + /// The ID of the thread to which this run belongs. + /// The ID of the run to cancel. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// or is null. + /// or is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual ClientResult CancelRun(string threadId, string runId, RequestOptions? options) + { + Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); + Argument.AssertNotNullOrEmpty(runId, nameof(runId)); + + using PipelineMessage message = CreateCancelRunRequest(threadId, runId, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + + /// + /// [Protocol Method] When a run has the `status: "requires_action"` and `required_action.type` is + /// `submit_tool_outputs`, this endpoint can be used to submit the outputs from the tool calls once + /// they're all completed. All outputs must be submitted in a single request. + /// + /// The ID of the [thread](/docs/api-reference/threads) to which this run belongs. + /// The ID of the run that requires the tool output submission. + /// The content to send as the body of the request. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// , or is null. + /// or is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual async Task SubmitToolOutputsToRunAsync(string threadId, string runId, BinaryContent content, RequestOptions? options = null) + { + Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); + Argument.AssertNotNullOrEmpty(runId, nameof(runId)); + Argument.AssertNotNull(content, nameof(content)); + + PipelineMessage? message = null; + try + { + message = CreateSubmitToolOutputsToRunRequest(threadId, runId, content, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + finally + { + if (options?.BufferResponse != false) + { + message?.Dispose(); + } + } + } + + /// + /// [Protocol Method] When a run has the `status: "requires_action"` and `required_action.type` is + /// `submit_tool_outputs`, this endpoint can be used to submit the outputs from the tool calls once + /// they're all completed. All outputs must be submitted in a single request. + /// + /// The ID of the [thread](/docs/api-reference/threads) to which this run belongs. + /// The ID of the run that requires the tool output submission. + /// The content to send as the body of the request. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// , or is null. + /// or is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual ClientResult SubmitToolOutputsToRun(string threadId, string runId, BinaryContent content, RequestOptions? options = null) + { + Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); + Argument.AssertNotNullOrEmpty(runId, nameof(runId)); + Argument.AssertNotNull(content, nameof(content)); + + PipelineMessage? message = null; + try + { + message = CreateSubmitToolOutputsToRunRequest(threadId, runId, content, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + finally + { + if (options?.BufferResponse != false) + { + message?.Dispose(); + } + } + } + + /// + /// [Protocol Method] Returns a paginated collection of run steps belonging to a run. + /// + /// The ID of the thread the run and run steps belong to. + /// The ID of the run the run steps belong to. + /// + /// A limit on the number of objects to be returned. Limit can range between 1 and 100, and the + /// default is 20. + /// + /// + /// Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and`desc` + /// for descending order. Allowed values: "asc" | "desc" + /// + /// + /// A cursor for use in pagination. `after` is an object ID that defines your place in the list. + /// For instance, if you make a list request and receive 100 objects, ending with obj_foo, your + /// subsequent call can include after=obj_foo in order to fetch the next page of the list. + /// + /// + /// A cursor for use in pagination. `before` is an object ID that defines your place in the list. + /// For instance, if you make a list request and receive 100 objects, ending with obj_foo, your + /// subsequent call can include before=obj_foo in order to fetch the previous page of the list. + /// + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// or is null. + /// or is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// A collection of service responses, each holding a page of values. + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual IAsyncEnumerable GetRunStepsAsync(string threadId, string runId, int? limit, string order, string after, string before, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); + Argument.AssertNotNullOrEmpty(runId, nameof(runId)); + + RunStepsPageEnumerator enumerator = new RunStepsPageEnumerator(_pipeline, _endpoint, threadId, runId, limit, order, after, before, options); + return PageCollectionHelpers.CreateAsync(enumerator); + } + + /// + /// [Protocol Method] Returns a paginated collection of run steps belonging to a run. + /// + /// The ID of the thread the run and run steps belong to. + /// The ID of the run the run steps belong to. + /// + /// A limit on the number of objects to be returned. Limit can range between 1 and 100, and the + /// default is 20. + /// + /// + /// Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and`desc` + /// for descending order. Allowed values: "asc" | "desc" + /// + /// + /// A cursor for use in pagination. `after` is an object ID that defines your place in the list. + /// For instance, if you make a list request and receive 100 objects, ending with obj_foo, your + /// subsequent call can include after=obj_foo in order to fetch the next page of the list. + /// + /// + /// A cursor for use in pagination. `before` is an object ID that defines your place in the list. + /// For instance, if you make a list request and receive 100 objects, ending with obj_foo, your + /// subsequent call can include before=obj_foo in order to fetch the previous page of the list. + /// + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// or is null. + /// or is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// A collection of service responses, each holding a page of values. + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual IEnumerable GetRunSteps(string threadId, string runId, int? limit, string order, string after, string before, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); + Argument.AssertNotNullOrEmpty(runId, nameof(runId)); + + RunStepsPageEnumerator enumerator = new RunStepsPageEnumerator(_pipeline, _endpoint, threadId, runId, limit, order, after, before, options); + return PageCollectionHelpers.Create(enumerator); + } + + /// + /// [Protocol Method] Retrieves a run step. + /// + /// The ID of the thread to which the run and run step belongs. + /// The ID of the run to which the run step belongs. + /// The ID of the run step to retrieve. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// , or is null. + /// , or is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual async Task GetRunStepAsync(string threadId, string runId, string stepId, RequestOptions? options) + { + Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); + Argument.AssertNotNullOrEmpty(runId, nameof(runId)); + Argument.AssertNotNullOrEmpty(stepId, nameof(stepId)); + + using PipelineMessage message = CreateGetRunStepRequest(threadId, runId, stepId, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + + /// + /// [Protocol Method] Retrieves a run step. + /// + /// The ID of the thread to which the run and run step belongs. + /// The ID of the run to which the run step belongs. + /// The ID of the run step to retrieve. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// , or is null. + /// , or is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual ClientResult GetRunStep(string threadId, string runId, string stepId, RequestOptions? options) + { + Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); + Argument.AssertNotNullOrEmpty(runId, nameof(runId)); + Argument.AssertNotNullOrEmpty(stepId, nameof(stepId)); + + using PipelineMessage message = CreateGetRunStepRequest(threadId, runId, stepId, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + + internal PipelineMessage CreateCreateRunRequest(string threadId, BinaryContent content, RequestOptions? options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "POST"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/threads/", false); + uri.AppendPath(threadId, true); + uri.AppendPath("/runs", false); + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + request.Headers.Set("Content-Type", "application/json"); + request.Content = content; + message.Apply(options); + return message; + } + + internal PipelineMessage CreateGetRunRequest(string threadId, string runId, RequestOptions? options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "GET"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/threads/", false); + uri.AppendPath(threadId, true); + uri.AppendPath("/runs/", false); + uri.AppendPath(runId, true); + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + message.Apply(options); + return message; + } + + internal PipelineMessage CreateModifyRunRequest(string threadId, string runId, BinaryContent content, RequestOptions? options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "POST"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/threads/", false); + uri.AppendPath(threadId, true); + uri.AppendPath("/runs/", false); + uri.AppendPath(runId, true); + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + request.Headers.Set("Content-Type", "application/json"); + request.Content = content; + message.Apply(options); + return message; + } + + internal PipelineMessage CreateCancelRunRequest(string threadId, string runId, RequestOptions? options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "POST"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/threads/", false); + uri.AppendPath(threadId, true); + uri.AppendPath("/runs/", false); + uri.AppendPath(runId, true); + uri.AppendPath("/cancel", false); + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + message.Apply(options); + return message; + } + + internal PipelineMessage CreateSubmitToolOutputsToRunRequest(string threadId, string runId, BinaryContent content, RequestOptions? options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "POST"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/threads/", false); + uri.AppendPath(threadId, true); + uri.AppendPath("/runs/", false); + uri.AppendPath(runId, true); + uri.AppendPath("/submit_tool_outputs", false); + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + request.Headers.Set("Content-Type", "application/json"); + request.Content = content; + message.Apply(options); + return message; + } + + internal PipelineMessage CreateGetRunStepsRequest(string threadId, string runId, int? limit, string order, string after, string before, RequestOptions? options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "GET"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/threads/", false); + uri.AppendPath(threadId, true); + uri.AppendPath("/runs/", false); + uri.AppendPath(runId, true); + uri.AppendPath("/steps", false); + if (limit != null) + { + uri.AppendQuery("limit", limit.Value, true); + } + if (order != null) + { + uri.AppendQuery("order", order, true); + } + if (after != null) + { + uri.AppendQuery("after", after, true); + } + if (before != null) + { + uri.AppendQuery("before", before, true); + } + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + message.Apply(options); + return message; + } + + internal PipelineMessage CreateGetRunStepRequest(string threadId, string runId, string stepId, RequestOptions? options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "GET"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/threads/", false); + uri.AppendPath(threadId, true); + uri.AppendPath("/runs/", false); + uri.AppendPath(runId, true); + uri.AppendPath("/steps/", false); + uri.AppendPath(stepId, true); + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + message.Apply(options); + return message; + } + + private static PipelineMessageClassifier? _pipelineMessageClassifier200; + private static PipelineMessageClassifier PipelineMessageClassifier200 => _pipelineMessageClassifier200 ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200 }); + #endregion +} \ No newline at end of file diff --git a/src/To.Be.Generated/RunOperation.cs b/src/To.Be.Generated/RunOperation.cs new file mode 100644 index 00000000..e2c53430 --- /dev/null +++ b/src/To.Be.Generated/RunOperation.cs @@ -0,0 +1,392 @@ +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +#nullable enable + +namespace OpenAI.Assistants; + +// Convenience version +public partial class RunOperation : ClientResult +{ + // For use with polling convenience methods where the response has been + // obtained prior to creation of the LRO type. + internal RunOperation( + ClientPipeline pipeline, + Uri endpoint, + ThreadRun value, + RunStatus status, + PipelineResponse response) + : base(response) + { + _pipeline = pipeline; + _endpoint = endpoint; + _pollingInterval = new(); + + if (response.Headers.TryGetValue("Content-Type", out string? contentType) && + contentType == "text/event-stream; charset=utf-8") + { + throw new ArgumentException("Cannot create polling operation from streaming response.", nameof(response)); + } + + Value = value; + Status = status; + + ThreadId = value.ThreadId; + RunId = value.Id; + + RehydrationToken = new RunOperationToken(value.ThreadId, value.Id); + } + + // For use with streaming convenience methods - response hasn't been provided yet. + internal RunOperation( + ClientPipeline pipeline, + Uri endpoint) + : base() + { + _pipeline = pipeline; + _endpoint = endpoint; + + // This constructor is provided for streaming convenience method only. + // Because of this, we don't set the polling interval type. + } + + // Note: these all have to be nullable because the derived streaming type + // cannot set them until it reads the first event from the SSE stream. + public string? RunId { get => _runId; protected set => _runId = value; } + public string? ThreadId { get => _threadId; protected set => _threadId = value; } + + public ThreadRun? Value { get; protected set; } + public RunStatus? Status { get; protected set; } + + #region OperationResult methods + + public virtual async Task WaitUntilStoppedAsync(CancellationToken cancellationToken = default) + => await WaitUntilStoppedAsync(default, cancellationToken).ConfigureAwait(false); + + public virtual void WaitUntilStopped(CancellationToken cancellationToken = default) + => WaitUntilStopped(default, cancellationToken); + + public virtual async Task WaitUntilStoppedAsync(TimeSpan? pollingInterval, CancellationToken cancellationToken = default) + { + if (IsStreaming) + { + // We would have to read from the stream to get the run ID to poll for. + throw new NotSupportedException("Cannot poll for status updates from streaming operation."); + } + + await foreach (ThreadRun update in GetUpdatesAsync(pollingInterval, cancellationToken)) + { + // Don't keep polling if would do so infinitely. + if (update.Status == RunStatus.RequiresAction) + { + return; + } + } + } + + public virtual void WaitUntilStopped(TimeSpan? pollingInterval, CancellationToken cancellationToken = default) + { + if (IsStreaming) + { + // We would have to read from the stream to get the run ID to poll for. + throw new NotSupportedException("Cannot poll for status updates from streaming operation."); + } + + foreach (ThreadRun update in GetUpdates(pollingInterval, cancellationToken)) + { + // Don't keep polling if would do so infinitely. + if (update.Status == RunStatus.RequiresAction) + { + return; + } + } + } + + // Expose enumerable APIs similar to the streaming ones. + public virtual async IAsyncEnumerable GetUpdatesAsync( + TimeSpan? pollingInterval = default, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (pollingInterval is not null) + { + // TODO: don't allocate + _pollingInterval = new PollingInterval(pollingInterval); + } + + IAsyncEnumerator> enumerator = + new RunOperationUpdateEnumerator(_pipeline, _endpoint, _threadId!, _runId!, cancellationToken); + + while (await enumerator.MoveNextAsync().ConfigureAwait(false)) + { + ApplyUpdate(enumerator.Current); + + yield return enumerator.Current; + + // TODO: do we need null check? + await _pollingInterval!.WaitAsync(cancellationToken).ConfigureAwait(false); + } + } + + public virtual IEnumerable GetUpdates( + TimeSpan? pollingInterval = default, + CancellationToken cancellationToken = default) + { + if (pollingInterval is not null) + { + // TODO: don't allocate + _pollingInterval = new PollingInterval(pollingInterval); + } + + IEnumerator> enumerator = new RunOperationUpdateEnumerator( + _pipeline, _endpoint, _threadId!, _runId!, cancellationToken); + + while (enumerator.MoveNext()) + { + ApplyUpdate(enumerator.Current); + + yield return enumerator.Current; + + // TODO: do we need null check? + _pollingInterval!.Wait(); + } + } + + private void ApplyUpdate(ClientResult update) + { + Value = update; + Status = update.Value.Status; + IsCompleted = Status.Value.IsTerminal; + + SetRawResponse(update.GetRawResponse()); + } + + #endregion + + #region Convenience overloads of generated protocol methods + + /// + /// Gets an existing from a known . + /// + /// A token that can be used to cancel this method call. + /// The existing instance. + public virtual async Task> GetRunAsync(CancellationToken cancellationToken = default) + { + ClientResult protocolResult = await GetRunAsync(_threadId!, _runId!, cancellationToken.ToRequestOptions()).ConfigureAwait(false); + return CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse); + } + + /// + /// Gets an existing from a known . + /// + /// A token that can be used to cancel this method call. + /// The existing instance. + public virtual ClientResult GetRun(CancellationToken cancellationToken = default) + { + ClientResult protocolResult = GetRun(_threadId!, _runId!, cancellationToken.ToRequestOptions()); + return CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse); + } + + /// + /// Cancels an in-progress . + /// + /// A token that can be used to cancel this method call. + /// An updated instance, reflecting the new status of the run. + public virtual async Task> CancelRunAsync(CancellationToken cancellationToken = default) + { + ClientResult protocolResult = await CancelRunAsync(_threadId!, _runId!, cancellationToken.ToRequestOptions()).ConfigureAwait(false); + return CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse); + } + + /// + /// Cancels an in-progress . + /// + /// A token that can be used to cancel this method call. + /// An updated instance, reflecting the new status of the run. + public virtual ClientResult CancelRun(CancellationToken cancellationToken = default) + { + ClientResult protocolResult = CancelRun(_threadId!, _runId!, cancellationToken.ToRequestOptions()); + return CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse); + } + + /// + /// Submits a collection of required tool call outputs to a run and resumes the run. + /// + /// + /// The tool outputs, corresponding to instances from the run. + /// + /// A token that can be used to cancel this method call. + /// The , updated after the submission was processed. + public virtual async Task SubmitToolOutputsToRunAsync( + IEnumerable toolOutputs, + CancellationToken cancellationToken = default) + { + BinaryContent content = new InternalSubmitToolOutputsRunRequest(toolOutputs).ToBinaryContent(); + ClientResult protocolResult = await SubmitToolOutputsToRunAsync(_threadId!, _runId!, content, cancellationToken.ToRequestOptions()) + .ConfigureAwait(false); + ClientResult update = CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse); + ApplyUpdate(update); + } + + /// + /// Submits a collection of required tool call outputs to a run and resumes the run. + /// + /// + /// The tool outputs, corresponding to instances from the run. + /// + /// A token that can be used to cancel this method call. + /// The , updated after the submission was processed. + public virtual void SubmitToolOutputsToRun( + IEnumerable toolOutputs, + CancellationToken cancellationToken = default) + { + BinaryContent content = new InternalSubmitToolOutputsRunRequest(toolOutputs).ToBinaryContent(); + ClientResult protocolResult = SubmitToolOutputsToRun(_threadId!, _runId!, content, cancellationToken.ToRequestOptions()); + ClientResult update = CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse); + ApplyUpdate(update); + } + + /// + /// Gets a page collection holding instances associated with a . + /// + /// + /// A token that can be used to cancel this method call. + /// holds pages of values. To obtain a collection of values, call + /// . To obtain the current + /// page of values, call . + /// A collection of pages of . + public virtual AsyncPageCollection GetRunStepsAsync( + RunStepCollectionOptions? options = default, + CancellationToken cancellationToken = default) + { + RunStepsPageEnumerator enumerator = new(_pipeline, _endpoint, + _threadId!, + _runId!, + options?.PageSize, + options?.Order?.ToString(), + options?.AfterId, + options?.BeforeId, + cancellationToken.ToRequestOptions()); + + return PageCollectionHelpers.CreateAsync(enumerator); + } + + /// + /// Rehydrates a page collection holding instances from a page token. + /// + /// Page token corresponding to the first page of the collection to rehydrate. + /// A token that can be used to cancel this method call. + /// holds pages of values. To obtain a collection of values, call + /// . To obtain the current + /// page of values, call . + /// A collection of pages of . + public virtual AsyncPageCollection GetRunStepsAsync( + ContinuationToken firstPageToken, + CancellationToken cancellationToken = default) + { + Argument.AssertNotNull(firstPageToken, nameof(firstPageToken)); + + RunStepsPageToken pageToken = RunStepsPageToken.FromToken(firstPageToken); + RunStepsPageEnumerator enumerator = new(_pipeline, _endpoint, + pageToken.ThreadId, + pageToken.RunId, + pageToken.Limit, + pageToken.Order, + pageToken.After, + pageToken.Before, + cancellationToken.ToRequestOptions()); + + return PageCollectionHelpers.CreateAsync(enumerator); + } + + /// + /// Gets a page collection holding instances associated with a . + /// + /// + /// A token that can be used to cancel this method call. + /// holds pages of values. To obtain a collection of values, call + /// . To obtain the current + /// page of values, call . + /// A collection of pages of . + public virtual PageCollection GetRunSteps( + RunStepCollectionOptions? options = default, + CancellationToken cancellationToken = default) + { + RunStepsPageEnumerator enumerator = new(_pipeline, _endpoint, + ThreadId!, + RunId!, + options?.PageSize, + options?.Order?.ToString(), + options?.AfterId, + options?.BeforeId, + cancellationToken.ToRequestOptions()); + + return PageCollectionHelpers.Create(enumerator); + } + + /// + /// Rehydrates a page collection holding instances from a page token. + /// + /// Page token corresponding to the first page of the collection to rehydrate. + /// A token that can be used to cancel this method call. + /// holds pages of values. To obtain a collection of values, call + /// . To obtain the current + /// page of values, call . + /// A collection of pages of . + public virtual PageCollection GetRunSteps( + ContinuationToken firstPageToken, + CancellationToken cancellationToken = default) + { + Argument.AssertNotNull(firstPageToken, nameof(firstPageToken)); + + RunStepsPageToken pageToken = RunStepsPageToken.FromToken(firstPageToken); + RunStepsPageEnumerator enumerator = new(_pipeline, _endpoint, + pageToken.ThreadId, + pageToken.RunId, + pageToken.Limit, + pageToken.Order, + pageToken.After, + pageToken.Before, + cancellationToken.ToRequestOptions()); + + return PageCollectionHelpers.Create(enumerator); + } + + /// + /// Gets a single run step from a run. + /// + /// The ID of the run step. + /// A token that can be used to cancel this method call. + /// A instance corresponding to the specified step. + public virtual async Task> GetRunStepAsync(string stepId, CancellationToken cancellationToken = default) + { + ClientResult protocolResult = await GetRunStepAsync(_threadId!, _runId!, stepId, cancellationToken.ToRequestOptions()).ConfigureAwait(false); + return CreateResultFromProtocol(protocolResult, RunStep.FromResponse); + } + + /// + /// Gets a single run step from a run. + /// + /// The ID of the run step. + /// A token that can be used to cancel this method call. + /// A instance corresponding to the specified step. + public virtual ClientResult GetRunStep(string stepId, CancellationToken cancellationToken = default) + { + ClientResult protocolResult = GetRunStep(_threadId!, _runId!, stepId, cancellationToken.ToRequestOptions()); + return CreateResultFromProtocol(protocolResult, RunStep.FromResponse); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ClientResult CreateResultFromProtocol(ClientResult protocolResult, Func responseDeserializer) + { + PipelineResponse pipelineResponse = protocolResult.GetRawResponse(); + T deserializedResultValue = responseDeserializer.Invoke(pipelineResponse); + return ClientResult.FromValue(deserializedResultValue, pipelineResponse); + } + + #endregion +} \ No newline at end of file diff --git a/src/To.Be.Generated/RunOperationToken.cs b/src/To.Be.Generated/RunOperationToken.cs new file mode 100644 index 00000000..b57dec15 --- /dev/null +++ b/src/To.Be.Generated/RunOperationToken.cs @@ -0,0 +1,102 @@ +using System; +using System.ClientModel; +using System.Diagnostics; +using System.IO; +using System.Text.Json; + +#nullable enable + +namespace OpenAI.Assistants; + +internal class RunOperationToken : ContinuationToken +{ + public RunOperationToken(string threadId, string runId) + { + ThreadId = threadId; + RunId = runId; + } + + public string ThreadId { get; } + + public string RunId { get; } + + public override BinaryData ToBytes() + { + using MemoryStream stream = new(); + using Utf8JsonWriter writer = new(stream); + + writer.WriteStartObject(); + + writer.WriteString("threadId", ThreadId); + writer.WriteString("runId", RunId); + + writer.WriteEndObject(); + + writer.Flush(); + stream.Position = 0; + + return BinaryData.FromStream(stream); + } + + public static RunOperationToken FromToken(ContinuationToken continuationToken) + { + if (continuationToken is RunOperationToken token) + { + return token; + } + + BinaryData data = continuationToken.ToBytes(); + + if (data.ToMemory().Length == 0) + { + throw new ArgumentException("Failed to create RunOperationToken from provided continuationToken.", nameof(continuationToken)); + } + + Utf8JsonReader reader = new(data); + + string threadId = null!; + string runId = null!; + + reader.Read(); + + Debug.Assert(reader.TokenType == JsonTokenType.StartObject); + + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndObject) + { + break; + } + + Debug.Assert(reader.TokenType == JsonTokenType.PropertyName); + + string propertyName = reader.GetString()!; + + switch (propertyName) + { + case "threadId": + reader.Read(); + Debug.Assert(reader.TokenType == JsonTokenType.String); + threadId = reader.GetString()!; + break; + + case "runId": + reader.Read(); + Debug.Assert(reader.TokenType == JsonTokenType.String); + threadId = reader.GetString()!; + break; + + default: + throw new JsonException($"Unrecognized property '{propertyName}'."); + } + } + + if (threadId is null || runId is null) + { + throw new ArgumentException("Failed to create RunOperationToken from provided continuationToken.", nameof(continuationToken)); + } + + return new(threadId, runId); + } +} + diff --git a/src/To.Be.Generated/RunOperationUpdateEnumerator.Protocol.cs b/src/To.Be.Generated/RunOperationUpdateEnumerator.Protocol.cs new file mode 100644 index 00000000..3185d0b0 --- /dev/null +++ b/src/To.Be.Generated/RunOperationUpdateEnumerator.Protocol.cs @@ -0,0 +1,175 @@ +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections; +using System.Collections.Generic; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +#nullable enable + +namespace OpenAI.Assistants; + +internal partial class RunOperationUpdateEnumerator : + IAsyncEnumerator, + IEnumerator +{ + private readonly ClientPipeline _pipeline; + private readonly Uri _endpoint; + private readonly string _threadId; + private readonly string _runId; + private readonly CancellationToken _cancellationToken; + + private ClientResult? _current; + private bool _hasNext = true; + + public RunOperationUpdateEnumerator( + ClientPipeline pipeline, + Uri endpoint, + string threadId, + string runId, + CancellationToken cancellationToken) + { + _pipeline = pipeline; + _endpoint = endpoint; + + _threadId = threadId; + _runId = runId; + + _cancellationToken = cancellationToken; + } + + public ClientResult Current => _current!; + + #region IEnumerator methods + + object IEnumerator.Current => _current!; + + bool IEnumerator.MoveNext() + { + if (!_hasNext) + { + _current = null; + return false; + } + + ClientResult result = GetRun(_threadId, _runId, _cancellationToken.ToRequestOptions()); + + _current = result; + _hasNext = HasNext(result); + + return true; + } + + void IEnumerator.Reset() => _current = null; + + void IDisposable.Dispose() { } + + #endregion + + #region IAsyncEnumerator methods + + ClientResult IAsyncEnumerator.Current => _current!; + + public async ValueTask MoveNextAsync() + { + if (!_hasNext) + { + _current = null; + return false; + } + + ClientResult result = await GetRunAsync(_threadId, _runId, _cancellationToken.ToRequestOptions()).ConfigureAwait(false); + + _current = result; + _hasNext = HasNext(result); + + return true; + } + + // TODO: handle Dispose and DisposeAsync using proper patterns? + ValueTask IAsyncDisposable.DisposeAsync() => default; + + #endregion + + // Methods used by both implementations + + private bool HasNext(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + // TODO: don't parse JsonDocument twice if possible + using JsonDocument doc = JsonDocument.Parse(response.Content); + string? status = doc.RootElement.GetProperty("status"u8).GetString(); + + bool isComplete = status == "expired" || + status == "completed" || + status == "failed" || + status == "incomplete" || + status == "cancelled"; + + return !isComplete; + } + + // Generated methods + + /// + /// [Protocol Method] Retrieves a run. + /// + /// The ID of the [thread](/docs/api-reference/threads) that was run. + /// The ID of the run to retrieve. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// or is null. + /// or is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + public virtual async Task GetRunAsync(string threadId, string runId, RequestOptions? options) + { + Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); + Argument.AssertNotNullOrEmpty(runId, nameof(runId)); + + using PipelineMessage message = CreateGetRunRequest(threadId, runId, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + + /// + /// [Protocol Method] Retrieves a run. + /// + /// The ID of the [thread](/docs/api-reference/threads) that was run. + /// The ID of the run to retrieve. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// or is null. + /// or is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + public virtual ClientResult GetRun(string threadId, string runId, RequestOptions? options) + { + Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); + Argument.AssertNotNullOrEmpty(runId, nameof(runId)); + + using PipelineMessage message = CreateGetRunRequest(threadId, runId, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + + internal PipelineMessage CreateGetRunRequest(string threadId, string runId, RequestOptions? options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "GET"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/threads/", false); + uri.AppendPath(threadId, true); + uri.AppendPath("/runs/", false); + uri.AppendPath(runId, true); + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + message.Apply(options); + return message; + } + + private static PipelineMessageClassifier? _pipelineMessageClassifier200; + private static PipelineMessageClassifier PipelineMessageClassifier200 => _pipelineMessageClassifier200 ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200 }); +} diff --git a/src/To.Be.Generated/RunOperationUpdateEnumerator.cs b/src/To.Be.Generated/RunOperationUpdateEnumerator.cs new file mode 100644 index 00000000..8d485771 --- /dev/null +++ b/src/To.Be.Generated/RunOperationUpdateEnumerator.cs @@ -0,0 +1,54 @@ +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; + +#nullable enable + +namespace OpenAI.Assistants; + +internal partial class RunOperationUpdateEnumerator : + IAsyncEnumerator>, + IEnumerator> +{ + #region IEnumerator> methods + + ClientResult IEnumerator>.Current + { + get + { + if (Current is null) + { + return default!; + } + + return GetUpdateFromResult(Current); + } + } + + #endregion + + #region IAsyncEnumerator> methods + + ClientResult IAsyncEnumerator>.Current + { + get + { + if (Current is null) + { + return default!; + } + + return GetUpdateFromResult(Current); + } + } + + #endregion + + // Methods used by convenience implementation + private ClientResult GetUpdateFromResult(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + ThreadRun run = ThreadRun.FromResponse(response); + return ClientResult.FromValue(run, response); + } +} diff --git a/src/To.Be.Generated/StreamingRunOperation.cs b/src/To.Be.Generated/StreamingRunOperation.cs new file mode 100644 index 00000000..d47542fe --- /dev/null +++ b/src/To.Be.Generated/StreamingRunOperation.cs @@ -0,0 +1,280 @@ +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +#nullable enable + +namespace OpenAI.Assistants; + +// Streaming version +public partial class StreamingRunOperation : RunOperation +{ + private readonly Func> _createRunAsync; + private readonly Func _createRun; + + + private StreamingRunOperationUpdateEnumerator? _enumerator; + + internal StreamingRunOperation( + ClientPipeline pipeline, + Uri endpoint, + + // Note if we pass funcs we don't need to pass in the pipeline. + Func> createRunAsync, + Func createRun) + : base(pipeline, endpoint) + { + _createRunAsync = createRunAsync; + _createRun = createRun; + } + + // TODO: this duplicates a field on the base type. Address? + public override bool IsCompleted { get; protected set; } + + public override async Task WaitUntilStoppedAsync(CancellationToken cancellationToken = default) + { + // TODO: add validation that stream is only requested and enumerated once. + // TODO: Make sure you can't create the same run twice and/or submit tools twice + // somehow, even accidentally. + + await foreach (StreamingUpdate update in GetUpdatesStreamingAsync(cancellationToken).ConfigureAwait(false)) + { + // Should terminate naturally when get to "requires action" because + // the SSE stream will end. + } + } + + public override void WaitUntilStopped(CancellationToken cancellationToken = default) + { + foreach (StreamingUpdate update in GetUpdatesStreaming(cancellationToken)) + { + // Should terminate naturally when get to "requires action" because + // the SSE stream will end. + } + } + + // Public APIs specific to streaming LRO + public virtual async IAsyncEnumerable GetUpdatesStreamingAsync( + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + try + { + if (_enumerator is null) + { + AsyncStreamingUpdateCollection updates = new AsyncStreamingUpdateCollection(_createRunAsync); + _enumerator = new StreamingRunOperationUpdateEnumerator(updates); + } + + while (await _enumerator.MoveNextAsync().ConfigureAwait(false)) + { + if (_enumerator.Current is RunUpdate update) + { + ApplyUpdate(update); + } + + cancellationToken.ThrowIfCancellationRequested(); + + yield return _enumerator.Current; + } + } + finally + { + if (_enumerator != null) + { + await _enumerator.DisposeAsync(); + _enumerator = null; + } + } + } + + public virtual IEnumerable GetUpdatesStreaming(CancellationToken cancellationToken = default) + { + try + { + if (_enumerator is null) + { + StreamingUpdateCollection updates = new StreamingUpdateCollection(_createRun); + _enumerator = new StreamingRunOperationUpdateEnumerator(updates); + } + + while (_enumerator.MoveNext()) + { + if (_enumerator.Current is RunUpdate update) + { + ApplyUpdate(update); + } + + cancellationToken.ThrowIfCancellationRequested(); + + yield return _enumerator.Current; + } + } + finally + { + if (_enumerator != null) + { + _enumerator.Dispose(); + _enumerator = null; + } + } + } + + public override async IAsyncEnumerable GetUpdatesAsync(TimeSpan? pollingInterval = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (pollingInterval is not null) + { + throw new NotSupportedException("Cannot specify polling interval for streaming operation."); + } + + await foreach (StreamingUpdate update in GetUpdatesStreamingAsync(cancellationToken).ConfigureAwait(false)) + { + if (update is RunUpdate runUpdate) + { + yield return runUpdate; + } + } + } + + public override IEnumerable GetUpdates(TimeSpan? pollingInterval = null, CancellationToken cancellationToken = default) + { + if (pollingInterval is not null) + { + throw new NotSupportedException("Cannot specify polling interval for streaming operation."); + } + + foreach (StreamingUpdate update in GetUpdatesStreaming()) + { + if (update is RunUpdate runUpdate) + { + yield return runUpdate; + } + } + } + + private void ApplyUpdate(ThreadRun update) + { + RunId ??= update.Id; + ThreadId ??= update.ThreadId; + + Value = update; + Status = update.Status; + IsCompleted = update.Status.IsTerminal; + + SetRawResponse(_enumerator!.GetRawResponse()); + } + + public virtual async Task SubmitToolOutputsToRunStreamingAsync( + IEnumerable toolOutputs, + CancellationToken cancellationToken = default) + { + if (ThreadId is null || RunId is null) + { + throw new InvalidOperationException("Cannot submit tools until first update stream has been applied."); + } + + BinaryContent content = new InternalSubmitToolOutputsRunRequest( + toolOutputs.ToList(), stream: true, null).ToBinaryContent(); + + // TODO: can we do this the same way as this in the other method instead + // of having to take all those funcs? + async Task getResultAsync() => + await SubmitToolOutputsToRunAsync(ThreadId, RunId, content, cancellationToken.ToRequestOptions(streaming: true)) + .ConfigureAwait(false); + + AsyncStreamingUpdateCollection updates = new AsyncStreamingUpdateCollection(getResultAsync); + if (_enumerator is null) + { + _enumerator = new StreamingRunOperationUpdateEnumerator(updates); + } + else + { + await _enumerator.ReplaceUpdateCollectionAsync(updates).ConfigureAwait(false); + } + } + + public virtual void SubmitToolOutputsToRunStreaming( + IEnumerable toolOutputs, + CancellationToken cancellationToken = default) + { + if (ThreadId is null || RunId is null) + { + throw new InvalidOperationException("Cannot submit tools until first update stream has been applied."); + } + + if (_enumerator is null) + { + throw new InvalidOperationException( + "Cannot submit tools until first run update stream has been enumerated. " + + "Call 'Wait' or 'GetUpdatesStreaming' to read update stream."); + } + + BinaryContent content = new InternalSubmitToolOutputsRunRequest( + toolOutputs.ToList(), stream: true, null).ToBinaryContent(); + + // TODO: can we do this the same way as this in the other method instead + // of having to take all those funcs? + ClientResult getResult() => + SubmitToolOutputsToRun(ThreadId, RunId, content, cancellationToken.ToRequestOptions(streaming: true)); + + StreamingUpdateCollection updates = new StreamingUpdateCollection(getResult); + if (_enumerator is null) + { + _enumerator = new StreamingRunOperationUpdateEnumerator(updates); + } + else + { + _enumerator.ReplaceUpdateCollection(updates); + } + } + + #region hide + + //// used to defer first request. + //internal virtual async Task CreateRunAsync(string threadId, BinaryContent content, RequestOptions? options = null) + //{ + // Argument.AssertNotNullOrEmpty(threadId, nameof(threadId)); + // Argument.AssertNotNull(content, nameof(content)); + + // PipelineMessage? message = null; + // try + // { + // message = CreateCreateRunRequest(threadId, content, options); + // return ClientResult.FromResponse(await Pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + // } + // finally + // { + // if (options?.BufferResponse != false) + // { + // message?.Dispose(); + // } + // } + //} + + //internal PipelineMessage CreateCreateRunRequest(string threadId, BinaryContent content, RequestOptions? options) + //{ + // var message = Pipeline.CreateMessage(); + // message.ResponseClassifier = PipelineMessageClassifier200; + // var request = message.Request; + // request.Method = "POST"; + // var uri = new ClientUriBuilder(); + // uri.Reset(_endpoint); + // uri.AppendPath("/threads/", false); + // uri.AppendPath(threadId, true); + // uri.AppendPath("/runs", false); + // request.Uri = uri.ToUri(); + // request.Headers.Set("Accept", "application/json"); + // request.Headers.Set("Content-Type", "application/json"); + // request.Content = content; + // message.Apply(options); + // return message; + //} + + //private static PipelineMessageClassifier? _pipelineMessageClassifier200; + //private static PipelineMessageClassifier PipelineMessageClassifier200 => _pipelineMessageClassifier200 ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200 }); + #endregion +} diff --git a/src/To.Be.Generated/StreamingRunOperationUpdateEnumerator.cs b/src/To.Be.Generated/StreamingRunOperationUpdateEnumerator.cs new file mode 100644 index 00000000..93a933cb --- /dev/null +++ b/src/To.Be.Generated/StreamingRunOperationUpdateEnumerator.cs @@ -0,0 +1,145 @@ +using System; +using System.ClientModel.Primitives; +using System.Collections; +using System.Collections.Generic; +using System.Threading.Tasks; + +#nullable enable + +namespace OpenAI.Assistants; + +internal partial class StreamingRunOperationUpdateEnumerator : + IAsyncEnumerator, + IEnumerator +{ + private StreamingUpdate? _current; + + private AsyncStreamingUpdateCollection? _asyncUpdates; + private IAsyncEnumerator? _asyncEnumerator; + + private StreamingUpdateCollection? _updates; + private IEnumerator? _enumerator; + + public StreamingRunOperationUpdateEnumerator( + AsyncStreamingUpdateCollection updates) + { + _asyncUpdates = updates; + } + + public StreamingRunOperationUpdateEnumerator( + StreamingUpdateCollection updates) + { + _updates = updates; + } + + // Cache this here for now + public PipelineResponse GetRawResponse() => + _asyncUpdates?.GetRawResponse() ?? + _updates?.GetRawResponse() ?? + throw new InvalidOperationException("No response available."); + + public StreamingUpdate Current => _current!; + + #region IEnumerator methods + + object IEnumerator.Current => _current!; + + public bool MoveNext() + { + if (_updates is null) + { + throw new InvalidOperationException("Cannot MoveNext after starting enumerator asynchronously."); + } + + _enumerator ??= _updates.GetEnumerator(); + + bool movedNext = _enumerator.MoveNext(); + _current = _enumerator.Current; + return movedNext; + } + + void IEnumerator.Reset() + { + throw new NotSupportedException("Cannot reset streaming enumerator."); + } + + public void Dispose() + { + if (_enumerator != null) + { + _enumerator.Dispose(); + _enumerator = null; + } + } + + #endregion + + #region IAsyncEnumerator methods + + public async ValueTask MoveNextAsync() + { + if (_asyncUpdates is null) + { + throw new InvalidOperationException("Cannot MoveNextAsync after starting enumerator synchronously."); + } + + _asyncEnumerator ??= _asyncUpdates.GetAsyncEnumerator(); + + bool movedNext = await _asyncEnumerator.MoveNextAsync().ConfigureAwait(false); + _current = _asyncEnumerator.Current; + return movedNext; + } + + public async ValueTask DisposeAsync() + { + // TODO: implement according to pattern + + if (_asyncEnumerator is null) + { + return; + } + + await _asyncEnumerator.DisposeAsync().ConfigureAwait(false); + } + + #endregion + + public async Task ReplaceUpdateCollectionAsync(AsyncStreamingUpdateCollection updates) + { + if (_asyncUpdates is null) + { + throw new InvalidOperationException("Cannot replace null update collection."); + } + + if (_updates is not null || _enumerator is not null) + { + throw new InvalidOperationException("Cannot being enumerating asynchronously after enumerating synchronously."); + } + + if (_asyncEnumerator is not null) + { + await _asyncEnumerator.DisposeAsync().ConfigureAwait(false); + _asyncEnumerator = null; + } + + _asyncUpdates = updates; + } + + public void ReplaceUpdateCollection(StreamingUpdateCollection updates) + { + if (_updates is null) + { + throw new InvalidOperationException("Cannot replace null update collection."); + } + + if (_asyncUpdates is not null || _asyncEnumerator is not null) + { + throw new InvalidOperationException("Cannot being enumerating synchronously after enumerating asynchronously."); + } + + _enumerator?.Dispose(); + _enumerator = null; + + _updates = updates; + } +} diff --git a/src/To.Be.Generated/VectorStoreFileBatchOperation.Protocol.cs b/src/To.Be.Generated/VectorStoreFileBatchOperation.Protocol.cs new file mode 100644 index 00000000..c95cec02 --- /dev/null +++ b/src/To.Be.Generated/VectorStoreFileBatchOperation.Protocol.cs @@ -0,0 +1,343 @@ +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.ComponentModel; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +#nullable enable + +namespace OpenAI.VectorStores; + +// Protocol version +public partial class VectorStoreFileBatchOperation : OperationResult +{ + private readonly ClientPipeline _pipeline; + private readonly Uri _endpoint; + + private readonly string _vectorStoreId; + private readonly string _batchId; + + private PollingInterval _pollingInterval; + + // For use with protocol methods where the response has been obtained prior + // to creation of the LRO instance. + internal VectorStoreFileBatchOperation( + ClientPipeline pipeline, + Uri endpoint, + string vectorStoreId, + string batchId, + string status, + PipelineResponse response) + : base(response) + { + _pipeline = pipeline; + _endpoint = endpoint; + + _vectorStoreId = vectorStoreId; + _batchId = batchId; + + IsCompleted = GetIsCompleted(status); + + _pollingInterval = new(); + + RehydrationToken = new VectorStoreFileBatchOperationToken(vectorStoreId, batchId); + } + + public override ContinuationToken? RehydrationToken { get; protected set; } + + public override bool IsCompleted { get; protected set; } + + // These are replaced when LRO is evolved to have conveniences + //public override async Task WaitAsync(CancellationToken cancellationToken = default) + //{ + // IAsyncEnumerator enumerator = + // new VectorStoreFileBatchOperationUpdateEnumerator( + // _pipeline, _endpoint, _vectorStoreId, _batchId, _options); + + // while (await enumerator.MoveNextAsync().ConfigureAwait(false)) + // { + // ApplyUpdate(enumerator.Current); + + // cancellationToken.ThrowIfCancellationRequested(); + + // // TODO: Plumb through cancellation token + // await _pollingInterval.WaitAsync(); + // } + //} + + //public override void Wait(CancellationToken cancellationToken = default) + //{ + // IEnumerator enumerator = + // new VectorStoreFileBatchOperationUpdateEnumerator( + // _pipeline, _endpoint, _vectorStoreId, _batchId, _options); + + // while (enumerator.MoveNext()) + // { + // ApplyUpdate(enumerator.Current); + + // cancellationToken.ThrowIfCancellationRequested(); + + // _pollingInterval.Wait(); + // } + //} + + private void ApplyUpdate(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + using JsonDocument doc = JsonDocument.Parse(response.Content); + string? status = doc.RootElement.GetProperty("status"u8).GetString(); + + IsCompleted = GetIsCompleted(status); + SetRawResponse(response); + } + + private static bool GetIsCompleted(string? status) + { + return status == "completed" || + status == "cancelled" || + status == "failed"; + } + + // Generated protocol methods + + /// + /// [Protocol Method] Retrieves a vector store file batch. + /// + /// The ID of the vector store that the file batch belongs to. + /// The ID of the file batch being retrieved. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// or is null. + /// or is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual async Task GetBatchFileJobAsync(string vectorStoreId, string batchId, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); + Argument.AssertNotNullOrEmpty(batchId, nameof(batchId)); + + using PipelineMessage message = CreateGetVectorStoreFileBatchRequest(vectorStoreId, batchId, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + + /// + /// [Protocol Method] Retrieves a vector store file batch. + /// + /// The ID of the vector store that the file batch belongs to. + /// The ID of the file batch being retrieved. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// or is null. + /// or is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual ClientResult GetBatchFileJob(string vectorStoreId, string batchId, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); + Argument.AssertNotNullOrEmpty(batchId, nameof(batchId)); + + using PipelineMessage message = CreateGetVectorStoreFileBatchRequest(vectorStoreId, batchId, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + + /// + /// [Protocol Method] Cancel a vector store file batch. This attempts to cancel the processing of files in this batch as soon as possible. + /// + /// The ID of the vector store that the file batch belongs to. + /// The ID of the file batch to cancel. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// or is null. + /// or is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual async Task CancelBatchFileJobAsync(string vectorStoreId, string batchId, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); + Argument.AssertNotNullOrEmpty(batchId, nameof(batchId)); + + using PipelineMessage message = CreateCancelVectorStoreFileBatchRequest(vectorStoreId, batchId, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + + /// + /// [Protocol Method] Cancel a vector store file batch. This attempts to cancel the processing of files in this batch as soon as possible. + /// + /// The ID of the vector store that the file batch belongs to. + /// The ID of the file batch to cancel. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// or is null. + /// or is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual ClientResult CancelBatchFileJob(string vectorStoreId, string batchId, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); + Argument.AssertNotNullOrEmpty(batchId, nameof(batchId)); + + using PipelineMessage message = CreateCancelVectorStoreFileBatchRequest(vectorStoreId, batchId, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + + /// + /// [Protocol Method] Returns a paginated collection of vector store files in a batch. + /// + /// The ID of the vector store that the file batch belongs to. + /// The ID of the file batch that the files belong to. + /// + /// A limit on the number of objects to be returned. Limit can range between 1 and 100, and the + /// default is 20. + /// + /// + /// Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and`desc` + /// for descending order. Allowed values: "asc" | "desc" + /// + /// + /// A cursor for use in pagination. `after` is an object ID that defines your place in the list. + /// For instance, if you make a list request and receive 100 objects, ending with obj_foo, your + /// subsequent call can include after=obj_foo in order to fetch the next page of the list. + /// + /// + /// A cursor for use in pagination. `before` is an object ID that defines your place in the list. + /// For instance, if you make a list request and receive 100 objects, ending with obj_foo, your + /// subsequent call can include before=obj_foo in order to fetch the previous page of the list. + /// + /// Filter by file status. One of `in_progress`, `completed`, `failed`, `cancelled`. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// or is null. + /// or is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// A collection of service responses, each holding a page of values. + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual IAsyncEnumerable GetFileAssociationsAsync(string vectorStoreId, string batchId, int? limit, string order, string after, string before, string filter, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); + Argument.AssertNotNullOrEmpty(batchId, nameof(batchId)); + + VectorStoreFileBatchesPageEnumerator enumerator = new VectorStoreFileBatchesPageEnumerator(_pipeline, _endpoint, vectorStoreId, batchId, limit, order, after, before, filter, options); + return PageCollectionHelpers.CreateAsync(enumerator); + } + + /// + /// [Protocol Method] Returns a paginated collection of vector store files in a batch. + /// + /// The ID of the vector store that the file batch belongs to. + /// The ID of the file batch that the files belong to. + /// + /// A limit on the number of objects to be returned. Limit can range between 1 and 100, and the + /// default is 20. + /// + /// + /// Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and`desc` + /// for descending order. Allowed values: "asc" | "desc" + /// + /// + /// A cursor for use in pagination. `after` is an object ID that defines your place in the list. + /// For instance, if you make a list request and receive 100 objects, ending with obj_foo, your + /// subsequent call can include after=obj_foo in order to fetch the next page of the list. + /// + /// + /// A cursor for use in pagination. `before` is an object ID that defines your place in the list. + /// For instance, if you make a list request and receive 100 objects, ending with obj_foo, your + /// subsequent call can include before=obj_foo in order to fetch the previous page of the list. + /// + /// Filter by file status. One of `in_progress`, `completed`, `failed`, `cancelled`. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// or is null. + /// or is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// A collection of service responses, each holding a page of values. + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual IEnumerable GetFileAssociations(string vectorStoreId, string batchId, int? limit, string order, string after, string before, string filter, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); + Argument.AssertNotNullOrEmpty(batchId, nameof(batchId)); + + VectorStoreFileBatchesPageEnumerator enumerator = new VectorStoreFileBatchesPageEnumerator(_pipeline, _endpoint, vectorStoreId, batchId, limit, order, after, before, filter, options); + return PageCollectionHelpers.Create(enumerator); + } + + internal PipelineMessage CreateGetVectorStoreFileBatchRequest(string vectorStoreId, string batchId, RequestOptions options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "GET"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/vector_stores/", false); + uri.AppendPath(vectorStoreId, true); + uri.AppendPath("/file_batches/", false); + uri.AppendPath(batchId, true); + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + message.Apply(options); + return message; + } + + internal PipelineMessage CreateCancelVectorStoreFileBatchRequest(string vectorStoreId, string batchId, RequestOptions options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "POST"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/vector_stores/", false); + uri.AppendPath(vectorStoreId, true); + uri.AppendPath("/file_batches/", false); + uri.AppendPath(batchId, true); + uri.AppendPath("/cancel", false); + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + message.Apply(options); + return message; + } + + internal PipelineMessage CreateGetFilesInVectorStoreBatchesRequest(string vectorStoreId, string batchId, int? limit, string order, string after, string before, string filter, RequestOptions options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "GET"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/vector_stores/", false); + uri.AppendPath(vectorStoreId, true); + uri.AppendPath("/file_batches/", false); + uri.AppendPath(batchId, true); + uri.AppendPath("/files", false); + if (limit != null) + { + uri.AppendQuery("limit", limit.Value, true); + } + if (order != null) + { + uri.AppendQuery("order", order, true); + } + if (after != null) + { + uri.AppendQuery("after", after, true); + } + if (before != null) + { + uri.AppendQuery("before", before, true); + } + if (filter != null) + { + uri.AppendQuery("filter", filter, true); + } + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + message.Apply(options); + return message; + } + + private static PipelineMessageClassifier? _pipelineMessageClassifier200; + private static PipelineMessageClassifier PipelineMessageClassifier200 => _pipelineMessageClassifier200 ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200 }); +} \ No newline at end of file diff --git a/src/To.Be.Generated/VectorStoreFileBatchOperation.cs b/src/To.Be.Generated/VectorStoreFileBatchOperation.cs new file mode 100644 index 00000000..e6a3e6bf --- /dev/null +++ b/src/To.Be.Generated/VectorStoreFileBatchOperation.cs @@ -0,0 +1,291 @@ +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +#nullable enable + +namespace OpenAI.VectorStores; + +// Convenience version +public partial class VectorStoreFileBatchOperation : OperationResult +{ + // Convenience version + internal VectorStoreFileBatchOperation( + ClientPipeline pipeline, + Uri endpoint, + ClientResult result) + : base(result.GetRawResponse()) + { + _pipeline = pipeline; + _endpoint = endpoint; + + Value = result; + Status = Value.Status; + IsCompleted = GetIsCompleted(Value.Status); + + _vectorStoreId = Value.VectorStoreId; + _batchId = Value.BatchId; + + _pollingInterval = new(); + + RehydrationToken = new VectorStoreFileBatchOperationToken(VectorStoreId, BatchId); + } + + // TODO: interesting question regarding whether these properties should be + // nullable or not. If someone has called the protocol method, do they want + // to pay the perf cost of deserialization? This could capitalize on a + // property on RequestOptions that allows the caller to opt-in to creation + // of convenience models. For now, make them nullable so I don't have to + // pass the model into the constructor from a protocol method. + public VectorStoreBatchFileJob? Value { get; private set; } + public VectorStoreBatchFileJobStatus? Status { get; private set; } + + public string VectorStoreId { get => _vectorStoreId; } + public string BatchId { get => _batchId; } + + public override async Task WaitForCompletionAsync(CancellationToken cancellationToken = default) + { + IAsyncEnumerator> enumerator = + new VectorStoreFileBatchOperationUpdateEnumerator( + _pipeline, _endpoint, _vectorStoreId, _batchId, cancellationToken); + + while (await enumerator.MoveNextAsync().ConfigureAwait(false)) + { + ApplyUpdate(enumerator.Current); + + await _pollingInterval.WaitAsync(cancellationToken); + } + } + + public override void WaitForCompletion(CancellationToken cancellationToken = default) + { + IEnumerator> enumerator = + new VectorStoreFileBatchOperationUpdateEnumerator( + _pipeline, _endpoint, _vectorStoreId, _batchId, cancellationToken); + + while (enumerator.MoveNext()) + { + ApplyUpdate(enumerator.Current); + + cancellationToken.ThrowIfCancellationRequested(); + + _pollingInterval.Wait(); + } + } + + private void ApplyUpdate(ClientResult update) + { + Value = update; + Status = Value.Status; + + IsCompleted = GetIsCompleted(Value.Status); + SetRawResponse(update.GetRawResponse()); + } + + private static bool GetIsCompleted(VectorStoreBatchFileJobStatus status) + { + return status == VectorStoreBatchFileJobStatus.Completed || + status == VectorStoreBatchFileJobStatus.Cancelled || + status == VectorStoreBatchFileJobStatus.Failed; + } + + // Generated convenience methods + + /// + /// Gets an existing vector store batch file ingestion job from a known vector store ID and job ID. + /// + /// A token that can be used to cancel this method call. + /// A instance representing the ingestion operation. + public virtual async Task> GetBatchFileJobAsync(CancellationToken cancellationToken = default) + { + ClientResult result = await GetBatchFileJobAsync(_vectorStoreId, _batchId, cancellationToken.ToRequestOptions()).ConfigureAwait(false); + PipelineResponse response = result.GetRawResponse(); + VectorStoreBatchFileJob value = VectorStoreBatchFileJob.FromResponse(response); + return ClientResult.FromValue(value, response); + } + + /// + /// Gets an existing vector store batch file ingestion job from a known vector store ID and job ID. + /// + /// A token that can be used to cancel this method call. + /// A instance representing the ingestion operation. + public virtual ClientResult GetBatchFileJob(CancellationToken cancellationToken = default) + { + ClientResult result = GetBatchFileJob(_vectorStoreId, _batchId, cancellationToken.ToRequestOptions()); + PipelineResponse response = result.GetRawResponse(); + VectorStoreBatchFileJob value = VectorStoreBatchFileJob.FromResponse(response); + return ClientResult.FromValue(value, response); + } + + /// + /// Cancels an in-progress . + /// + /// A token that can be used to cancel this method call. + /// An updated instance. + public virtual async Task> CancelBatchFileJobAsync(CancellationToken cancellationToken = default) + { + ClientResult result = await CancelBatchFileJobAsync(_vectorStoreId, _batchId, cancellationToken.ToRequestOptions()).ConfigureAwait(false); + PipelineResponse response = result.GetRawResponse(); + VectorStoreBatchFileJob value = VectorStoreBatchFileJob.FromResponse(response); + return ClientResult.FromValue(value, response); + } + + /// + /// Cancels an in-progress . + /// + /// A token that can be used to cancel this method call. + /// An updated instance. + public virtual ClientResult CancelBatchFileJob(CancellationToken cancellationToken = default) + { + ClientResult result = CancelBatchFileJob(_vectorStoreId, _batchId, cancellationToken.ToRequestOptions()); + PipelineResponse response = result.GetRawResponse(); + VectorStoreBatchFileJob value = VectorStoreBatchFileJob.FromResponse(response); + return ClientResult.FromValue(value, response); + } + + /// + /// Gets a page collection of file associations associated with a vector store batch file job, representing the files + /// that were scheduled for ingestion into the vector store. + /// + /// Options describing the collection to return. + /// A token that can be used to cancel this method call. + /// holds pages of values. To obtain a collection of values, call + /// . To obtain the current + /// page of values, call . + /// A collection of pages of . + public virtual AsyncPageCollection GetFileAssociationsAsync( + VectorStoreFileAssociationCollectionOptions? options = default, + CancellationToken cancellationToken = default) + { + VectorStoreFileBatchesPageEnumerator enumerator = new(_pipeline, _endpoint, + _vectorStoreId, + _batchId, + options?.PageSize, + options?.Order?.ToString(), + options?.AfterId, + options?.BeforeId, + options?.Filter?.ToString(), + cancellationToken.ToRequestOptions()); + + return PageCollectionHelpers.CreateAsync(enumerator); + } + + /// + /// Rehydrates a page collection of file associations from a page token. + /// + /// Page token corresponding to the first page of the collection to rehydrate. + /// A token that can be used to cancel this method call. + /// holds pages of values. To obtain a collection of values, call + /// . To obtain the current + /// page of values, call . + /// A collection of pages of . + public virtual AsyncPageCollection GetFileAssociationsAsync( + ContinuationToken firstPageToken, + CancellationToken cancellationToken = default) + { + Argument.AssertNotNull(firstPageToken, nameof(firstPageToken)); + + VectorStoreFileBatchesPageToken pageToken = VectorStoreFileBatchesPageToken.FromToken(firstPageToken); + + if (_vectorStoreId != pageToken.VectorStoreId) + { + throw new ArgumentException( + "Invalid page token. 'VectorStoreId' value does not match page token value.", + nameof(firstPageToken)); + } + + if (_batchId != pageToken.BatchId) + { + throw new ArgumentException( + "Invalid page token. 'BatchId' value does not match page token value.", + nameof(firstPageToken)); + } + + VectorStoreFileBatchesPageEnumerator enumerator = new(_pipeline, _endpoint, + pageToken.VectorStoreId, + pageToken.BatchId, + pageToken.Limit, + pageToken.Order, + pageToken.After, + pageToken.Before, + pageToken.Filter, + cancellationToken.ToRequestOptions()); + + return PageCollectionHelpers.CreateAsync(enumerator); + } + + /// + /// Gets a page collection of file associations associated with a vector store batch file job, representing the files + /// that were scheduled for ingestion into the vector store. + /// + /// Options describing the collection to return. + /// A token that can be used to cancel this method call. + /// holds pages of values. To obtain a collection of values, call + /// . To obtain the current + /// page of values, call . + /// A collection of pages of . + public virtual PageCollection GetFileAssociations( + VectorStoreFileAssociationCollectionOptions? options = default, + CancellationToken cancellationToken = default) + { + VectorStoreFileBatchesPageEnumerator enumerator = new(_pipeline, _endpoint, + _vectorStoreId, + _batchId, + options?.PageSize, + options?.Order?.ToString(), + options?.AfterId, + options?.BeforeId, + options?.Filter?.ToString(), + cancellationToken.ToRequestOptions()); + + return PageCollectionHelpers.Create(enumerator); + } + + /// + /// Rehydrates a page collection of file associations from a page token. + /// that were scheduled for ingestion into the vector store. + /// + /// Page token corresponding to the first page of the collection to rehydrate. + /// A token that can be used to cancel this method call. + /// holds pages of values. To obtain a collection of values, call + /// . To obtain the current + /// page of values, call . + /// A collection of pages of . + public virtual PageCollection GetFileAssociations( + ContinuationToken firstPageToken, + CancellationToken cancellationToken = default) + { + Argument.AssertNotNull(firstPageToken, nameof(firstPageToken)); + + VectorStoreFileBatchesPageToken pageToken = VectorStoreFileBatchesPageToken.FromToken(firstPageToken); + + if (_vectorStoreId != pageToken.VectorStoreId) + { + throw new ArgumentException( + "Invalid page token. 'VectorStoreId' value does not match page token value.", + nameof(firstPageToken)); + } + + if (_batchId != pageToken.BatchId) + { + throw new ArgumentException( + "Invalid page token. 'BatchId' value does not match page token value.", + nameof(firstPageToken)); + } + + VectorStoreFileBatchesPageEnumerator enumerator = new(_pipeline, _endpoint, + pageToken.VectorStoreId, + pageToken.BatchId, + pageToken.Limit, + pageToken.Order, + pageToken.After, + pageToken.Before, + pageToken.Filter, + cancellationToken.ToRequestOptions()); + + return PageCollectionHelpers.Create(enumerator); + } +} \ No newline at end of file diff --git a/src/To.Be.Generated/VectorStoreFileBatchOperationToken.cs b/src/To.Be.Generated/VectorStoreFileBatchOperationToken.cs new file mode 100644 index 00000000..bcda6b72 --- /dev/null +++ b/src/To.Be.Generated/VectorStoreFileBatchOperationToken.cs @@ -0,0 +1,101 @@ +using System; +using System.ClientModel; +using System.Diagnostics; +using System.IO; +using System.Text.Json; + +#nullable enable + +namespace OpenAI.VectorStores; + +internal class VectorStoreFileBatchOperationToken : ContinuationToken +{ + public VectorStoreFileBatchOperationToken(string vectorStoreId, string batchId) + { + VectorStoreId = vectorStoreId; + BatchId = batchId; + } + + public string VectorStoreId { get; } + + public string BatchId { get; } + + public override BinaryData ToBytes() + { + using MemoryStream stream = new(); + using Utf8JsonWriter writer = new(stream); + writer.WriteStartObject(); + + writer.WriteString("vectorStoreId", VectorStoreId); + writer.WriteString("batchId", BatchId); + + writer.WriteEndObject(); + + writer.Flush(); + stream.Position = 0; + + return BinaryData.FromStream(stream); + } + + public static VectorStoreFileBatchOperationToken FromToken(ContinuationToken continuationToken) + { + if (continuationToken is VectorStoreFileBatchOperationToken token) + { + return token; + } + + BinaryData data = continuationToken.ToBytes(); + + if (data.ToMemory().Length == 0) + { + throw new ArgumentException("Failed to create VectorStoreFileBatchOperationToken from provided continuationToken.", nameof(continuationToken)); + } + + Utf8JsonReader reader = new(data); + + string vectorStoreId = null!; + string batchId = null!; + + reader.Read(); + + Debug.Assert(reader.TokenType == JsonTokenType.StartObject); + + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndObject) + { + break; + } + + Debug.Assert(reader.TokenType == JsonTokenType.PropertyName); + + string propertyName = reader.GetString()!; + + switch (propertyName) + { + case "vectorStoreId": + reader.Read(); + Debug.Assert(reader.TokenType == JsonTokenType.String); + vectorStoreId = reader.GetString()!; + break; + + case "batchId": + reader.Read(); + Debug.Assert(reader.TokenType == JsonTokenType.String); + batchId = reader.GetString()!; + break; + + default: + throw new JsonException($"Unrecognized property '{propertyName}'."); + } + } + + if (vectorStoreId is null || batchId is null) + { + throw new ArgumentException("Failed to create VectorStoreFileBatchOperationToken from provided continuationToken.", nameof(continuationToken)); + } + + return new(vectorStoreId, batchId); + } +} + diff --git a/src/To.Be.Generated/VectorStoreFileBatchOperationUpdateEnumerator.Protocol.cs b/src/To.Be.Generated/VectorStoreFileBatchOperationUpdateEnumerator.Protocol.cs new file mode 100644 index 00000000..bf3b1f1a --- /dev/null +++ b/src/To.Be.Generated/VectorStoreFileBatchOperationUpdateEnumerator.Protocol.cs @@ -0,0 +1,175 @@ +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections; +using System.Collections.Generic; +using System.ComponentModel; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +#nullable enable + +namespace OpenAI.VectorStores; + +internal partial class VectorStoreFileBatchOperationUpdateEnumerator : + IAsyncEnumerator, + IEnumerator +{ + private readonly ClientPipeline _pipeline; + private readonly Uri _endpoint; + private readonly CancellationToken _cancellationToken; + + private readonly string _vectorStoreId; + private readonly string _batchId; + + private ClientResult? _current; + private bool _hasNext = true; + + public VectorStoreFileBatchOperationUpdateEnumerator( + ClientPipeline pipeline, + Uri endpoint, + string vectorStoreId, + string batchId, + CancellationToken cancellationToken) + { + _pipeline = pipeline; + _endpoint = endpoint; + + _vectorStoreId = vectorStoreId; + _batchId = batchId; + + _cancellationToken = cancellationToken; + } + + public ClientResult Current => _current!; + + #region IEnumerator methods + + object IEnumerator.Current => _current!; + + bool IEnumerator.MoveNext() + { + if (!_hasNext) + { + _current = null; + return false; + } + + ClientResult result = GetBatchFileJob(_vectorStoreId, _batchId, _cancellationToken.ToRequestOptions()); + + _current = result; + _hasNext = HasNext(result); + + return true; + } + + void IEnumerator.Reset() => _current = null; + + void IDisposable.Dispose() { } + + #endregion + + #region IAsyncEnumerator methods + + ClientResult IAsyncEnumerator.Current => _current!; + + public async ValueTask MoveNextAsync() + { + if (!_hasNext) + { + _current = null; + return false; + } + + ClientResult result = await GetBatchFileJobAsync(_vectorStoreId, _batchId, _cancellationToken.ToRequestOptions()).ConfigureAwait(false); + + _current = result; + _hasNext = HasNext(result); + + return true; + } + + // TODO: handle Dispose and DisposeAsync using proper patterns? + ValueTask IAsyncDisposable.DisposeAsync() => default; + + #endregion + + private bool HasNext(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + // TODO: don't parse JsonDocument twice if possible + using JsonDocument doc = JsonDocument.Parse(response.Content); + string? status = doc.RootElement.GetProperty("status"u8).GetString(); + + bool isComplete = status == "completed" || + status == "cancelled" || + status == "failed"; + + return !isComplete; + } + + // Generated methods + + /// + /// [Protocol Method] Retrieves a vector store file batch. + /// + /// The ID of the vector store that the file batch belongs to. + /// The ID of the file batch being retrieved. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// or is null. + /// or is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual async Task GetBatchFileJobAsync(string vectorStoreId, string batchId, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); + Argument.AssertNotNullOrEmpty(batchId, nameof(batchId)); + + using PipelineMessage message = CreateGetVectorStoreFileBatchRequest(vectorStoreId, batchId, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + + /// + /// [Protocol Method] Retrieves a vector store file batch. + /// + /// The ID of the vector store that the file batch belongs to. + /// The ID of the file batch being retrieved. + /// The request options, which can override default behaviors of the client pipeline on a per-call basis. + /// or is null. + /// or is an empty string, and was expected to be non-empty. + /// Service returned a non-success status code. + /// The response returned from the service. + [EditorBrowsable(EditorBrowsableState.Never)] + public virtual ClientResult GetBatchFileJob(string vectorStoreId, string batchId, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(vectorStoreId, nameof(vectorStoreId)); + Argument.AssertNotNullOrEmpty(batchId, nameof(batchId)); + + using PipelineMessage message = CreateGetVectorStoreFileBatchRequest(vectorStoreId, batchId, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + + internal PipelineMessage CreateGetVectorStoreFileBatchRequest(string vectorStoreId, string batchId, RequestOptions options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "GET"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/vector_stores/", false); + uri.AppendPath(vectorStoreId, true); + uri.AppendPath("/file_batches/", false); + uri.AppendPath(batchId, true); + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + message.Apply(options); + return message; + } + + private static PipelineMessageClassifier? _pipelineMessageClassifier200; + private static PipelineMessageClassifier PipelineMessageClassifier200 => _pipelineMessageClassifier200 ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200 }); +} diff --git a/src/To.Be.Generated/VectorStoreFileBatchOperationUpdateEnumerator.cs b/src/To.Be.Generated/VectorStoreFileBatchOperationUpdateEnumerator.cs new file mode 100644 index 00000000..c901c63d --- /dev/null +++ b/src/To.Be.Generated/VectorStoreFileBatchOperationUpdateEnumerator.cs @@ -0,0 +1,54 @@ +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; + +#nullable enable + +namespace OpenAI.VectorStores; + +internal partial class VectorStoreFileBatchOperationUpdateEnumerator : + IAsyncEnumerator>, + IEnumerator> +{ + #region IEnumerator> methods + + ClientResult IEnumerator>.Current + { + get + { + if (Current is null) + { + return default!; + } + + return GetUpdateFromResult(Current); + } + } + + #endregion + + #region IAsyncEnumerator> methods + + ClientResult IAsyncEnumerator>.Current + { + get + { + if (Current is null) + { + return default!; + } + + return GetUpdateFromResult(Current); + } + } + + #endregion + + // Methods used by convenience implementation + private ClientResult GetUpdateFromResult(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + VectorStoreBatchFileJob run = VectorStoreBatchFileJob.FromResponse(response); + return ClientResult.FromValue(run, response); + } +} diff --git a/tests/Assistants/AssistantTests.cs b/tests/Assistants/AssistantTests.cs index 4aab667e..b2aa2b4c 100644 --- a/tests/Assistants/AssistantTests.cs +++ b/tests/Assistants/AssistantTests.cs @@ -8,7 +8,8 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; -using System.Threading; +using System.Net.ServerSentEvents; +using System.Text.Json; using System.Threading.Tasks; using static OpenAI.Tests.TestHelpers; @@ -230,23 +231,25 @@ public void BasicRunOperationsWork() Assert.That(runsPage.Values.Count, Is.EqualTo(0)); ThreadMessage message = client.CreateMessage(thread.Id, MessageRole.User, ["Hello, assistant!"]); Validate(message); - ThreadRun run = client.CreateRun(thread.Id, assistant.Id); - Validate(run); - Assert.That(run.Status, Is.EqualTo(RunStatus.Queued)); - Assert.That(run.CreatedAt, Is.GreaterThan(s_2024)); - ThreadRun retrievedRun = client.GetRun(thread.Id, run.Id); - Assert.That(retrievedRun.Id, Is.EqualTo(run.Id)); + RunOperation runOperation = client.CreateRun(ReturnWhen.Started, thread.Id, assistant.Id); + Validate(runOperation); + Assert.That(runOperation.Status, Is.EqualTo(RunStatus.Queued)); + Assert.That(runOperation.Value.CreatedAt, Is.GreaterThan(s_2024)); + Assert.That(runOperation.IsCompleted, Is.False); + //ThreadRun retrievedRun = client.GetRun(thread.Id, run.Id); + //Assert.That(retrievedRun.Id, Is.EqualTo(run.Id)); runsPage = client.GetRuns(thread).GetCurrentPage(); Assert.That(runsPage.Values.Count, Is.EqualTo(1)); - Assert.That(runsPage.Values[0].Id, Is.EqualTo(run.Id)); + Assert.That(runsPage.Values[0].Id, Is.EqualTo(runOperation.RunId)); PageResult messagesPage = client.GetMessages(thread).GetCurrentPage(); Assert.That(messagesPage.Values.Count, Is.GreaterThanOrEqualTo(1)); - for (int i = 0; i < 10 && !run.Status.IsTerminal; i++) - { - Thread.Sleep(500); - run = client.GetRun(run); - } + + runOperation.WaitUntilStopped(); + Assert.That(runOperation.IsCompleted, Is.True); + + ThreadRun run = runOperation.Value; + Assert.That(run.Status, Is.EqualTo(RunStatus.Completed)); Assert.That(run.CompletedAt, Is.GreaterThan(s_2024)); Assert.That(run.RequiredActions.Count, Is.EqualTo(0)); @@ -303,18 +306,17 @@ public void BasicRunStepFunctionalityWorks() }); Validate(thread); - ThreadRun run = client.CreateRun(thread, assistant); - Validate(run); + RunOperation runOperation = client.CreateRun(ReturnWhen.Started, thread, assistant); + Validate(runOperation); - while (!run.Status.IsTerminal) - { - Thread.Sleep(1000); - run = client.GetRun(run); - } + runOperation.WaitUntilStopped(); + Assert.That(runOperation.IsCompleted, Is.True); + + ThreadRun run = runOperation.Value; Assert.That(run.Status, Is.EqualTo(RunStatus.Completed)); Assert.That(run.Usage?.TotalTokens, Is.GreaterThan(0)); - PageCollection pages = client.GetRunSteps(run); + PageCollection pages = runOperation.GetRunSteps(); PageResult firstPage = pages.GetCurrentPage(); RunStep firstStep = firstPage.Values[0]; RunStep secondStep = firstPage.Values[1]; @@ -364,12 +366,12 @@ public void SettingResponseFormatWorks() Validate(thread); ThreadMessage message = client.CreateMessage(thread, MessageRole.User, ["Write some JSON for me!"]); Validate(message); - ThreadRun run = client.CreateRun(thread, assistant, new() + RunOperation runOperation = client.CreateRun(ReturnWhen.Completed, thread, assistant, new() { ResponseFormat = AssistantResponseFormat.JsonObject, }); - Validate(run); - Assert.That(run.ResponseFormat, Is.EqualTo(AssistantResponseFormat.JsonObject)); + Validate(runOperation); + Assert.That(runOperation.Value.ResponseFormat, Is.EqualTo(AssistantResponseFormat.JsonObject)); } [Test] @@ -406,7 +408,8 @@ public void FunctionToolsWork() Assert.That(responseToolDefinition?.FunctionName, Is.EqualTo("get_favorite_food_for_day_of_week")); Assert.That(responseToolDefinition?.Parameters, Is.Not.Null); - ThreadRun run = client.CreateThreadAndRun( + RunOperation runOperation = client.CreateThreadAndRun( + ReturnWhen.Started, assistant, new ThreadCreationOptions() { @@ -416,28 +419,22 @@ public void FunctionToolsWork() { AdditionalInstructions = "Call provided tools when appropriate.", }); - Validate(run); + Validate(runOperation); - for (int i = 0; i < 10 && !run.Status.IsTerminal; i++) - { - Thread.Sleep(500); - run = client.GetRun(run); - } - Assert.That(run.Status, Is.EqualTo(RunStatus.RequiresAction)); + runOperation.WaitUntilStopped(); + + ThreadRun run = runOperation.Value; + Assert.That(runOperation.Status, Is.EqualTo(RunStatus.RequiresAction)); Assert.That(run.RequiredActions?.Count, Is.EqualTo(1)); Assert.That(run.RequiredActions[0].ToolCallId, Is.Not.Null.And.Not.Empty); Assert.That(run.RequiredActions[0].FunctionName, Is.EqualTo("get_favorite_food_for_day_of_week")); Assert.That(run.RequiredActions[0].FunctionArguments, Is.Not.Null.And.Not.Empty); - run = client.SubmitToolOutputsToRun(run, [new(run.RequiredActions[0].ToolCallId, "tacos")]); - Assert.That(run.Status.IsTerminal, Is.False); + runOperation.SubmitToolOutputsToRun([new(run.RequiredActions[0].ToolCallId, "tacos")]); + Assert.That(runOperation.Status!.Value.IsTerminal, Is.False); - for (int i = 0; i < 10 && !run.Status.IsTerminal; i++) - { - Thread.Sleep(500); - run = client.GetRun(run); - } - Assert.That(run.Status, Is.EqualTo(RunStatus.Completed)); + runOperation.WaitUntilStopped(); + Assert.That(runOperation.Status, Is.EqualTo(RunStatus.Completed)); PageCollection messagePages = client.GetMessages(run.ThreadId, new MessageCollectionOptions() { Order = ListOrder.NewestFirst }); PageResult firstPage = messagePages.GetCurrentPage(); @@ -463,12 +460,11 @@ public async Task StreamingRunWorks() Stopwatch stopwatch = Stopwatch.StartNew(); void Print(string message) => Console.WriteLine($"[{stopwatch.ElapsedMilliseconds,6}] {message}"); - AsyncCollectionResult streamingResult - = client.CreateRunStreamingAsync(thread.Id, assistant.Id); + StreamingRunOperation streamingRunOperation = client.CreateRunStreaming(thread.Id, assistant.Id); Print(">>> Connected <<<"); - await foreach (StreamingUpdate update in streamingResult) + await foreach (StreamingUpdate update in streamingRunOperation.GetUpdatesStreamingAsync()) { string message = $"{update.UpdateKind} "; if (update is RunUpdate runUpdate) @@ -510,7 +506,7 @@ public async Task StreamingToolCall() void Print(string message) => Console.WriteLine($"[{stopwatch.ElapsedMilliseconds,6}] {message}"); Print(" >>> Beginning call ... "); - AsyncCollectionResult asyncResults = client.CreateThreadAndRunStreamingAsync( + StreamingRunOperation streamingRunOperation = client.CreateThreadAndRunStreaming( assistant, new() { @@ -518,39 +514,33 @@ public async Task StreamingToolCall() }); Print(" >>> Starting enumeration ..."); - ThreadRun run = null; - - do + await foreach (StreamingUpdate update in streamingRunOperation.GetUpdatesStreamingAsync()) { - run = null; - List toolOutputs = []; - await foreach (StreamingUpdate update in asyncResults) + string message = update.UpdateKind.ToString(); + + if (update is RunUpdate runUpdate) + { + message += $" run_id:{runUpdate.Value.Id}"; + } + if (update is RequiredActionUpdate requiredActionUpdate) { - string message = update.UpdateKind.ToString(); + RequiredAction action = requiredActionUpdate.RequiredActions.First(); + Assert.That(action.FunctionName, Is.EqualTo(getWeatherTool.FunctionName)); + Assert.That(requiredActionUpdate.GetThreadRun().Status, Is.EqualTo(RunStatus.RequiresAction)); + message += $" {action.FunctionName}"; - if (update is RunUpdate runUpdate) - { - message += $" run_id:{runUpdate.Value.Id}"; - run = runUpdate.Value; - } - if (update is RequiredActionUpdate requiredActionUpdate) - { - Assert.That(requiredActionUpdate.FunctionName, Is.EqualTo(getWeatherTool.FunctionName)); - Assert.That(requiredActionUpdate.GetThreadRun().Status, Is.EqualTo(RunStatus.RequiresAction)); - message += $" {requiredActionUpdate.FunctionName}"; - toolOutputs.Add(new(requiredActionUpdate.ToolCallId, "warm and sunny")); - } - if (update is MessageContentUpdate contentUpdate) - { - message += $" {contentUpdate.Text}"; - } - Print(message); + List toolOutputs = []; + + toolOutputs.Add(new(action.ToolCallId, "warm and sunny")); + + await streamingRunOperation.SubmitToolOutputsToRunStreamingAsync(toolOutputs); } - if (toolOutputs.Count > 0) + if (update is MessageContentUpdate contentUpdate) { - asyncResults = client.SubmitToolOutputsToRunStreamingAsync(run, toolOutputs); + message += $" {contentUpdate.Text}"; } - } while (run?.Status.IsTerminal == false); + Print(message); + } } [Test] @@ -640,15 +630,9 @@ This file describes the favorite foods of several people. Assert.That(thread.ToolResources?.FileSearch?.VectorStoreIds, Has.Count.EqualTo(1)); Assert.That(thread.ToolResources.FileSearch.VectorStoreIds[0], Is.EqualTo(createdVectorStoreId)); - ThreadRun run = client.CreateRun(thread, assistant); - Validate(run); - do - { - Thread.Sleep(1000); - run = client.GetRun(run); - } while (run?.Status.IsTerminal == false); - Assert.That(run.Status, Is.EqualTo(RunStatus.Completed)); - + RunOperation runOperation = client.CreateRun(ReturnWhen.Completed, thread, assistant); + Validate(runOperation); + IEnumerable messages = client.GetMessages(thread, new() { Order = ListOrder.NewestFirst }).GetAllValues(); int messageCount = 0; bool hasCake = false; @@ -933,87 +917,87 @@ public async Task Pagination_CanCastAssistantPageCollectionToConvenienceFromProt Assert.That(pageCount, Is.GreaterThanOrEqualTo(5)); } - [Test] - public void Pagination_CanRehydrateRunStepPageCollectionFromBytes() - { - AssistantClient client = GetTestClient(); - Assistant assistant = client.CreateAssistant("gpt-4o", new AssistantCreationOptions() - { - Tools = { new CodeInterpreterToolDefinition() }, - Instructions = "You help the user with mathematical descriptions and visualizations.", - }); - Validate(assistant); - - FileClient fileClient = new(); - OpenAIFileInfo equationFile = fileClient.UploadFile( - BinaryData.FromString(""" - x,y - 2,5 - 7,14, - 8,22 - """).ToStream(), - "text/csv", - FileUploadPurpose.Assistants); - Validate(equationFile); - - AssistantThread thread = client.CreateThread(new ThreadCreationOptions() - { - InitialMessages = - { - "Describe the contents of any available tool resource file." - + " Graph a linear regression and provide the coefficient of correlation." - + " Explain any code executed to evaluate.", - }, - ToolResources = new() - { - CodeInterpreter = new() - { - FileIds = { equationFile.Id }, - } - } - }); - Validate(thread); - - ThreadRun run = client.CreateRun(thread, assistant); - Validate(run); - - while (!run.Status.IsTerminal) - { - Thread.Sleep(1000); - run = client.GetRun(run); - } - Assert.That(run.Status, Is.EqualTo(RunStatus.Completed)); - Assert.That(run.Usage?.TotalTokens, Is.GreaterThan(0)); - - PageCollection pages = client.GetRunSteps(run); - IEnumerator> pageEnumerator = ((IEnumerable>)pages).GetEnumerator(); - - // Simulate rehydration of the collection - BinaryData rehydrationBytes = pages.GetCurrentPage().PageToken.ToBytes(); - ContinuationToken rehydrationToken = ContinuationToken.FromBytes(rehydrationBytes); - - PageCollection rehydratedPages = client.GetRunSteps(rehydrationToken); - IEnumerator> rehydratedPageEnumerator = ((IEnumerable>)rehydratedPages).GetEnumerator(); - - int pageCount = 0; - - while (pageEnumerator.MoveNext() && rehydratedPageEnumerator.MoveNext()) - { - PageResult page = pageEnumerator.Current; - PageResult rehydratedPage = rehydratedPageEnumerator.Current; - - Assert.AreEqual(page.Values.Count, rehydratedPage.Values.Count); - - for (int i = 0; i < page.Values.Count; i++) - { - Assert.AreEqual(page.Values[0].Id, rehydratedPage.Values[0].Id); - } - - pageCount++; - } - - Assert.That(pageCount, Is.GreaterThanOrEqualTo(1)); - } + //[Test] + //public void Pagination_CanRehydrateRunStepPageCollectionFromBytes() + //{ + // AssistantClient client = GetTestClient(); + // Assistant assistant = client.CreateAssistant("gpt-4o", new AssistantCreationOptions() + // { + // Tools = { new CodeInterpreterToolDefinition() }, + // Instructions = "You help the user with mathematical descriptions and visualizations.", + // }); + // Validate(assistant); + + // FileClient fileClient = new(); + // OpenAIFileInfo equationFile = fileClient.UploadFile( + // BinaryData.FromString(""" + // x,y + // 2,5 + // 7,14, + // 8,22 + // """).ToStream(), + // "text/csv", + // FileUploadPurpose.Assistants); + // Validate(equationFile); + + // AssistantThread thread = client.CreateThread(new ThreadCreationOptions() + // { + // InitialMessages = + // { + // "Describe the contents of any available tool resource file." + // + " Graph a linear regression and provide the coefficient of correlation." + // + " Explain any code executed to evaluate.", + // }, + // ToolResources = new() + // { + // CodeInterpreter = new() + // { + // FileIds = { equationFile.Id }, + // } + // } + // }); + // Validate(thread); + + // ThreadRun run = client.CreateRun(thread, assistant); + // Validate(run); + + // while (!run.Status.IsTerminal) + // { + // Thread.Sleep(1000); + // run = client.GetRun(run); + // } + // Assert.That(run.Status, Is.EqualTo(RunStatus.Completed)); + // Assert.That(run.Usage?.TotalTokens, Is.GreaterThan(0)); + + // PageCollection pages = client.GetRunSteps(run); + // IEnumerator> pageEnumerator = ((IEnumerable>)pages).GetEnumerator(); + + // // Simulate rehydration of the collection + // BinaryData rehydrationBytes = pages.GetCurrentPage().PageToken.ToBytes(); + // ContinuationToken rehydrationToken = ContinuationToken.FromBytes(rehydrationBytes); + + // PageCollection rehydratedPages = client.GetRunSteps(rehydrationToken); + // IEnumerator> rehydratedPageEnumerator = ((IEnumerable>)rehydratedPages).GetEnumerator(); + + // int pageCount = 0; + + // while (pageEnumerator.MoveNext() && rehydratedPageEnumerator.MoveNext()) + // { + // PageResult page = pageEnumerator.Current; + // PageResult rehydratedPage = rehydratedPageEnumerator.Current; + + // Assert.AreEqual(page.Values.Count, rehydratedPage.Values.Count); + + // for (int i = 0; i < page.Values.Count; i++) + // { + // Assert.AreEqual(page.Values[0].Id, rehydratedPage.Values[0].Id); + // } + + // pageCount++; + // } + + // Assert.That(pageCount, Is.GreaterThanOrEqualTo(1)); + //} [Test] public async Task MessagesWithRoles() @@ -1071,33 +1055,962 @@ async Task RefreshMessageListAsync() Assert.That(messages[2].Content[0].Text, Is.EqualTo(assistantMessageText)); } - /// - /// Performs basic, invariant validation of a target that was just instantiated from its corresponding origination - /// mechanism. If applicable, the instance is recorded into the test run for cleanup of persistent resources. - /// - /// Instance type being validated. - /// The instance to validate. - /// The provided instance type isn't supported. - private void Validate(T target) + #region LRO Tests + + [Test] + public void LRO_ProtocolOnly_Polling_CanWaitForThreadRunToComplete() { - if (target is Assistant assistant) + AssistantClient client = GetTestClient(); + Assistant assistant = client.CreateAssistant("gpt-3.5-turbo"); + Validate(assistant); + AssistantThread thread = client.CreateThread(); + Validate(thread); + PageResult runsPage = client.GetRuns(thread).GetCurrentPage(); + Assert.That(runsPage.Values.Count, Is.EqualTo(0)); + ThreadMessage message = client.CreateMessage(thread.Id, MessageRole.User, ["Hello, assistant!"]); + Validate(message); + + string json = $"{{\"assistant_id\":\"{assistant.Id}\"}}"; + BinaryContent content = BinaryContent.Create(BinaryData.FromString(json)); + + RunOperation runOperation = client.CreateRun(ReturnWhen.Started, thread.Id, content); + + PipelineResponse response = runOperation.GetRawResponse(); + using JsonDocument createdJsonDoc = JsonDocument.Parse(response.Content); + string runId = createdJsonDoc.RootElement.GetProperty("id"u8).GetString()!; + + runsPage = client.GetRuns(thread).GetCurrentPage(); + Assert.That(runsPage.Values.Count, Is.EqualTo(1)); + Assert.That(runsPage.Values[0].Id, Is.EqualTo(runId)); + + PageResult messagesPage = client.GetMessages(thread).GetCurrentPage(); + Assert.That(messagesPage.Values.Count, Is.GreaterThanOrEqualTo(1)); + + runOperation.WaitUntilStopped(); + + response = runOperation.GetRawResponse(); + using JsonDocument completedJsonDoc = JsonDocument.Parse(response.Content); + string status = completedJsonDoc.RootElement.GetProperty("status"u8).GetString()!; + + Assert.That(status, Is.EqualTo(RunStatus.Completed.ToString())); + Assert.That(runOperation.IsCompleted, Is.True); + + messagesPage = client.GetMessages(thread).GetCurrentPage(); + Assert.That(messagesPage.Values.Count, Is.EqualTo(2)); + + Assert.That(messagesPage.Values[0].Role, Is.EqualTo(MessageRole.Assistant)); + Assert.That(messagesPage.Values[1].Role, Is.EqualTo(MessageRole.User)); + Assert.That(messagesPage.Values[1].Id, Is.EqualTo(message.Id)); + } + + [Test] + public async Task LRO_ProtocolOnly_Streaming_CanWaitForThreadRunToComplete() + { + AssistantClient client = GetTestClient(); + Assistant assistant = client.CreateAssistant("gpt-3.5-turbo"); + Validate(assistant); + AssistantThread thread = client.CreateThread(); + Validate(thread); + PageResult runsPage = client.GetRuns(thread).GetCurrentPage(); + Assert.That(runsPage.Values.Count, Is.EqualTo(0)); + ThreadMessage message = client.CreateMessage(thread.Id, MessageRole.User, ["Hello, assistant!"]); + Validate(message); + + // Create streaming + string json = $"{{\"assistant_id\":\"{assistant.Id}\", \"stream\":true}}"; + BinaryContent content = BinaryContent.Create(BinaryData.FromString(json)); + RequestOptions options = new() { BufferResponse = false }; + + RunOperation runOperation = client.CreateRun(ReturnWhen.Started, thread.Id, content, options); + + // For streaming on protocol, if you call Wait, it will throw. + Assert.Throws(() => runOperation.WaitUntilStopped()); + + // Instead, callers must get the response stream and parse it. + PipelineResponse response = runOperation.GetRawResponse(); + IAsyncEnumerable> events = SseParser.Create( + response.ContentStream, + (_, bytes) => bytes.ToArray()).EnumerateAsync(); + + bool first = true; + string runId = default; + string status = default; + + PageResult messagesPage = default; + + await foreach (var sseItem in events) { - Assert.That(assistant?.Id, Is.Not.Null); - _assistantsToDelete.Add(assistant); + if (BinaryData.FromBytes(sseItem.Data).ToString() == "[DONE]") + { + continue; + } + + using JsonDocument doc = JsonDocument.Parse(sseItem.Data); + + if (first) + { + Assert.That(sseItem.EventType, Is.EqualTo("thread.run.created")); + + runId = doc.RootElement.GetProperty("id"u8).GetString()!; + + runsPage = client.GetRuns(thread).GetCurrentPage(); + Assert.That(runsPage.Values.Count, Is.EqualTo(1)); + Assert.That(runsPage.Values[0].Id, Is.EqualTo(runId)); + + messagesPage = client.GetMessages(thread).GetCurrentPage(); + Assert.That(messagesPage.Values.Count, Is.GreaterThanOrEqualTo(1)); + + first = false; + } + + string prefix = sseItem.EventType.AsSpan().Slice(0, 11).ToString(); + string suffix = sseItem.EventType.AsSpan().Slice(11).ToString(); + if (prefix == "thread.run." && !suffix.Contains("step.")) + { + status = doc.RootElement.GetProperty("status"u8).GetString()!; + + // Note: the below doesn't work because 'created' isn't a valid status. + //Assert.That(suffix, Is.EqualTo(status)); + } } - else if (target is AssistantThread thread) + + Assert.That(status, Is.EqualTo(RunStatus.Completed.ToString())); + + // For streaming on protocol, if you read IsCompleted, it will throw. + Assert.Throws(() => { bool b = runOperation.IsCompleted; }); + + messagesPage = client.GetMessages(thread).GetCurrentPage(); + Assert.That(messagesPage.Values.Count, Is.EqualTo(2)); + + Assert.That(messagesPage.Values[0].Role, Is.EqualTo(MessageRole.Assistant)); + Assert.That(messagesPage.Values[1].Role, Is.EqualTo(MessageRole.User)); + Assert.That(messagesPage.Values[1].Id, Is.EqualTo(message.Id)); + } + + [Test] + public async Task LRO_ProtocolOnly_Streaming_CanCancelThreadRun() + { + AssistantClient client = GetTestClient(); + Assistant assistant = client.CreateAssistant("gpt-3.5-turbo"); + Validate(assistant); + AssistantThread thread = client.CreateThread(); + Validate(thread); + PageResult runsPage = client.GetRuns(thread).GetCurrentPage(); + Assert.That(runsPage.Values.Count, Is.EqualTo(0)); + ThreadMessage message = client.CreateMessage(thread.Id, MessageRole.User, ["Hello, assistant!"]); + Validate(message); + + string threadId = thread.Id; + string runId = default; + + // Create streaming + string json = $"{{\"assistant_id\":\"{assistant.Id}\", \"stream\":true}}"; + BinaryContent content = BinaryContent.Create(BinaryData.FromString(json)); + RequestOptions options = new() { BufferResponse = false }; + + RunOperation runOperation = client.CreateRun(ReturnWhen.Started, thread.Id, content, options); + + // Instead, callers must get the response stream and parse it. + PipelineResponse response = runOperation.GetRawResponse(); + IAsyncEnumerable> events = SseParser.Create( + response.ContentStream, + (_, bytes) => bytes.ToArray()).EnumerateAsync(); + + bool first = true; + string status = default; + + PageResult messagesPage = default; + + await foreach (var sseItem in events) { - Assert.That(thread?.Id, Is.Not.Null); - _threadsToDelete.Add(thread); + if (BinaryData.FromBytes(sseItem.Data).ToString() == "[DONE]") + { + continue; + } + + using JsonDocument doc = JsonDocument.Parse(sseItem.Data); + + if (first) + { + Assert.That(sseItem.EventType, Is.EqualTo("thread.run.created")); + + runId = doc.RootElement.GetProperty("id"u8).GetString()!; + + runsPage = client.GetRuns(thread).GetCurrentPage(); + Assert.That(runsPage.Values.Count, Is.EqualTo(1)); + Assert.That(runsPage.Values[0].Id, Is.EqualTo(runId)); + + messagesPage = client.GetMessages(thread).GetCurrentPage(); + Assert.That(messagesPage.Values.Count, Is.GreaterThanOrEqualTo(1)); + + first = false; + + // Cancel the run while reading the event stream + runOperation.CancelRun(threadId, runId, options: default); + } + + string prefix = sseItem.EventType.AsSpan().Slice(0, 11).ToString(); + string suffix = sseItem.EventType.AsSpan().Slice(11).ToString(); + if (prefix == "thread.run." && !suffix.Contains("step.")) + { + status = doc.RootElement.GetProperty("status"u8).GetString()!; + } } - else if (target is ThreadMessage message) + + Assert.That(status, Is.EqualTo(RunStatus.Cancelled.ToString())); + } + + [Test] + public async Task LRO_ProtocolOnly_Streaming_CanSubmitToolOutputs() + { + FunctionToolDefinition getTemperatureTool = new() { - Assert.That(message?.Id, Is.Not.Null); - _messagesToDelete.Add(message); - } - else if (target is ThreadRun run) + FunctionName = "get_current_temperature", + Description = "Gets the current temperature at a specific location.", + Parameters = BinaryData.FromString(""" + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g., San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["Celsius", "Fahrenheit"], + "description": "The temperature unit to use. Infer this from the user's location." + } + } + } + """), + }; + + FunctionToolDefinition getRainProbabilityTool = new() { - Assert.That(run?.Id, Is.Not.Null); + FunctionName = "get_current_rain_probability", + Description = "Gets the current forecasted probability of rain at a specific location," + + " represented as a percent chance in the range of 0 to 100.", + Parameters = BinaryData.FromString(""" + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g., San Francisco, CA" + } + }, + "required": ["location"] + } + """), + }; + + // TODO: this test does not send a message that requires action -- update it + // to do that. + + AssistantClient client = GetTestClient(); + Assistant assistant = client.CreateAssistant("gpt-3.5-turbo"); + Validate(assistant); + AssistantThread thread = client.CreateThread(); + Validate(thread); + PageResult runsPage = client.GetRuns(thread).GetCurrentPage(); + Assert.That(runsPage.Values.Count, Is.EqualTo(0)); + ThreadMessage message = client.CreateMessage(thread.Id, MessageRole.User, ["Hello, assistant!"]); + Validate(message); + + string threadId = thread.Id; + string runId = default; + + // Create streaming + string json = $"{{\"assistant_id\":\"{assistant.Id}\", \"stream\":true}}"; + BinaryContent content = BinaryContent.Create(BinaryData.FromString(json)); + RequestOptions options = new() { BufferResponse = false }; + + RunOperation runOperation = client.CreateRun(ReturnWhen.Started, thread.Id, content, options); + + // Instead, callers must get the response stream and parse it. + PipelineResponse response = runOperation.GetRawResponse(); + IAsyncEnumerable> events = SseParser.Create( + response.ContentStream, + (_, bytes) => bytes.ToArray()).EnumerateAsync(); + + string status = default; + byte[] data = default; + + await foreach (var sseItem in events) + { + if (BinaryData.FromBytes(sseItem.Data).ToString() == "[DONE]") + { + continue; + } + + using JsonDocument doc = JsonDocument.Parse(sseItem.Data); + + string prefix = sseItem.EventType.AsSpan().Slice(0, 11).ToString(); + string suffix = sseItem.EventType.AsSpan().Slice(11).ToString(); + if (prefix == "thread.run." && !suffix.Contains("step.")) + { + status = doc.RootElement.GetProperty("status"u8).GetString()!; + data = sseItem.Data; + } + } + + if (status == "requires_action") + { + using JsonDocument doc = JsonDocument.Parse(data); + IEnumerable toolCallJsonElements = doc.RootElement + .GetProperty("required_action") + .GetProperty("submit_tool_outputs") + .GetProperty("tool_calls").EnumerateArray(); + + List outputsToSubmit = []; + + foreach (JsonElement toolCallJsonElement in toolCallJsonElements) + { + string functionName = toolCallJsonElement.GetProperty("function").GetProperty("name").GetString(); + string toolCallId = toolCallJsonElement.GetProperty("id").GetString(); + + if (functionName == getTemperatureTool.FunctionName) + { + outputsToSubmit.Add(new ToolOutput(toolCallId, "57")); + } + else if (functionName == getRainProbabilityTool.FunctionName) + { + outputsToSubmit.Add(new ToolOutput(toolCallId, "25%")); + } + } + + await runOperation.SubmitToolOutputsToRunAsync(outputsToSubmit); + } + + Assert.That(status, Is.EqualTo(RunStatus.Completed.ToString())); + } + + + [Test] + public void LRO_Convenience_Polling_CanWaitForThreadRunToComplete() + { + AssistantClient client = GetTestClient(); + Assistant assistant = client.CreateAssistant("gpt-3.5-turbo"); + Validate(assistant); + AssistantThread thread = client.CreateThread(); + Validate(thread); + PageResult runsPage = client.GetRuns(thread).GetCurrentPage(); + Assert.That(runsPage.Values.Count, Is.EqualTo(0)); + ThreadMessage message = client.CreateMessage(thread.Id, MessageRole.User, ["Hello, assistant!"]); + Validate(message); + + // Create polling + RunOperation runOperation = client.CreateRun(ReturnWhen.Started, thread, assistant); + + Assert.That(runOperation.IsCompleted, Is.False); + Assert.That(runOperation.ThreadId, Is.EqualTo(thread.Id)); + Assert.That(runOperation.RunId, Is.Not.Null); + Assert.That(runOperation.Status, Is.EqualTo(RunStatus.Queued)); + Assert.That(runOperation.Value, Is.Not.Null); + Assert.That(runOperation.Value.Id, Is.EqualTo(runOperation.RunId)); + + // Wait for operation to complete. + runOperation.WaitUntilStopped(); + + runsPage = client.GetRuns(thread).GetCurrentPage(); + Assert.That(runsPage.Values.Count, Is.EqualTo(1)); + Assert.That(runsPage.Values[0].Id, Is.EqualTo(runOperation.RunId)); + + Assert.That(runOperation.IsCompleted, Is.True); + Assert.That(runOperation.Status, Is.EqualTo(RunStatus.Completed)); + Assert.That(runOperation.Value.Status, Is.EqualTo(RunStatus.Completed)); + + PageResult messagesPage = client.GetMessages(thread).GetCurrentPage(); + Assert.That(messagesPage.Values.Count, Is.EqualTo(2)); + messagesPage = client.GetMessages(thread).GetCurrentPage(); + Assert.That(messagesPage.Values.Count, Is.EqualTo(2)); + + Assert.That(messagesPage.Values[0].Role, Is.EqualTo(MessageRole.Assistant)); + Assert.That(messagesPage.Values[1].Role, Is.EqualTo(MessageRole.User)); + Assert.That(messagesPage.Values[1].Id, Is.EqualTo(message.Id)); + } + + [Test] + public async Task LRO_Convenience_Polling_CanPollWithCustomPollingInterval() + { + AssistantClient client = GetTestClient(); + Assistant assistant = client.CreateAssistant("gpt-3.5-turbo"); + Validate(assistant); + AssistantThread thread = client.CreateThread(); + Validate(thread); + PageResult runsPage = client.GetRuns(thread).GetCurrentPage(); + Assert.That(runsPage.Values.Count, Is.EqualTo(0)); + ThreadMessage message = client.CreateMessage(thread.Id, MessageRole.User, ["Hello, assistant!"]); + Validate(message); + + // Create polling + RunOperation runOperation = client.CreateRun(ReturnWhen.Started, thread, assistant); + + Assert.That(runOperation.IsCompleted, Is.False); + Assert.That(runOperation.ThreadId, Is.EqualTo(thread.Id)); + Assert.That(runOperation.RunId, Is.Not.Null); + Assert.That(runOperation.Status, Is.EqualTo(RunStatus.Queued)); + Assert.That(runOperation.Value, Is.Not.Null); + Assert.That(runOperation.Value.Id, Is.EqualTo(runOperation.RunId)); + + // Poll manually to implement custom poll interval + IEnumerable updates = runOperation.GetUpdates(TimeSpan.Zero); + + int i = 0; + foreach (ThreadRun update in updates) + { + // Change polling interval for each update + await Task.Delay(i++ * 100); + } + + Assert.That(runOperation.IsCompleted, Is.True); + Assert.That(runOperation.Status, Is.EqualTo(RunStatus.Completed)); + Assert.That(runOperation.Value.Status, Is.EqualTo(RunStatus.Completed)); + } + + [Test] + public void LRO_Convenience_Polling_CanSubmitToolUpdates_Wait() + { + AssistantClient client = GetTestClient(); + + #region Create Assistant with Tools + + Assistant assistant = client.CreateAssistant("gpt-3.5-turbo", new AssistantCreationOptions() + { + Tools = + { + new FunctionToolDefinition() + { + FunctionName = "get_favorite_food_for_day_of_week", + Description = "gets the user's favorite food for a given day of the week, like Tuesday", + Parameters = BinaryData.FromObjectAsJson(new + { + type = "object", + properties = new + { + day_of_week = new + { + type = "string", + description = "a day of the week, like Tuesday or Saturday", + } + } + }), + }, + }, + }); + Validate(assistant); + Assert.That(assistant.Tools?.Count, Is.EqualTo(1)); + + FunctionToolDefinition responseToolDefinition = assistant.Tools[0] as FunctionToolDefinition; + Assert.That(responseToolDefinition?.FunctionName, Is.EqualTo("get_favorite_food_for_day_of_week")); + Assert.That(responseToolDefinition?.Parameters, Is.Not.Null); + + #endregion + + #region Create Thread + + AssistantThread thread = client.CreateThread(); + Validate(thread); + PageResult runsPage = client.GetRuns(thread).GetCurrentPage(); + Assert.That(runsPage.Values.Count, Is.EqualTo(0)); + ThreadMessage message = client.CreateMessage(thread.Id, MessageRole.User, ["What should I eat on Thursday?"]); + Validate(message); + + #endregion + + // Create run polling + RunOperation runOperation = client.CreateRun( + ReturnWhen.Started, + thread, assistant, + new RunCreationOptions() + { + AdditionalInstructions = "Call provided tools when appropriate.", + }); + + while (!runOperation.IsCompleted) + { + runOperation.WaitUntilStopped(); + + if (runOperation.Status == RunStatus.RequiresAction) + { + Assert.That(runOperation.Value.RequiredActions?.Count, Is.EqualTo(1)); + Assert.That(runOperation.Value.RequiredActions[0].ToolCallId, Is.Not.Null.And.Not.Empty); + Assert.That(runOperation.Value.RequiredActions[0].FunctionName, Is.EqualTo("get_favorite_food_for_day_of_week")); + Assert.That(runOperation.Value.RequiredActions[0].FunctionArguments, Is.Not.Null.And.Not.Empty); + Assert.That(runOperation.Status?.IsTerminal, Is.False); + + IEnumerable outputs = new List { + new ToolOutput(runOperation.Value.RequiredActions[0].ToolCallId, "tacos") + }; + + runOperation.SubmitToolOutputsToRun(outputs); + } + } + + Assert.That(runOperation.Status, Is.EqualTo(RunStatus.Completed)); + + PageCollection messagePages = client.GetMessages(runOperation.ThreadId, new MessageCollectionOptions() { Order = ListOrder.NewestFirst }); + PageResult page = messagePages.GetCurrentPage(); + Assert.That(page.Values.Count, Is.GreaterThan(1)); + Assert.That(page.Values[0].Role, Is.EqualTo(MessageRole.Assistant)); + Assert.That(page.Values[0].Content?[0], Is.Not.Null); + Assert.That(page.Values[0].Content[0].Text.ToLowerInvariant(), Does.Contain("tacos")); + } + + [Test] + public void LRO_Convenience_Polling_CanSubmitToolUpdates_GetAllUpdates() + { + AssistantClient client = GetTestClient(); + + #region Create Assistant with Tools + + Assistant assistant = client.CreateAssistant("gpt-3.5-turbo", new AssistantCreationOptions() + { + Tools = + { + new FunctionToolDefinition() + { + FunctionName = "get_favorite_food_for_day_of_week", + Description = "gets the user's favorite food for a given day of the week, like Tuesday", + Parameters = BinaryData.FromObjectAsJson(new + { + type = "object", + properties = new + { + day_of_week = new + { + type = "string", + description = "a day of the week, like Tuesday or Saturday", + } + } + }), + }, + }, + }); + Validate(assistant); + Assert.That(assistant.Tools?.Count, Is.EqualTo(1)); + + FunctionToolDefinition responseToolDefinition = assistant.Tools[0] as FunctionToolDefinition; + Assert.That(responseToolDefinition?.FunctionName, Is.EqualTo("get_favorite_food_for_day_of_week")); + Assert.That(responseToolDefinition?.Parameters, Is.Not.Null); + + #endregion + + #region Create Thread + + AssistantThread thread = client.CreateThread(); + Validate(thread); + PageResult runsPage = client.GetRuns(thread).GetCurrentPage(); + Assert.That(runsPage.Values.Count, Is.EqualTo(0)); + ThreadMessage message = client.CreateMessage(thread.Id, MessageRole.User, ["What should I eat on Thursday?"]); + Validate(message); + + #endregion + + // Create run polling + RunOperation runOperation = client.CreateRun( + ReturnWhen.Started, + thread, assistant, + new RunCreationOptions() + { + AdditionalInstructions = "Call provided tools when appropriate.", + }); + + IEnumerable updates = runOperation.GetUpdates(); + + foreach (ThreadRun update in updates) + { + if (update.Status == RunStatus.RequiresAction) + { + Assert.That(runOperation.Value.RequiredActions?.Count, Is.EqualTo(1)); + Assert.That(runOperation.Value.RequiredActions[0].ToolCallId, Is.Not.Null.And.Not.Empty); + Assert.That(runOperation.Value.RequiredActions[0].FunctionName, Is.EqualTo("get_favorite_food_for_day_of_week")); + Assert.That(runOperation.Value.RequiredActions[0].FunctionArguments, Is.Not.Null.And.Not.Empty); + Assert.That(runOperation.Status?.IsTerminal, Is.False); + + IEnumerable outputs = new List { + new ToolOutput(runOperation.Value.RequiredActions[0].ToolCallId, "tacos") + }; + + runOperation.SubmitToolOutputsToRun(outputs); + } + } + + Assert.That(runOperation.Status, Is.EqualTo(RunStatus.Completed)); + + PageCollection messagePages = client.GetMessages(runOperation.ThreadId, new MessageCollectionOptions() { Order = ListOrder.NewestFirst }); + PageResult page = messagePages.GetCurrentPage(); + Assert.That(page.Values.Count, Is.GreaterThan(1)); + Assert.That(page.Values[0].Role, Is.EqualTo(MessageRole.Assistant)); + Assert.That(page.Values[0].Content?[0], Is.Not.Null); + Assert.That(page.Values[0].Content[0].Text.ToLowerInvariant(), Does.Contain("tacos")); + } + + [Test] + public void LRO_Convenience_Polling_CanRehydrateRunOperationFromPageToken() + { + AssistantClient client = GetTestClient(); + Assistant assistant = client.CreateAssistant("gpt-3.5-turbo"); + Validate(assistant); + AssistantThread thread = client.CreateThread(); + Validate(thread); + PageResult runsPage = client.GetRuns(thread).GetCurrentPage(); + Assert.That(runsPage.Values.Count, Is.EqualTo(0)); + ThreadMessage message = client.CreateMessage(thread.Id, MessageRole.User, ["Hello, assistant!"]); + Validate(message); + + // Create polling + RunOperation runOperation = client.CreateRun(ReturnWhen.Started, thread, assistant); + + // Get the rehydration token + ContinuationToken rehydrationToken = runOperation.RehydrationToken; + + // Call the rehydration method + RunOperation rehydratedRunOperation = client.ContinueRun(rehydrationToken); + + // Validate operations are equivalent + Assert.That(runOperation.ThreadId, Is.EqualTo(rehydratedRunOperation.ThreadId)); + Assert.That(runOperation.RunId, Is.EqualTo(rehydratedRunOperation.RunId)); + + // Wait for both to complete + Task.WaitAll( + Task.Run(() => runOperation.WaitUntilStopped()), + Task.Run(() => rehydratedRunOperation.WaitUntilStopped())); + + Assert.That(runOperation.Status, Is.EqualTo(rehydratedRunOperation.Status)); + + // Validate that values from both are equivalent + PageCollection runStepPages = runOperation.GetRunSteps(); + PageCollection rehydratedRunStepPages = rehydratedRunOperation.GetRunSteps(); + + List runSteps = runStepPages.GetAllValues().ToList(); + List rehydratedRunSteps = rehydratedRunStepPages.GetAllValues().ToList(); + + for (int i = 0; i < runSteps.Count; i++) + { + Assert.AreEqual(runSteps[i].Id, rehydratedRunSteps[i].Id); + Assert.AreEqual(runSteps[i].Status, rehydratedRunSteps[i].Status); + } + + Assert.AreEqual(runSteps.Count, rehydratedRunSteps.Count); + Assert.That(runSteps.Count, Is.GreaterThan(0)); + } + + [Test] + public async Task LRO_Convenience_Streaming_CanWaitForThreadRunToComplete() + { + AssistantClient client = GetTestClient(); + Assistant assistant = client.CreateAssistant("gpt-3.5-turbo"); + Validate(assistant); + AssistantThread thread = client.CreateThread(); + Validate(thread); + PageResult runsPage = client.GetRuns(thread).GetCurrentPage(); + Assert.That(runsPage.Values.Count, Is.EqualTo(0)); + ThreadMessage message = client.CreateMessage(thread.Id, MessageRole.User, ["Hello, assistant!"]); + Validate(message); + + // Create streaming + StreamingRunOperation runOperation = client.CreateRunStreaming(thread, assistant); + + // Before the response stream has been enumerated, all the public properties + // should still be null. + Assert.That(runOperation.IsCompleted, Is.False); + Assert.That(runOperation.ThreadId, Is.Null); + Assert.That(runOperation.RunId, Is.Null); + Assert.That(runOperation.Status, Is.Null); + Assert.That(runOperation.Value, Is.Null); + + // Wait for operation to complete, as implemented in streaming operation type. + await runOperation.WaitUntilStoppedAsync(); + + // Validate that req/response operation work with streaming + IAsyncEnumerable steps = runOperation.GetRunStepsAsync().GetAllValuesAsync(); + Assert.That(await steps.CountAsync(), Is.GreaterThan(0)); + + runsPage = client.GetRuns(thread).GetCurrentPage(); + Assert.That(runsPage.Values.Count, Is.EqualTo(1)); + Assert.That(runsPage.Values[0].Id, Is.EqualTo(runOperation.RunId)); + + Assert.That(runOperation.IsCompleted, Is.True); + Assert.That(runOperation.Status, Is.EqualTo(RunStatus.Completed)); + Assert.That(runOperation.Value.Status, Is.EqualTo(RunStatus.Completed)); + + PageResult messagesPage = client.GetMessages(thread).GetCurrentPage(); + Assert.That(messagesPage.Values.Count, Is.EqualTo(2)); + messagesPage = client.GetMessages(thread).GetCurrentPage(); + Assert.That(messagesPage.Values.Count, Is.EqualTo(2)); + + Assert.That(messagesPage.Values[0].Role, Is.EqualTo(MessageRole.Assistant)); + Assert.That(messagesPage.Values[1].Role, Is.EqualTo(MessageRole.User)); + Assert.That(messagesPage.Values[1].Id, Is.EqualTo(message.Id)); + } + + [Test] + public async Task LRO_Convenience_Streaming_CanGetStreamingUpdates() + { + AssistantClient client = GetTestClient(); + Assistant assistant = client.CreateAssistant("gpt-3.5-turbo"); + Validate(assistant); + AssistantThread thread = client.CreateThread(); + Validate(thread); + PageResult runsPage = client.GetRuns(thread).GetCurrentPage(); + Assert.That(runsPage.Values.Count, Is.EqualTo(0)); + ThreadMessage message = client.CreateMessage(thread.Id, MessageRole.User, ["Hello, assistant!"]); + Validate(message); + + // Create streaming + StreamingRunOperation runOperation = client.CreateRunStreaming(thread, assistant); + + // Before the response stream has been enumerated, all the public properties + // should still be null. + Assert.That(runOperation.IsCompleted, Is.False); + Assert.That(runOperation.ThreadId, Is.Null); + Assert.That(runOperation.RunId, Is.Null); + Assert.That(runOperation.Status, Is.Null); + Assert.That(runOperation.Value, Is.Null); + + // Instead of calling Wait for the operation to complete, manually + // enumerate the updates from the update stream. + IAsyncEnumerable updates = runOperation.GetUpdatesStreamingAsync(); + await foreach (StreamingUpdate update in updates) + { + // TODO: we could print messages here, but not critical ATM. + } + + // TODO: add this back once conveniences are available + //ThreadRun retrievedRun = runOperation.GetRun(thread.Id, runOperation.RunId, options: default); + //Assert.That(retrievedRun.Id, Is.EqualTo(run.Id)); + + runsPage = client.GetRuns(thread).GetCurrentPage(); + Assert.That(runsPage.Values.Count, Is.EqualTo(1)); + Assert.That(runsPage.Values[0].Id, Is.EqualTo(runOperation.RunId)); + + Assert.That(runOperation.IsCompleted, Is.True); + Assert.That(runOperation.Status, Is.EqualTo(RunStatus.Completed)); + Assert.That(runOperation.Value.Status, Is.EqualTo(RunStatus.Completed)); + Assert.That(runOperation.ThreadId, Is.EqualTo(thread.Id)); + + PageResult messagesPage = client.GetMessages(thread).GetCurrentPage(); + Assert.That(messagesPage.Values.Count, Is.EqualTo(2)); + messagesPage = client.GetMessages(thread).GetCurrentPage(); + Assert.That(messagesPage.Values.Count, Is.EqualTo(2)); + + Assert.That(messagesPage.Values[0].Role, Is.EqualTo(MessageRole.Assistant)); + Assert.That(messagesPage.Values[1].Role, Is.EqualTo(MessageRole.User)); + Assert.That(messagesPage.Values[1].Id, Is.EqualTo(message.Id)); + } + + [Test] + public async Task LRO_Convenience_Streaming_CanSubmitToolUpdates_GetAllUpdates() + { + AssistantClient client = GetTestClient(); + + #region Create Assistant with Tools + + Assistant assistant = client.CreateAssistant("gpt-3.5-turbo", new AssistantCreationOptions() + { + Tools = + { + new FunctionToolDefinition() + { + FunctionName = "get_favorite_food_for_day_of_week", + Description = "gets the user's favorite food for a given day of the week, like Tuesday", + Parameters = BinaryData.FromObjectAsJson(new + { + type = "object", + properties = new + { + day_of_week = new + { + type = "string", + description = "a day of the week, like Tuesday or Saturday", + } + } + }), + }, + }, + }); + Validate(assistant); + Assert.That(assistant.Tools?.Count, Is.EqualTo(1)); + + FunctionToolDefinition responseToolDefinition = assistant.Tools[0] as FunctionToolDefinition; + Assert.That(responseToolDefinition?.FunctionName, Is.EqualTo("get_favorite_food_for_day_of_week")); + Assert.That(responseToolDefinition?.Parameters, Is.Not.Null); + + #endregion + + #region Create Thread + + AssistantThread thread = client.CreateThread(); + Validate(thread); + PageResult runsPage = client.GetRuns(thread).GetCurrentPage(); + Assert.That(runsPage.Values.Count, Is.EqualTo(0)); + ThreadMessage message = client.CreateMessage(thread.Id, MessageRole.User, ["What should I eat on Thursday?"]); + Validate(message); + + #endregion + + // Create run streaming + StreamingRunOperation runOperation = client.CreateRunStreaming(thread, assistant, + new RunCreationOptions() + { + AdditionalInstructions = "Call provided tools when appropriate.", + }); + + + IAsyncEnumerable updates = runOperation.GetUpdatesStreamingAsync(); + + await foreach (StreamingUpdate update in updates) + { + if (update is RunUpdate && + runOperation.Status == RunStatus.RequiresAction) + { + Assert.That(runOperation.Value.RequiredActions?.Count, Is.EqualTo(1)); + Assert.That(runOperation.Value.RequiredActions[0].ToolCallId, Is.Not.Null.And.Not.Empty); + Assert.That(runOperation.Value.RequiredActions[0].FunctionName, Is.EqualTo("get_favorite_food_for_day_of_week")); + Assert.That(runOperation.Value.RequiredActions[0].FunctionArguments, Is.Not.Null.And.Not.Empty); + Assert.That(runOperation.Status?.IsTerminal, Is.False); + + IEnumerable outputs = new List { + new ToolOutput(runOperation.Value.RequiredActions[0].ToolCallId, "tacos") + }; + + await runOperation.SubmitToolOutputsToRunStreamingAsync(outputs); + } + } + + Assert.That(runOperation.Status, Is.EqualTo(RunStatus.Completed)); + + PageCollection messagePages = client.GetMessages(runOperation.ThreadId, new MessageCollectionOptions() { Order = ListOrder.NewestFirst }); + PageResult page = messagePages.GetCurrentPage(); + Assert.That(page.Values.Count, Is.GreaterThan(1)); + Assert.That(page.Values[0].Role, Is.EqualTo(MessageRole.Assistant)); + Assert.That(page.Values[0].Content?[0], Is.Not.Null); + Assert.That(page.Values[0].Content[0].Text.ToLowerInvariant(), Does.Contain("tacos")); + } + + [Test] + public async Task LRO_Convenience_Streaming_CanSubmitToolUpdates_Wait() + { + AssistantClient client = GetTestClient(); + + #region Create Assistant with Tools + + Assistant assistant = client.CreateAssistant("gpt-3.5-turbo", new AssistantCreationOptions() + { + Tools = + { + new FunctionToolDefinition() + { + FunctionName = "get_favorite_food_for_day_of_week", + Description = "gets the user's favorite food for a given day of the week, like Tuesday", + Parameters = BinaryData.FromObjectAsJson(new + { + type = "object", + properties = new + { + day_of_week = new + { + type = "string", + description = "a day of the week, like Tuesday or Saturday", + } + } + }), + }, + }, + }); + Validate(assistant); + Assert.That(assistant.Tools?.Count, Is.EqualTo(1)); + + FunctionToolDefinition responseToolDefinition = assistant.Tools[0] as FunctionToolDefinition; + Assert.That(responseToolDefinition?.FunctionName, Is.EqualTo("get_favorite_food_for_day_of_week")); + Assert.That(responseToolDefinition?.Parameters, Is.Not.Null); + + #endregion + + #region Create Thread + + AssistantThread thread = client.CreateThread(); + Validate(thread); + PageResult runsPage = client.GetRuns(thread).GetCurrentPage(); + Assert.That(runsPage.Values.Count, Is.EqualTo(0)); + ThreadMessage message = client.CreateMessage(thread.Id, MessageRole.User, ["What should I eat on Thursday?"]); + Validate(message); + + #endregion + + // Create run streaming + StreamingRunOperation runOperation = client.CreateRunStreaming(thread, assistant, + new RunCreationOptions() + { + AdditionalInstructions = "Call provided tools when appropriate.", + }); + + do + { + await runOperation.WaitUntilStoppedAsync(); + + if (runOperation.Status == RunStatus.RequiresAction) + { + Assert.That(runOperation.Value.RequiredActions?.Count, Is.EqualTo(1)); + Assert.That(runOperation.Value.RequiredActions[0].ToolCallId, Is.Not.Null.And.Not.Empty); + Assert.That(runOperation.Value.RequiredActions[0].FunctionName, Is.EqualTo("get_favorite_food_for_day_of_week")); + Assert.That(runOperation.Value.RequiredActions[0].FunctionArguments, Is.Not.Null.And.Not.Empty); + Assert.That(runOperation.Status?.IsTerminal, Is.False); + + IEnumerable outputs = new List { + new ToolOutput(runOperation.Value.RequiredActions[0].ToolCallId, "tacos") + }; + + await runOperation.SubmitToolOutputsToRunStreamingAsync(outputs); + } + } + while (!runOperation.IsCompleted); + + Assert.That(runOperation.Status, Is.EqualTo(RunStatus.Completed)); + + PageCollection messagePages = client.GetMessages(runOperation.ThreadId, new MessageCollectionOptions() { Order = ListOrder.NewestFirst }); + PageResult page = messagePages.GetCurrentPage(); + Assert.That(page.Values.Count, Is.GreaterThan(1)); + Assert.That(page.Values[0].Role, Is.EqualTo(MessageRole.Assistant)); + Assert.That(page.Values[0].Content?[0], Is.Not.Null); + Assert.That(page.Values[0].Content[0].Text.ToLowerInvariant(), Does.Contain("tacos")); + } + + #endregion + + /// + /// Performs basic, invariant validation of a target that was just instantiated from its corresponding origination + /// mechanism. If applicable, the instance is recorded into the test run for cleanup of persistent resources. + /// + /// Instance type being validated. + /// The instance to validate. + /// The provided instance type isn't supported. + private void Validate(T target) + { + if (target is Assistant assistant) + { + Assert.That(assistant?.Id, Is.Not.Null); + _assistantsToDelete.Add(assistant); + } + else if (target is AssistantThread thread) + { + Assert.That(thread?.Id, Is.Not.Null); + _threadsToDelete.Add(thread); + } + else if (target is ThreadMessage message) + { + Assert.That(message?.Id, Is.Not.Null); + _messagesToDelete.Add(message); + } + else if (target is ThreadRun run) + { + Assert.That(run?.Id, Is.Not.Null); + } + else if (target is RunOperation runOperation) + { + Assert.That(runOperation?.ThreadId, Is.Not.Null); + Assert.That(runOperation?.RunId, Is.Not.Null); } else if (target is OpenAIFileInfo file) { diff --git a/tests/Assistants/VectorStoreTests.cs b/tests/Assistants/VectorStoreTests.cs index ea1d2b7a..bf58a476 100644 --- a/tests/Assistants/VectorStoreTests.cs +++ b/tests/Assistants/VectorStoreTests.cs @@ -271,7 +271,7 @@ public void CanUseBatchIngestion() IReadOnlyList testFiles = GetNewTestFiles(5); - VectorStoreBatchFileJob batchJob = client.CreateBatchFileJob(vectorStore, testFiles); + VectorStoreFileBatchOperation batchJob = client.CreateBatchFileJob(ReturnWhen.Started, vectorStore, testFiles); Validate(batchJob); Assert.Multiple(() => @@ -281,12 +281,9 @@ public void CanUseBatchIngestion() Assert.That(batchJob.Status, Is.EqualTo(VectorStoreBatchFileJobStatus.InProgress)); }); - for (int i = 0; i < 10 && client.GetBatchFileJob(batchJob).Value.Status != VectorStoreBatchFileJobStatus.Completed; i++) - { - Thread.Sleep(500); - } + batchJob.WaitForCompletion(); - foreach (VectorStoreFileAssociation association in client.GetFileAssociations(batchJob).GetAllValues()) + foreach (VectorStoreFileAssociation association in batchJob.GetFileAssociations().GetAllValues()) { Assert.Multiple(() => { @@ -382,10 +379,10 @@ protected void Cleanup() { ErrorOptions = ClientErrorBehaviors.NoThrow, }; - foreach (VectorStoreBatchFileJob job in _jobsToCancel) + foreach (VectorStoreFileBatchOperation operation in _operationsToCancel) { - ClientResult protocolResult = vectorStoreClient.CancelBatchFileJob(job.VectorStoreId, job.BatchId, requestOptions); - Console.WriteLine($"Cleanup: {job.BatchId} => {protocolResult?.GetRawResponse()?.Status}"); + ClientResult protocolResult = operation.CancelBatchFileJob(operation.VectorStoreId, operation.BatchId, requestOptions); + Console.WriteLine($"Cleanup: {operation.BatchId} => {protocolResult?.GetRawResponse()?.Status}"); } foreach (VectorStoreFileAssociation association in _associationsToRemove) { @@ -413,10 +410,15 @@ protected void Cleanup() /// The provided instance type isn't supported. private void Validate(T target) { - if (target is VectorStoreBatchFileJob job) + if (target is VectorStoreFileBatchOperation operation) { - Assert.That(job.BatchId, Is.Not.Null); - _jobsToCancel.Add(job); + Assert.That(operation.VectorStoreId, Is.Not.Null); + Assert.That(operation.BatchId, Is.Not.Null); + + Assert.That(operation.Value, Is.Not.Null); + Assert.That(operation.Status, Is.Not.Null); + + _operationsToCancel.Add(operation); } else if (target is VectorStoreFileAssociation association) { @@ -440,7 +442,7 @@ private void Validate(T target) } } - private readonly List _jobsToCancel = []; + private readonly List _operationsToCancel = []; private readonly List _associationsToRemove = []; private readonly List _filesToDelete = []; private readonly List _vectorStoresToDelete = []; diff --git a/tests/Batch/BatchTests.cs b/tests/Batch/BatchTests.cs index 16b7c5b4..1289d74a 100644 --- a/tests/Batch/BatchTests.cs +++ b/tests/Batch/BatchTests.cs @@ -75,11 +75,11 @@ public async Task CreateGetAndCancelBatchProtocol() testMetadataKey = "test metadata value", }, })); - ClientResult batchResult = IsAsync - ? await client.CreateBatchAsync(content) - : client.CreateBatch(content); + BatchOperation batchOperation = IsAsync + ? await client.CreateBatchAsync(ReturnWhen.Started, content) + : client.CreateBatch(ReturnWhen.Started, content); - BinaryData response = batchResult.GetRawResponse().Content; + BinaryData response = batchOperation.GetRawResponse().Content; JsonDocument jsonDocument = JsonDocument.Parse(response); JsonElement idElement = jsonDocument.RootElement.GetProperty("id"); @@ -100,18 +100,14 @@ public async Task CreateGetAndCancelBatchProtocol() Assert.That(status, Is.EqualTo("validating")); Assert.That(testMetadataKey, Is.EqualTo("test metadata value")); - batchResult = IsAsync - ? await client.GetBatchAsync(id, options: null) - : client.GetBatch(id, options: null); - JsonElement endpointElement = jsonDocument.RootElement.GetProperty("endpoint"); string endpoint = endpointElement.GetString(); Assert.That(endpoint, Is.EqualTo("/v1/chat/completions")); - batchResult = IsAsync - ? await client.CancelBatchAsync(id, options: null) - : client.CancelBatch(id, options: null); + ClientResult clientResult = IsAsync + ? await batchOperation.CancelBatchAsync(id, options: null) + : batchOperation.CancelBatch(id, options: null); statusElement = jsonDocument.RootElement.GetProperty("status"); status = statusElement.GetString(); diff --git a/tests/FineTuning/FineTuningTests.cs b/tests/FineTuning/FineTuningTests.cs new file mode 100644 index 00000000..f7d5f978 --- /dev/null +++ b/tests/FineTuning/FineTuningTests.cs @@ -0,0 +1,89 @@ +using NUnit.Framework; +using OpenAI.Files; +using OpenAI.FineTuning; +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using static OpenAI.Tests.TestHelpers; + +namespace OpenAI.Tests.FineTuning; + +#pragma warning disable OPENAI001 + +[Parallelizable(ParallelScope.Fixtures)] +public partial class FineTuningTests +{ + [Test] + public void BasicFineTuningOperationsWork() + { + // Upload training file first + FileClient fileClient = GetTestClient(TestScenario.Files); + string filename = "toy_chat.jsonl"; + BinaryData fileContent = BinaryData.FromString(""" + {"messages": [{"role": "user", "content": "I lost my book today."}, {"role": "assistant", "content": "You can read everything on ebooks these days!"}]} + {"messages": [{"role": "system", "content": "You are a happy assistant that puts a positive spin on everything."}, {"role": "assistant", "content": "You're great!"}]} + """); + OpenAIFileInfo uploadedFile = fileClient.UploadFile(fileContent, filename, FileUploadPurpose.FineTune); + Assert.That(uploadedFile?.Filename, Is.EqualTo(filename)); + + // Submit fine-tuning job + FineTuningClient client = GetTestClient(); + + string json = $"{{\"training_file\":\"{uploadedFile.Id}\",\"model\":\"gpt-3.5-turbo\"}}"; + BinaryData input = BinaryData.FromString(json); + using BinaryContent content = BinaryContent.Create(input); + + ClientResult result = client.CreateJob(content); + } + + [OneTimeTearDown] + protected void Cleanup() + { + // Skip cleanup if there is no API key (e.g., if we are not running live tests). + if (string.IsNullOrEmpty(Environment.GetEnvironmentVariable("OPEN_API_KEY"))) + { + return; + } + + FileClient fileClient = new(); + RequestOptions requestOptions = new() + { + ErrorOptions = ClientErrorBehaviors.NoThrow, + }; + foreach (OpenAIFileInfo file in _filesToDelete) + { + Console.WriteLine($"Cleanup: {file.Id} -> {fileClient.DeleteFile(file.Id, requestOptions)?.GetRawResponse().Status}"); + } + _filesToDelete.Clear(); + } + + /// + /// Performs basic, invariant validation of a target that was just instantiated from its corresponding origination + /// mechanism. If applicable, the instance is recorded into the test run for cleanup of persistent resources. + /// + /// Instance type being validated. + /// The instance to validate. + /// The provided instance type isn't supported. + private void Validate(T target) + { + if (target is OpenAIFileInfo file) + { + Assert.That(file?.Id, Is.Not.Null); + _filesToDelete.Add(file); + } + else + { + throw new NotImplementedException($"{nameof(Validate)} helper not implemented for: {typeof(T)}"); + } + } + + private readonly List _filesToDelete = []; + + private static FineTuningClient GetTestClient() => GetTestClient(TestScenario.FineTuning); + + private static readonly DateTimeOffset s_2024 = new(2024, 1, 1, 0, 0, 0, TimeSpan.Zero); + private static readonly string s_cleanupMetadataKey = $"test_metadata_cleanup_eligible"; +} + +#pragma warning restore OPENAI001 diff --git a/tests/OpenAI.Tests.csproj b/tests/OpenAI.Tests.csproj index b33b036e..9c403858 100644 --- a/tests/OpenAI.Tests.csproj +++ b/tests/OpenAI.Tests.csproj @@ -1,4 +1,4 @@ - + net8.0 @@ -14,5 +14,6 @@ + \ No newline at end of file diff --git a/tests/Utility/System.Net.ServerSentEvents.cs b/tests/Utility/System.Net.ServerSentEvents.cs new file mode 100644 index 00000000..ea0b4f81 --- /dev/null +++ b/tests/Utility/System.Net.ServerSentEvents.cs @@ -0,0 +1,623 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +// This file contains a source copy of: +// https://github.com/dotnet/runtime/tree/2bd15868f12ace7cee9999af61d5c130b2603f04/src/libraries/System.Net.ServerSentEvents/src/System/Net/ServerSentEvents +// Once the System.Net.ServerSentEvents package is available, this file should be removed and replaced with a package reference. +// +// The only changes made to this code from the original are: +// - Enabled nullable reference types at file scope, and use a few null suppression operators to work around the lack of [NotNull] +// - Put into a single file for ease of management (it should not be edited in this repo). +// - Changed public types to be internal. +// - Removed a use of a [NotNull] attribute to assist in netstandard2.0 compilation. +// - Replaced a reference to a .resx string with an inline constant. + +#nullable enable + +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using System.IO; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading.Tasks; +using System.Threading; + +namespace System.Net.ServerSentEvents +{ + /// Represents a server-sent event. + /// Specifies the type of data payload in the event. + internal readonly struct SseItem + { + /// Initializes the server-sent event. + /// The event's payload. + /// The event's type. + public SseItem(T data, string eventType) + { + Data = data; + EventType = eventType; + } + + /// Gets the event's payload. + public T Data { get; } + + /// Gets the event's type. + public string EventType { get; } + } + + /// Encapsulates a method for parsing the bytes payload of a server-sent event. + /// Specifies the type of the return value of the parser. + /// The event's type. + /// The event's payload bytes. + /// The parsed . + internal delegate T SseItemParser(string eventType, ReadOnlySpan data); + + /// Provides a parser for parsing server-sent events. + internal static class SseParser + { + /// The default ("message") for an event that did not explicitly specify a type. + public const string EventTypeDefault = "message"; + + /// Creates a parser for parsing a of server-sent events into a sequence of values. + /// The stream containing the data to parse. + /// + /// The enumerable of strings, which may be enumerated synchronously or asynchronously. The strings + /// are decoded from the UTF8-encoded bytes of the payload of each event. + /// + /// is null. + /// + /// This overload has behavior equivalent to calling with a delegate + /// that decodes the data of each event using 's GetString method. + /// + public static SseParser Create(Stream sseStream) => + Create(sseStream, static (_, bytes) => Utf8GetString(bytes)); + + /// Creates a parser for parsing a of server-sent events into a sequence of values. + /// Specifies the type of data in each event. + /// The stream containing the data to parse. + /// The parser to use to transform each payload of bytes into a data element. + /// The enumerable, which may be enumerated synchronously or asynchronously. + /// is null. + /// is null. + public static SseParser Create(Stream sseStream, SseItemParser itemParser) => + new SseParser( + sseStream ?? throw new ArgumentNullException(nameof(sseStream)), + itemParser ?? throw new ArgumentNullException(nameof(itemParser))); + + /// Encoding.UTF8.GetString(bytes) + internal static string Utf8GetString(ReadOnlySpan bytes) + { +#if NET + return Encoding.UTF8.GetString(bytes); +#else + unsafe + { + fixed (byte* ptr = bytes) + { + return ptr is null ? + string.Empty : + Encoding.UTF8.GetString(ptr, bytes.Length); + } + } +#endif + } + } + + /// Provides a parser for server-sent events information. + /// Specifies the type of data parsed from an event. + internal sealed class SseParser + { + // For reference: + // Specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#server-sent-events + + /// Carriage Return. + private const byte CR = (byte)'\r'; + /// Line Feed. + private const byte LF = (byte)'\n'; + /// Carriage Return Line Feed. + private static ReadOnlySpan CRLF => "\r\n"u8; + + /// The default size of an ArrayPool buffer to rent. + /// Larger size used by default to minimize number of reads. Smaller size used in debug to stress growth/shifting logic. + private const int DefaultArrayPoolRentSize = +#if DEBUG + 16; +#else + 1024; +#endif + + /// The stream to be parsed. + private readonly Stream _stream; + /// The parser delegate used to transform bytes into a . + private readonly SseItemParser _itemParser; + + /// Indicates whether the enumerable has already been used for enumeration. + private int _used; + + /// Buffer, either empty or rented, containing the data being read from the stream while looking for the next line. + private byte[] _lineBuffer = []; + /// The starting offset of valid data in . + private int _lineOffset; + /// The length of valid data in , starting from . + private int _lineLength; + /// The index in where a newline ('\r', '\n', or "\r\n") was found. + private int _newlineIndex; + /// The index in of characters already checked for newlines. + /// + /// This is to avoid O(LineLength^2) behavior in the rare case where we have long lines that are built-up over multiple reads. + /// We want to avoid re-checking the same characters we've already checked over and over again. + /// + private int _lastSearchedForNewline; + /// Set when eof has been reached in the stream. + private bool _eof; + + /// Rented buffer containing buffered data for the next event. + private byte[]? _dataBuffer; + /// The length of valid data in , starting from index 0. + private int _dataLength; + /// Whether data has been appended to . + /// This can be different than != 0 if empty data was appended. + private bool _dataAppended; + + /// The event type for the next event. + private string _eventType = SseParser.EventTypeDefault; + + /// Initialize the enumerable. + /// The stream to parse. + /// The function to use to parse payload bytes into a . + internal SseParser(Stream stream, SseItemParser itemParser) + { + _stream = stream; + _itemParser = itemParser; + } + + /// Gets an enumerable of the server-sent events from this parser. + /// The parser has already been enumerated. Such an exception may propagate out of a call to . + public IEnumerable> Enumerate() + { + // Validate that the parser is only used for one enumeration. + ThrowIfNotFirstEnumeration(); + + // Rent a line buffer. This will grow as needed. The line buffer is what's passed to the stream, + // so we want it to be large enough to reduce the number of reads we need to do when data is + // arriving quickly. (In debug, we use a smaller buffer to stress the growth and shifting logic.) + _lineBuffer = ArrayPool.Shared.Rent(DefaultArrayPoolRentSize); + try + { + // Spec: "Event streams in this format must always be encoded as UTF-8". + // Skip a UTF8 BOM if it exists at the beginning of the stream. (The BOM is defined as optional in the SSE grammar.) + while (FillLineBuffer() != 0 && _lineLength < Utf8Bom.Length) ; + SkipBomIfPresent(); + + // Process all events in the stream. + while (true) + { + // See if there's a complete line in data already read from the stream. Lines are permitted to + // end with CR, LF, or CRLF. Look for all of them and if we find one, process the line. However, + // if we only find a CR and it's at the end of the read data, don't process it now, as we want + // to process it together with an LF that might immediately follow, rather than treating them + // as two separate characters, in which case we'd incorrectly process the CR as a line by itself. + GetNextSearchOffsetAndLength(out int searchOffset, out int searchLength); + _newlineIndex = _lineBuffer.AsSpan(searchOffset, searchLength).IndexOfAny(CR, LF); + if (_newlineIndex >= 0) + { + _lastSearchedForNewline = -1; + _newlineIndex += searchOffset; + if (_lineBuffer[_newlineIndex] is LF || // the newline is LF + _newlineIndex - _lineOffset + 1 < _lineLength || // we must have CR and we have whatever comes after it + _eof) // if we get here, we know we have a CR at the end of the buffer, so it's definitely the whole newline if we've hit EOF + { + // Process the line. + if (ProcessLine(out SseItem sseItem, out int advance)) + { + yield return sseItem; + } + + // Move past the line. + _lineOffset += advance; + _lineLength -= advance; + continue; + } + } + else + { + // Record the last position searched for a newline. The next time we search, + // we'll search from here rather than from _lineOffset, in order to avoid searching + // the same characters again. + _lastSearchedForNewline = _lineOffset + _lineLength; + } + + // We've processed everything in the buffer we currently can, so if we've already read EOF, we're done. + if (_eof) + { + // Spec: "Once the end of the file is reached, any pending data must be discarded. (If the file ends in the middle of an + // event, before the final empty line, the incomplete event is not dispatched.)" + break; + } + + // Read more data into the buffer. + FillLineBuffer(); + } + } + finally + { + ArrayPool.Shared.Return(_lineBuffer); + if (_dataBuffer is not null) + { + ArrayPool.Shared.Return(_dataBuffer); + } + } + } + + /// Gets an asynchronous enumerable of the server-sent events from this parser. + /// The cancellation token to use to cancel the enumeration. + /// The parser has already been enumerated. Such an exception may propagate out of a call to . + /// The enumeration was canceled. Such an exception may propagate out of a call to . + public async IAsyncEnumerable> EnumerateAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + // Validate that the parser is only used for one enumeration. + ThrowIfNotFirstEnumeration(); + + // Rent a line buffer. This will grow as needed. The line buffer is what's passed to the stream, + // so we want it to be large enough to reduce the number of reads we need to do when data is + // arriving quickly. (In debug, we use a smaller buffer to stress the growth and shifting logic.) + _lineBuffer = ArrayPool.Shared.Rent(DefaultArrayPoolRentSize); + try + { + // Spec: "Event streams in this format must always be encoded as UTF-8". + // Skip a UTF8 BOM if it exists at the beginning of the stream. (The BOM is defined as optional in the SSE grammar.) + while (await FillLineBufferAsync(cancellationToken).ConfigureAwait(false) != 0 && _lineLength < Utf8Bom.Length) ; + SkipBomIfPresent(); + + // Process all events in the stream. + while (true) + { + // See if there's a complete line in data already read from the stream. Lines are permitted to + // end with CR, LF, or CRLF. Look for all of them and if we find one, process the line. However, + // if we only find a CR and it's at the end of the read data, don't process it now, as we want + // to process it together with an LF that might immediately follow, rather than treating them + // as two separate characters, in which case we'd incorrectly process the CR as a line by itself. + GetNextSearchOffsetAndLength(out int searchOffset, out int searchLength); + _newlineIndex = _lineBuffer.AsSpan(searchOffset, searchLength).IndexOfAny(CR, LF); + if (_newlineIndex >= 0) + { + _lastSearchedForNewline = -1; + _newlineIndex += searchOffset; + if (_lineBuffer[_newlineIndex] is LF || // newline is LF + _newlineIndex - _lineOffset + 1 < _lineLength || // newline is CR, and we have whatever comes after it + _eof) // if we get here, we know we have a CR at the end of the buffer, so it's definitely the whole newline if we've hit EOF + { + // Process the line. + if (ProcessLine(out SseItem sseItem, out int advance)) + { + yield return sseItem; + } + + // Move past the line. + _lineOffset += advance; + _lineLength -= advance; + continue; + } + } + else + { + // Record the last position searched for a newline. The next time we search, + // we'll search from here rather than from _lineOffset, in order to avoid searching + // the same characters again. + _lastSearchedForNewline = searchOffset + searchLength; + } + + // We've processed everything in the buffer we currently can, so if we've already read EOF, we're done. + if (_eof) + { + // Spec: "Once the end of the file is reached, any pending data must be discarded. (If the file ends in the middle of an + // event, before the final empty line, the incomplete event is not dispatched.)" + break; + } + + // Read more data into the buffer. + await FillLineBufferAsync(cancellationToken).ConfigureAwait(false); + } + } + finally + { + ArrayPool.Shared.Return(_lineBuffer); + if (_dataBuffer is not null) + { + ArrayPool.Shared.Return(_dataBuffer); + } + } + } + + /// Gets the next index and length with which to perform a newline search. + private void GetNextSearchOffsetAndLength(out int searchOffset, out int searchLength) + { + if (_lastSearchedForNewline > _lineOffset) + { + searchOffset = _lastSearchedForNewline; + searchLength = _lineLength - (_lastSearchedForNewline - _lineOffset); + } + else + { + searchOffset = _lineOffset; + searchLength = _lineLength; + } + + Debug.Assert(searchOffset >= _lineOffset, $"{searchOffset}, {_lineLength}"); + Debug.Assert(searchOffset <= _lineOffset + _lineLength, $"{searchOffset}, {_lineOffset}, {_lineLength}"); + Debug.Assert(searchOffset <= _lineBuffer.Length, $"{searchOffset}, {_lineBuffer.Length}"); + + Debug.Assert(searchLength >= 0, $"{searchLength}"); + Debug.Assert(searchLength <= _lineLength, $"{searchLength}, {_lineLength}"); + } + + private int GetNewLineLength() + { + Debug.Assert(_newlineIndex - _lineOffset < _lineLength, "Expected to be positioned at a non-empty newline"); + return _lineBuffer.AsSpan(_newlineIndex, _lineLength - (_newlineIndex - _lineOffset)).StartsWith(CRLF) ? 2 : 1; + } + + /// + /// If there's no room remaining in the line buffer, either shifts the contents + /// left or grows the buffer in order to make room for the next read. + /// + private void ShiftOrGrowLineBufferIfNecessary() + { + // If data we've read is butting up against the end of the buffer and + // it's not taking up the entire buffer, slide what's there down to + // the beginning, making room to read more data into the buffer (since + // there's no newline in the data that's there). Otherwise, if the whole + // buffer is full, grow the buffer to accommodate more data, since, again, + // what's there doesn't contain a newline and thus a line is longer than + // the current buffer accommodates. + if (_lineOffset + _lineLength == _lineBuffer.Length) + { + if (_lineOffset != 0) + { + _lineBuffer.AsSpan(_lineOffset, _lineLength).CopyTo(_lineBuffer); + if (_lastSearchedForNewline >= 0) + { + _lastSearchedForNewline -= _lineOffset; + } + _lineOffset = 0; + } + else if (_lineLength == _lineBuffer.Length) + { + GrowBuffer(ref _lineBuffer!, _lineBuffer.Length * 2); + } + } + } + + /// Processes a complete line from the SSE stream. + /// The parsed item if the method returns true. + /// How many characters to advance in the line buffer. + /// true if an SSE item was successfully parsed; otherwise, false. + private bool ProcessLine(out SseItem sseItem, out int advance) + { + ReadOnlySpan line = _lineBuffer.AsSpan(_lineOffset, _newlineIndex - _lineOffset); + + // Spec: "If the line is empty (a blank line) Dispatch the event" + if (line.IsEmpty) + { + advance = GetNewLineLength(); + + if (_dataAppended) + { + sseItem = new SseItem(_itemParser(_eventType, _dataBuffer.AsSpan(0, _dataLength)), _eventType); + _eventType = SseParser.EventTypeDefault; + _dataLength = 0; + _dataAppended = false; + return true; + } + + sseItem = default; + return false; + } + + // Find the colon separating the field name and value. + int colonPos = line.IndexOf((byte)':'); + ReadOnlySpan fieldName; + ReadOnlySpan fieldValue; + if (colonPos >= 0) + { + // Spec: "Collect the characters on the line before the first U+003A COLON character (:), and let field be that string." + fieldName = line.Slice(0, colonPos); + + // Spec: "Collect the characters on the line after the first U+003A COLON character (:), and let value be that string. + // If value starts with a U+0020 SPACE character, remove it from value." + fieldValue = line.Slice(colonPos + 1); + if (!fieldValue.IsEmpty && fieldValue[0] == (byte)' ') + { + fieldValue = fieldValue.Slice(1); + } + } + else + { + // Spec: "using the whole line as the field name, and the empty string as the field value." + fieldName = line; + fieldValue = []; + } + + if (fieldName.SequenceEqual("data"u8)) + { + // Spec: "Append the field value to the data buffer, then append a single U+000A LINE FEED (LF) character to the data buffer." + // Spec: "If the data buffer's last character is a U+000A LINE FEED (LF) character, then remove the last character from the data buffer." + + // If there's nothing currently in the data buffer and we can easily detect that this line is immediately followed by + // an empty line, we can optimize it to just handle the data directly from the line buffer, rather than first copying + // into the data buffer and dispatching from there. + if (!_dataAppended) + { + int newlineLength = GetNewLineLength(); + ReadOnlySpan remainder = _lineBuffer.AsSpan(_newlineIndex + newlineLength, _lineLength - line.Length - newlineLength); + if (!remainder.IsEmpty && + (remainder[0] is LF || (remainder[0] is CR && remainder.Length > 1))) + { + advance = line.Length + newlineLength + (remainder.StartsWith(CRLF) ? 2 : 1); + sseItem = new SseItem(_itemParser(_eventType, fieldValue), _eventType); + _eventType = SseParser.EventTypeDefault; + return true; + } + } + + // We need to copy the data from the data buffer to the line buffer. Make sure there's enough room. + if (_dataBuffer is null || _dataLength + _lineLength + 1 > _dataBuffer.Length) + { + GrowBuffer(ref _dataBuffer, _dataLength + _lineLength + 1); + } + + // Append a newline if there's already content in the buffer. + // Then copy the field value to the data buffer + if (_dataAppended) + { + _dataBuffer![_dataLength++] = LF; + } + fieldValue.CopyTo(_dataBuffer.AsSpan(_dataLength)); + _dataLength += fieldValue.Length; + _dataAppended = true; + } + else if (fieldName.SequenceEqual("event"u8)) + { + // Spec: "Set the event type buffer to field value." + _eventType = SseParser.Utf8GetString(fieldValue); + } + else if (fieldName.SequenceEqual("id"u8)) + { + // Spec: "If the field value does not contain U+0000 NULL, then set the last event ID buffer to the field value. Otherwise, ignore the field." + if (fieldValue.IndexOf((byte)'\0') < 0) + { + // Note that fieldValue might be empty, in which case LastEventId will naturally be reset to the empty string. This is per spec. + LastEventId = SseParser.Utf8GetString(fieldValue); + } + } + else if (fieldName.SequenceEqual("retry"u8)) + { + // Spec: "If the field value consists of only ASCII digits, then interpret the field value as an integer in base ten, + // and set the event stream's reconnection time to that integer. Otherwise, ignore the field." + if (long.TryParse( +#if NET7_0_OR_GREATER + fieldValue, +#else + SseParser.Utf8GetString(fieldValue), +#endif + NumberStyles.None, CultureInfo.InvariantCulture, out long milliseconds)) + { + ReconnectionInterval = TimeSpan.FromMilliseconds(milliseconds); + } + } + else + { + // We'll end up here if the line starts with a colon, producing an empty field name, or if the field name is otherwise unrecognized. + // Spec: "If the line starts with a U+003A COLON character (:) Ignore the line." + // Spec: "Otherwise, The field is ignored" + } + + advance = line.Length + GetNewLineLength(); + sseItem = default; + return false; + } + + /// Gets the last event ID. + /// This value is updated any time a new last event ID is parsed. It is not reset between SSE items. + public string LastEventId { get; private set; } = string.Empty; // Spec: "must be initialized to the empty string" + + /// Gets the reconnection interval. + /// + /// If no retry event was received, this defaults to , and it will only + /// ever be in that situation. If a client wishes to retry, the server-sent + /// events specification states that the interval may then be decided by the client implementation and should be a + /// few seconds. + /// + public TimeSpan ReconnectionInterval { get; private set; } = Timeout.InfiniteTimeSpan; + + /// Transitions the object to a used state, throwing if it's already been used. + private void ThrowIfNotFirstEnumeration() + { + if (Interlocked.Exchange(ref _used, 1) != 0) + { + throw new InvalidOperationException("The enumerable may be enumerated only once."); + } + } + + /// Reads data from the stream into the line buffer. + private int FillLineBuffer() + { + ShiftOrGrowLineBufferIfNecessary(); + + int offset = _lineOffset + _lineLength; + int bytesRead = _stream.Read( +#if NET + _lineBuffer.AsSpan(offset)); +#else + _lineBuffer, offset, _lineBuffer.Length - offset); +#endif + + if (bytesRead > 0) + { + _lineLength += bytesRead; + } + else + { + _eof = true; + bytesRead = 0; + } + + return bytesRead; + } + + /// Reads data asynchronously from the stream into the line buffer. + private async ValueTask FillLineBufferAsync(CancellationToken cancellationToken) + { + ShiftOrGrowLineBufferIfNecessary(); + + int offset = _lineOffset + _lineLength; + int bytesRead = await +#if NET + _stream.ReadAsync(_lineBuffer.AsMemory(offset), cancellationToken) +#else + new ValueTask(_stream.ReadAsync(_lineBuffer, offset, _lineBuffer.Length - offset, cancellationToken)) +#endif + .ConfigureAwait(false); + + if (bytesRead > 0) + { + _lineLength += bytesRead; + } + else + { + _eof = true; + bytesRead = 0; + } + + return bytesRead; + } + + /// Gets the UTF8 BOM. + private static ReadOnlySpan Utf8Bom => [0xEF, 0xBB, 0xBF]; + + /// Called at the beginning of processing to skip over an optional UTF8 byte order mark. + private void SkipBomIfPresent() + { + Debug.Assert(_lineOffset == 0, $"Expected _lineOffset == 0, got {_lineOffset}"); + + if (_lineBuffer.AsSpan(0, _lineLength).StartsWith(Utf8Bom)) + { + _lineOffset += 3; + _lineLength -= 3; + } + } + + /// Grows the buffer, returning the existing one to the ArrayPool and renting an ArrayPool replacement. + private static void GrowBuffer(ref byte[]? buffer, int minimumLength) + { + byte[]? toReturn = buffer; + buffer = ArrayPool.Shared.Rent(Math.Max(minimumLength, DefaultArrayPoolRentSize)); + if (toReturn is not null) + { + Array.Copy(toReturn, buffer, toReturn.Length); + ArrayPool.Shared.Return(toReturn); + } + } + } +} \ No newline at end of file diff --git a/tests/Utility/TestHelpers.cs b/tests/Utility/TestHelpers.cs index 9231231e..6ff1237a 100644 --- a/tests/Utility/TestHelpers.cs +++ b/tests/Utility/TestHelpers.cs @@ -5,6 +5,7 @@ using OpenAI.Chat; using OpenAI.Embeddings; using OpenAI.Files; +using OpenAI.FineTuning; using OpenAI.Images; using OpenAI.VectorStores; using System; @@ -54,6 +55,7 @@ public static T GetTestClient(TestScenario scenario, string overrideModel = n TestScenario.Chat => new ChatClient(overrideModel ?? "gpt-4o-mini", options), TestScenario.Embeddings => new EmbeddingClient(overrideModel ?? "text-embedding-3-small", options), TestScenario.Files => new FileClient(options), + TestScenario.FineTuning => new FineTuningClient(options), TestScenario.Images => new ImageClient(overrideModel ?? "dall-e-3", options), #pragma warning disable OPENAI001 TestScenario.VectorStores => new VectorStoreClient(options),