Skip to content

Commit 25e1215

Browse files
crickmanTaoChenOSU
andauthored
.Net Agents - Support IAutoFunctionInvocationFilter for OpenAIAssistantAgent (#9690)
### Motivation and Context <!-- Thank you for your contribution to the semantic-kernel repo! Please help reviewers and future users, providing the following information: 1. Why is this change required? 2. What problem does it solve? 3. What scenario does it contribute to? 4. If it fixes an open issue, please link to the issue here. --> Fixes: #9673 ### Description <!-- Describe your changes, the overall approach, the underlying design. These notes will help understanding how your code works. Thanks! --> - Add explicit support for `IAutoFunctionInvocationFilter` within run processing for `OpenAIAssistantAgent` - Verified expected behavior for `IFunctionInvocationFilter` - Added sample ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [X] The code builds clean without any errors or warnings - [X] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [X] All unit tests pass, and I have added new tests where possible - [X] I didn't break anyone 😄 --------- Co-authored-by: Tao Chen <[email protected]>
1 parent 983da69 commit 25e1215

File tree

10 files changed

+474
-113
lines changed

10 files changed

+474
-113
lines changed

dotnet/samples/Concepts/Agents/ChatCompletion_FunctionTermination.cs

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -233,19 +233,7 @@ private Kernel CreateKernelWithFilter()
233233
{
234234
IKernelBuilder builder = Kernel.CreateBuilder();
235235

236-
if (this.UseOpenAIConfig)
237-
{
238-
builder.AddOpenAIChatCompletion(
239-
TestConfiguration.OpenAI.ChatModelId,
240-
TestConfiguration.OpenAI.ApiKey);
241-
}
242-
else
243-
{
244-
builder.AddAzureOpenAIChatCompletion(
245-
TestConfiguration.AzureOpenAI.ChatDeploymentName,
246-
TestConfiguration.AzureOpenAI.Endpoint,
247-
TestConfiguration.AzureOpenAI.ApiKey);
248-
}
236+
base.AddChatCompletionToKernel(builder);
249237

250238
builder.Services.AddSingleton<IAutoFunctionInvocationFilter>(new AutoInvocationFilter());
251239

dotnet/samples/Concepts/Agents/OpenAIAssistant_ChartMaker.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ Sum 426 1622 856 2904
6666
async Task InvokeAgentAsync(string input)
6767
{
6868
ChatMessageContent message = new(AuthorRole.User, input);
69-
chat.AddChatMessage(new(AuthorRole.User, input));
69+
chat.AddChatMessage(message);
7070
this.WriteAgentChatMessage(message);
7171

7272
await foreach (ChatMessageContent response in chat.InvokeAsync(agent))
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
using System.ComponentModel;
3+
using Microsoft.Extensions.DependencyInjection;
4+
using Microsoft.SemanticKernel;
5+
using Microsoft.SemanticKernel.Agents;
6+
using Microsoft.SemanticKernel.Agents.OpenAI;
7+
using Microsoft.SemanticKernel.ChatCompletion;
8+
9+
namespace Agents;
10+
11+
/// <summary>
12+
/// Demonstrate usage of <see cref="IAutoFunctionInvocationFilter"/> for and
13+
/// <see cref="IFunctionInvocationFilter"/> filters with <see cref="OpenAIAssistantAgent"/>
14+
/// via <see cref="AgentChat"/>.
15+
/// </summary>
16+
public class OpenAIAssistant_FunctionFilters(ITestOutputHelper output) : BaseAgentsTest(output)
17+
{
18+
protected override bool ForceOpenAI => true; // %%% REMOVE
19+
20+
[Fact]
21+
public async Task UseFunctionInvocationFilterAsync()
22+
{
23+
// Define the agent
24+
OpenAIAssistantAgent agent = await CreateAssistantAsync(CreateKernelWithInvokeFilter());
25+
26+
// Invoke assistant agent (non streaming)
27+
await InvokeAssistantAsync(agent);
28+
}
29+
30+
[Fact]
31+
public async Task UseFunctionInvocationFilterStreamingAsync()
32+
{
33+
// Define the agent
34+
OpenAIAssistantAgent agent = await CreateAssistantAsync(CreateKernelWithInvokeFilter());
35+
36+
// Invoke assistant agent (streaming)
37+
await InvokeAssistantStreamingAsync(agent);
38+
}
39+
40+
[Theory]
41+
[InlineData(false)]
42+
[InlineData(true)]
43+
public async Task UseAutoFunctionInvocationFilterAsync(bool terminate)
44+
{
45+
// Define the agent
46+
OpenAIAssistantAgent agent = await CreateAssistantAsync(CreateKernelWithAutoFilter(terminate));
47+
48+
// Invoke assistant agent (non streaming)
49+
await InvokeAssistantAsync(agent);
50+
}
51+
52+
[Theory]
53+
[InlineData(false)]
54+
[InlineData(true)]
55+
public async Task UseAutoFunctionInvocationFilterWithStreamingAgentInvocationAsync(bool terminate)
56+
{
57+
// Define the agent
58+
OpenAIAssistantAgent agent = await CreateAssistantAsync(CreateKernelWithAutoFilter(terminate));
59+
60+
// Invoke assistant agent (streaming)
61+
await InvokeAssistantStreamingAsync(agent);
62+
}
63+
64+
private async Task InvokeAssistantAsync(OpenAIAssistantAgent agent)
65+
{
66+
// Create a thread for the agent conversation.
67+
AgentGroupChat chat = new();
68+
69+
try
70+
{
71+
// Respond to user input, invoking functions where appropriate.
72+
ChatMessageContent message = new(AuthorRole.User, "What is the special soup?");
73+
chat.AddChatMessage(message);
74+
await chat.InvokeAsync(agent).ToArrayAsync();
75+
76+
// Display the entire chat history.
77+
ChatMessageContent[] history = await chat.GetChatMessagesAsync().Reverse().ToArrayAsync();
78+
this.WriteChatHistory(history);
79+
}
80+
finally
81+
{
82+
await chat.ResetAsync();
83+
await agent.DeleteAsync();
84+
}
85+
}
86+
87+
private async Task InvokeAssistantStreamingAsync(OpenAIAssistantAgent agent)
88+
{
89+
// Create a thread for the agent conversation.
90+
AgentGroupChat chat = new();
91+
92+
try
93+
{
94+
// Respond to user input, invoking functions where appropriate.
95+
ChatMessageContent message = new(AuthorRole.User, "What is the special soup?");
96+
chat.AddChatMessage(message);
97+
await chat.InvokeStreamingAsync(agent).ToArrayAsync();
98+
99+
// Display the entire chat history.
100+
ChatMessageContent[] history = await chat.GetChatMessagesAsync().Reverse().ToArrayAsync();
101+
this.WriteChatHistory(history);
102+
}
103+
finally
104+
{
105+
await chat.ResetAsync();
106+
await agent.DeleteAsync();
107+
}
108+
}
109+
110+
private void WriteChatHistory(IEnumerable<ChatMessageContent> history)
111+
{
112+
Console.WriteLine("\n================================");
113+
Console.WriteLine("CHAT HISTORY");
114+
Console.WriteLine("================================");
115+
foreach (ChatMessageContent message in history)
116+
{
117+
this.WriteAgentChatMessage(message);
118+
}
119+
}
120+
121+
private async Task<OpenAIAssistantAgent> CreateAssistantAsync(Kernel kernel)
122+
{
123+
OpenAIAssistantAgent agent =
124+
await OpenAIAssistantAgent.CreateAsync(
125+
this.GetClientProvider(),
126+
new OpenAIAssistantDefinition(base.Model)
127+
{
128+
Instructions = "Answer questions about the menu.",
129+
Metadata = AssistantSampleMetadata,
130+
},
131+
kernel: kernel
132+
);
133+
134+
KernelPlugin plugin = KernelPluginFactory.CreateFromType<MenuPlugin>();
135+
agent.Kernel.Plugins.Add(plugin);
136+
137+
return agent;
138+
}
139+
140+
private Kernel CreateKernelWithAutoFilter(bool terminate)
141+
{
142+
IKernelBuilder builder = Kernel.CreateBuilder();
143+
144+
base.AddChatCompletionToKernel(builder);
145+
146+
builder.Services.AddSingleton<IAutoFunctionInvocationFilter>(new AutoInvocationFilter(terminate));
147+
148+
return builder.Build();
149+
}
150+
151+
private Kernel CreateKernelWithInvokeFilter()
152+
{
153+
IKernelBuilder builder = Kernel.CreateBuilder();
154+
155+
base.AddChatCompletionToKernel(builder);
156+
157+
builder.Services.AddSingleton<IFunctionInvocationFilter>(new InvocationFilter());
158+
159+
return builder.Build();
160+
}
161+
162+
private sealed class MenuPlugin
163+
{
164+
[KernelFunction, Description("Provides a list of specials from the menu.")]
165+
[System.Diagnostics.CodeAnalysis.SuppressMessage("Design", "CA1024:Use properties where appropriate", Justification = "Too smart")]
166+
public string GetSpecials()
167+
{
168+
return
169+
"""
170+
Special Soup: Clam Chowder
171+
Special Salad: Cobb Salad
172+
Special Drink: Chai Tea
173+
""";
174+
}
175+
176+
[KernelFunction, Description("Provides the price of the requested menu item.")]
177+
public string GetItemPrice([Description("The name of the menu item.")] string menuItem)
178+
{
179+
return "$9.99";
180+
}
181+
}
182+
183+
private sealed class InvocationFilter() : IFunctionInvocationFilter
184+
{
185+
public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, Func<FunctionInvocationContext, Task> next)
186+
{
187+
System.Console.WriteLine($"FILTER INVOKED {nameof(InvocationFilter)} - {context.Function.Name}");
188+
189+
// Execution the function
190+
await next(context);
191+
192+
// Signal termination if the function is from the MenuPlugin
193+
if (context.Function.PluginName == nameof(MenuPlugin))
194+
{
195+
context.Result = new FunctionResult(context.Function, "BLOCKED");
196+
}
197+
}
198+
}
199+
200+
private sealed class AutoInvocationFilter(bool terminate = true) : IAutoFunctionInvocationFilter
201+
{
202+
public async Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func<AutoFunctionInvocationContext, Task> next)
203+
{
204+
System.Console.WriteLine($"FILTER INVOKED {nameof(AutoInvocationFilter)} - {context.Function.Name}");
205+
206+
// Execution the function
207+
await next(context);
208+
209+
// Signal termination if the function is from the MenuPlugin
210+
if (context.Function.PluginName == nameof(MenuPlugin))
211+
{
212+
context.Terminate = terminate;
213+
}
214+
}
215+
}
216+
}

dotnet/src/Agents/OpenAI/Agents.OpenAI.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/src/System/IListExtensions.cs" Link="%(RecursiveDir)System/%(Filename)%(Extension)" />
2727
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/src/System/AppContextSwitchHelper.cs" Link="%(RecursiveDir)Utilities/%(Filename)%(Extension)" />
2828
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/src/Functions/FunctionName.cs" Link="%(RecursiveDir)Utilities/%(Filename)%(Extension)" />
29+
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/connectors/AI/**/*.cs" Link="%(RecursiveDir)%(Filename)%(Extension)" />
2930
</ItemGroup>
3031

3132
<ItemGroup>

dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
using Azure;
1212
using Microsoft.Extensions.Logging;
1313
using Microsoft.SemanticKernel.ChatCompletion;
14+
using Microsoft.SemanticKernel.Connectors.FunctionCalling;
1415
using OpenAI.Assistants;
1516

1617
namespace Microsoft.SemanticKernel.Agents.OpenAI.Internal;
@@ -177,6 +178,10 @@ public static async IAsyncEnumerable<ChatMessageContent> GetMessagesAsync(Assist
177178

178179
logger.LogOpenAIAssistantCreatedRun(nameof(InvokeAsync), run.Id, threadId);
179180

181+
FunctionCallsProcessor functionProcessor = new(logger);
182+
// This matches current behavior. Will be configurable upon integrating with `FunctionChoice` (#6795/#5200)
183+
FunctionChoiceBehaviorOptions functionOptions = new() { AllowConcurrentInvocation = true, AllowParallelCalls = true };
184+
180185
// Evaluate status and process steps and messages, as encountered.
181186
HashSet<string> processedStepIds = [];
182187
Dictionary<string, FunctionResultContent> functionSteps = [];
@@ -206,13 +211,18 @@ public static async IAsyncEnumerable<ChatMessageContent> GetMessagesAsync(Assist
206211
if (functionCalls.Length > 0)
207212
{
208213
// Emit function-call content
209-
yield return (IsVisible: false, Message: GenerateFunctionCallContent(agent.GetName(), functionCalls));
214+
ChatMessageContent functionCallMessage = GenerateFunctionCallContent(agent.GetName(), functionCalls);
215+
yield return (IsVisible: false, Message: functionCallMessage);
210216

211217
// Invoke functions for each tool-step
212-
IEnumerable<Task<FunctionResultContent>> functionResultTasks = ExecuteFunctionSteps(agent, functionCalls, cancellationToken);
213-
214-
// Block for function results
215-
FunctionResultContent[] functionResults = await Task.WhenAll(functionResultTasks).ConfigureAwait(false);
218+
FunctionResultContent[] functionResults =
219+
await functionProcessor.InvokeFunctionCallsAsync(
220+
functionCallMessage,
221+
(_) => true,
222+
functionOptions,
223+
kernel,
224+
isStreaming: false,
225+
cancellationToken).ToArrayAsync(cancellationToken).ConfigureAwait(false);
216226

217227
// Capture function-call for message processing
218228
foreach (FunctionResultContent functionCall in functionResults)
@@ -402,6 +412,10 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamin
402412
List<RunStep> stepsToProcess = [];
403413
ThreadRun? run = null;
404414

415+
FunctionCallsProcessor functionProcessor = new(logger);
416+
// This matches current behavior. Will be configurable upon integrating with `FunctionChoice` (#6795/#5200)
417+
FunctionChoiceBehaviorOptions functionOptions = new() { AllowConcurrentInvocation = true, AllowParallelCalls = true };
418+
405419
IAsyncEnumerable<StreamingUpdate> asyncUpdates = client.CreateRunStreamingAsync(threadId, agent.Id, options, cancellationToken);
406420
do
407421
{
@@ -495,13 +509,17 @@ await client.GetRunStepsAsync(run.ThreadId, run.Id, cancellationToken: cancellat
495509
if (functionCalls.Length > 0)
496510
{
497511
// Emit function-call content
498-
messages?.Add(GenerateFunctionCallContent(agent.GetName(), functionCalls));
499-
500-
// Invoke functions for each tool-step
501-
IEnumerable<Task<FunctionResultContent>> functionResultTasks = ExecuteFunctionSteps(agent, functionCalls, cancellationToken);
502-
503-
// Block for function results
504-
FunctionResultContent[] functionResults = await Task.WhenAll(functionResultTasks).ConfigureAwait(false);
512+
ChatMessageContent functionCallMessage = GenerateFunctionCallContent(agent.GetName(), functionCalls);
513+
messages?.Add(functionCallMessage);
514+
515+
FunctionResultContent[] functionResults =
516+
await functionProcessor.InvokeFunctionCallsAsync(
517+
functionCallMessage,
518+
(_) => true,
519+
functionOptions,
520+
kernel,
521+
isStreaming: true,
522+
cancellationToken).ToArrayAsync(cancellationToken).ConfigureAwait(false);
505523

506524
// Process tool output
507525
ToolOutput[] toolOutputs = GenerateToolOutputs(functionResults);

dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeAllTypes.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone;
88

99
#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable.
1010
public record PineconeAllTypes()
11-
#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable.
1211
{
1312
[VectorStoreRecordKey]
1413
public string Id { get; init; }
@@ -62,3 +61,4 @@ public record PineconeAllTypes()
6261
[VectorStoreRecordVector(Dimensions: 8, DistanceFunction: DistanceFunction.DotProductSimilarity)]
6362
public ReadOnlyMemory<float>? Embedding { get; set; }
6463
}
64+
#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable.

dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeHotel.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone;
99

1010
#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable.
1111
public record PineconeHotel()
12-
#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable.
12+
1313
{
1414
[VectorStoreRecordKey]
1515
public string HotelId { get; init; }
@@ -37,3 +37,4 @@ public record PineconeHotel()
3737
[VectorStoreRecordVector(Dimensions: 8, DistanceFunction: DistanceFunction.DotProductSimilarity)]
3838
public ReadOnlyMemory<float> DescriptionEmbedding { get; set; }
3939
}
40+
#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring as nullable.

0 commit comments

Comments
 (0)