Skip to content

Commit b19cf20

Browse files
authored
Add GetRequiredService extension for IChatClient/EmbeddingGenerator (#5930)
* Add GetRequiredService extension for IChatClient/EmbeddingGenerator * Add non-generic GetRequiredService
1 parent 66eca03 commit b19cf20

File tree

8 files changed

+296
-5
lines changed

8 files changed

+296
-5
lines changed

src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System;
45
using System.Collections.Generic;
56
using System.Threading;
67
using System.Threading.Tasks;
@@ -16,6 +17,7 @@ public static class ChatClientExtensions
1617
/// <param name="client">The client.</param>
1718
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
1819
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
20+
/// <exception cref="ArgumentNullException"><paramref name="client"/> is <see langword="null"/>.</exception>
1921
/// <remarks>
2022
/// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the <see cref="IChatClient"/>,
2123
/// including itself or any services it might be wrapping.
@@ -24,7 +26,58 @@ public static class ChatClientExtensions
2426
{
2527
_ = Throw.IfNull(client);
2628

27-
return (TService?)client.GetService(typeof(TService), serviceKey);
29+
return client.GetService(typeof(TService), serviceKey) is TService service ? service : default;
30+
}
31+
32+
/// <summary>
33+
/// Asks the <see cref="IChatClient"/> for an object of the specified type <paramref name="serviceType"/>
34+
/// and throws an exception if one isn't available.
35+
/// </summary>
36+
/// <param name="client">The client.</param>
37+
/// <param name="serviceType">The type of object being requested.</param>
38+
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
39+
/// <returns>The found object.</returns>
40+
/// <exception cref="ArgumentNullException"><paramref name="client"/> is <see langword="null"/>.</exception>
41+
/// <exception cref="ArgumentNullException"><paramref name="serviceType"/> is <see langword="null"/>.</exception>
42+
/// <exception cref="InvalidOperationException">No service of the requested type for the specified key is available.</exception>
43+
/// <remarks>
44+
/// The purpose of this method is to allow for the retrieval of services that are required to be provided by the <see cref="IChatClient"/>,
45+
/// including itself or any services it might be wrapping.
46+
/// </remarks>
47+
public static object GetRequiredService(this IChatClient client, Type serviceType, object? serviceKey = null)
48+
{
49+
_ = Throw.IfNull(client);
50+
_ = Throw.IfNull(serviceType);
51+
52+
return
53+
client.GetService(serviceType, serviceKey) ??
54+
throw Throw.CreateMissingServiceException(serviceType, serviceKey);
55+
}
56+
57+
/// <summary>
58+
/// Asks the <see cref="IChatClient"/> for an object of type <typeparamref name="TService"/>
59+
/// and throws an exception if one isn't available.
60+
/// </summary>
61+
/// <typeparam name="TService">The type of the object to be retrieved.</typeparam>
62+
/// <param name="client">The client.</param>
63+
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
64+
/// <returns>The found object.</returns>
65+
/// <exception cref="ArgumentNullException"><paramref name="client"/> is <see langword="null"/>.</exception>
66+
/// <exception cref="InvalidOperationException">No service of the requested type for the specified key is available.</exception>
67+
/// <remarks>
68+
/// The purpose of this method is to allow for the retrieval of strongly typed services that are required to be provided by the <see cref="IChatClient"/>,
69+
/// including itself or any services it might be wrapping.
70+
/// </remarks>
71+
public static TService GetRequiredService<TService>(this IChatClient client, object? serviceKey = null)
72+
{
73+
_ = Throw.IfNull(client);
74+
75+
if (client.GetService(typeof(TService), serviceKey) is not TService service)
76+
{
77+
throw Throw.CreateMissingServiceException(typeof(TService), serviceKey);
78+
}
79+
80+
return service;
2881
}
2982

3083
/// <summary>Sends a user chat text message and returns the response messages.</summary>
@@ -33,6 +86,8 @@ public static class ChatClientExtensions
3386
/// <param name="options">The chat options to configure the request.</param>
3487
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
3588
/// <returns>The response messages generated by the client.</returns>
89+
/// <exception cref="ArgumentNullException"><paramref name="client"/> is <see langword="null"/>.</exception>
90+
/// <exception cref="ArgumentNullException"><paramref name="chatMessage"/> is <see langword="null"/>.</exception>
3691
public static Task<ChatResponse> GetResponseAsync(
3792
this IChatClient client,
3893
string chatMessage,
@@ -51,6 +106,8 @@ public static Task<ChatResponse> GetResponseAsync(
51106
/// <param name="options">The chat options to configure the request.</param>
52107
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
53108
/// <returns>The response messages generated by the client.</returns>
109+
/// <exception cref="ArgumentNullException"><paramref name="client"/> is <see langword="null"/>.</exception>
110+
/// <exception cref="ArgumentNullException"><paramref name="chatMessage"/> is <see langword="null"/>.</exception>
54111
public static Task<ChatResponse> GetResponseAsync(
55112
this IChatClient client,
56113
ChatMessage chatMessage,
@@ -69,6 +126,8 @@ public static Task<ChatResponse> GetResponseAsync(
69126
/// <param name="options">The chat options to configure the request.</param>
70127
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
71128
/// <returns>The response messages generated by the client.</returns>
129+
/// <exception cref="ArgumentNullException"><paramref name="client"/> is <see langword="null"/>.</exception>
130+
/// <exception cref="ArgumentNullException"><paramref name="chatMessage"/> is <see langword="null"/>.</exception>
72131
public static IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
73132
this IChatClient client,
74133
string chatMessage,
@@ -87,6 +146,8 @@ public static IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
87146
/// <param name="options">The chat options to configure the request.</param>
88147
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
89148
/// <returns>The response messages generated by the client.</returns>
149+
/// <exception cref="ArgumentNullException"><paramref name="client"/> is <see langword="null"/>.</exception>
150+
/// <exception cref="ArgumentNullException"><paramref name="chatMessage"/> is <see langword="null"/>.</exception>
90151
public static IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
91152
this IChatClient client,
92153
ChatMessage chatMessage,

src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
5757
/// <param name="serviceType">The type of object being requested.</param>
5858
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
5959
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
60+
/// <exception cref="ArgumentNullException"><paramref name="serviceType"/> is <see langword="null"/>.</exception>
6061
/// <remarks>
6162
/// The purpose of this method is to allow for the retrieval of strongly typed services that might be provided by the <see cref="IChatClient"/>,
6263
/// including itself or any services it might be wrapping.

src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using Microsoft.Shared.Diagnostics;
1010

1111
#pragma warning disable S2302 // "nameof" should be used
12+
#pragma warning disable S4136 // Method overloads should be grouped together
1213

1314
namespace Microsoft.Extensions.AI;
1415

@@ -22,19 +23,80 @@ public static class EmbeddingGeneratorExtensions
2223
/// <param name="generator">The generator.</param>
2324
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
2425
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
26+
/// <exception cref="ArgumentNullException"><paramref name="generator"/> is <see langword="null"/>.</exception>
2527
/// <remarks>
2628
/// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the
2729
/// <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/>, including itself or any services it might be wrapping.
2830
/// </remarks>
29-
public static TService? GetService<TInput, TEmbedding, TService>(this IEmbeddingGenerator<TInput, TEmbedding> generator, object? serviceKey = null)
31+
public static TService? GetService<TInput, TEmbedding, TService>(
32+
this IEmbeddingGenerator<TInput, TEmbedding> generator, object? serviceKey = null)
3033
where TEmbedding : Embedding
3134
{
3235
_ = Throw.IfNull(generator);
3336

34-
return (TService?)generator.GetService(typeof(TService), serviceKey);
37+
return generator.GetService(typeof(TService), serviceKey) is TService service ? service : default;
3538
}
3639

37-
// The following overload exists purely to work around the lack of partial generic type inference.
40+
/// <summary>
41+
/// Asks the <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/> for an object of the specified type <paramref name="serviceType"/>
42+
/// and throws an exception if one isn't available.
43+
/// </summary>
44+
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
45+
/// <typeparam name="TEmbedding">The numeric type of the embedding data.</typeparam>
46+
/// <param name="generator">The generator.</param>
47+
/// <param name="serviceType">The type of object being requested.</param>
48+
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
49+
/// <returns>The found object.</returns>
50+
/// <exception cref="ArgumentNullException"><paramref name="generator"/> is <see langword="null"/>.</exception>
51+
/// <exception cref="ArgumentNullException"><paramref name="serviceType"/> is <see langword="null"/>.</exception>
52+
/// <exception cref="InvalidOperationException">No service of the requested type for the specified key is available.</exception>
53+
/// <remarks>
54+
/// The purpose of this method is to allow for the retrieval of services that are required to be provided by the
55+
/// <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/>, including itself or any services it might be wrapping.
56+
/// </remarks>
57+
public static object GetRequiredService<TInput, TEmbedding>(
58+
this IEmbeddingGenerator<TInput, TEmbedding> generator, Type serviceType, object? serviceKey = null)
59+
where TEmbedding : Embedding
60+
{
61+
_ = Throw.IfNull(generator);
62+
_ = Throw.IfNull(serviceType);
63+
64+
return
65+
generator.GetService(serviceType, serviceKey) ??
66+
throw Throw.CreateMissingServiceException(serviceType, serviceKey);
67+
}
68+
69+
/// <summary>
70+
/// Asks the <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/> for an object of type <typeparamref name="TService"/>
71+
/// and throws an exception if one isn't available.
72+
/// </summary>
73+
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
74+
/// <typeparam name="TEmbedding">The numeric type of the embedding data.</typeparam>
75+
/// <typeparam name="TService">The type of the object to be retrieved.</typeparam>
76+
/// <param name="generator">The generator.</param>
77+
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
78+
/// <returns>The found object.</returns>
79+
/// <exception cref="ArgumentNullException"><paramref name="generator"/> is <see langword="null"/>.</exception>
80+
/// <exception cref="InvalidOperationException">No service of the requested type for the specified key is available.</exception>
81+
/// <remarks>
82+
/// The purpose of this method is to allow for the retrieval of strongly typed services that are required to be provided by the
83+
/// <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/>, including itself or any services it might be wrapping.
84+
/// </remarks>
85+
public static TService GetRequiredService<TInput, TEmbedding, TService>(
86+
this IEmbeddingGenerator<TInput, TEmbedding> generator, object? serviceKey = null)
87+
where TEmbedding : Embedding
88+
{
89+
_ = Throw.IfNull(generator);
90+
91+
if (generator.GetService(typeof(TService), serviceKey) is not TService service)
92+
{
93+
throw Throw.CreateMissingServiceException(typeof(TService), serviceKey);
94+
}
95+
96+
return service;
97+
}
98+
99+
// The following overloads exist purely to work around the lack of partial generic type inference.
38100
// Given an IEmbeddingGenerator<TInput, TEmbedding> generator, to call GetService with TService, you still need
39101
// to re-specify both TInput and TEmbedding, e.g. generator.GetService<string, Embedding<float>, TService>.
40102
// The case of string/Embedding<float> is by far the most common case today, so this overload exists as an
@@ -45,13 +107,31 @@ public static class EmbeddingGeneratorExtensions
45107
/// <param name="generator">The generator.</param>
46108
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
47109
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
110+
/// <exception cref="ArgumentNullException"><paramref name="generator"/> is <see langword="null"/>.</exception>
48111
/// <remarks>
49112
/// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the
50113
/// <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/>, including itself or any services it might be wrapping.
51114
/// </remarks>
52115
public static TService? GetService<TService>(this IEmbeddingGenerator<string, Embedding<float>> generator, object? serviceKey = null) =>
53116
GetService<string, Embedding<float>, TService>(generator, serviceKey);
54117

118+
/// <summary>
119+
/// Asks the <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/> for an object of type <typeparamref name="TService"/>
120+
/// and throws an exception if one isn't available.
121+
/// </summary>
122+
/// <typeparam name="TService">The type of the object to be retrieved.</typeparam>
123+
/// <param name="generator">The generator.</param>
124+
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
125+
/// <returns>The found object.</returns>
126+
/// <exception cref="ArgumentNullException"><paramref name="generator"/> is <see langword="null"/>.</exception>
127+
/// <exception cref="InvalidOperationException">No service of the requested type for the specified key is available.</exception>
128+
/// <remarks>
129+
/// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the
130+
/// <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/>, including itself or any services it might be wrapping.
131+
/// </remarks>
132+
public static TService GetRequiredService<TService>(this IEmbeddingGenerator<string, Embedding<float>> generator, object? serviceKey = null) =>
133+
GetRequiredService<string, Embedding<float>, TService>(generator, serviceKey);
134+
55135
/// <summary>Generates an embedding vector from the specified <paramref name="value"/>.</summary>
56136
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
57137
/// <typeparam name="TEmbeddingElement">The numeric type of the embedding data.</typeparam>
@@ -60,6 +140,9 @@ public static class EmbeddingGeneratorExtensions
60140
/// <param name="options">The embedding generation options to configure the request.</param>
61141
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
62142
/// <returns>The generated embedding for the specified <paramref name="value"/>.</returns>
143+
/// <exception cref="ArgumentNullException"><paramref name="generator"/> is <see langword="null"/>.</exception>
144+
/// <exception cref="ArgumentNullException"><paramref name="value"/> is <see langword="null"/>.</exception>
145+
/// <exception cref="InvalidOperationException">The generator did not produce exactly one embedding.</exception>
63146
/// <remarks>
64147
/// This operation is equivalent to using <see cref="GenerateEmbeddingAsync"/> and returning the
65148
/// resulting <see cref="Embedding{T}"/>'s <see cref="Embedding{T}.Vector"/> property.
@@ -84,6 +167,9 @@ public static async Task<ReadOnlyMemory<TEmbeddingElement>> GenerateEmbeddingVec
84167
/// <returns>
85168
/// The generated embedding for the specified <paramref name="value"/>.
86169
/// </returns>
170+
/// <exception cref="ArgumentNullException"><paramref name="generator"/> is <see langword="null"/>.</exception>
171+
/// <exception cref="ArgumentNullException"><paramref name="value"/> is <see langword="null"/>.</exception>
172+
/// <exception cref="InvalidOperationException">The generator did not produce exactly one embedding.</exception>
87173
/// <remarks>
88174
/// This operations is equivalent to using <see cref="IEmbeddingGenerator{TInput, TEmbedding}.GenerateAsync"/> with a
89175
/// collection composed of the single <paramref name="value"/> and then returning the first embedding element from the
@@ -125,6 +211,9 @@ public static async Task<TEmbedding> GenerateEmbeddingAsync<TInput, TEmbedding>(
125211
/// <param name="options">The embedding generation options to configure the request.</param>
126212
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
127213
/// <returns>An array containing tuples of the input values and the associated generated embeddings.</returns>
214+
/// <exception cref="ArgumentNullException"><paramref name="generator"/> is <see langword="null"/>.</exception>
215+
/// <exception cref="ArgumentNullException"><paramref name="values"/> is <see langword="null"/>.</exception>
216+
/// <exception cref="InvalidOperationException">The generator did not produce one embedding for each input value.</exception>
128217
public static async Task<(TInput Value, TEmbedding Embedding)[]> GenerateAndZipAsync<TInput, TEmbedding>(
129218
this IEmbeddingGenerator<TInput, TEmbedding> generator,
130219
IEnumerable<TInput> values,

src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ Task<GeneratedEmbeddings<TEmbedding>> GenerateAsync(
4141
/// <param name="serviceType">The type of object being requested.</param>
4242
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
4343
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
44+
/// <exception cref="ArgumentNullException"><paramref name="serviceType"/> is <see langword="null"/>.</exception>
4445
/// <remarks>
4546
/// The purpose of this method is to allow for the retrieval of strongly typed services that might be provided by the
4647
/// <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/>, including itself or any services it might be wrapping.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
6+
namespace Microsoft.Shared.Diagnostics;
7+
8+
internal static partial class Throw
9+
{
10+
/// <summary>Throws an exception indicating that a required service is not available.</summary>
11+
public static InvalidOperationException CreateMissingServiceException(Type serviceType, object? serviceKey) =>
12+
new InvalidOperationException(serviceKey is null ?
13+
$"No service of type '{serviceType}' is available." :
14+
$"No service of type '{serviceType}' for the key '{serviceKey}' is available.");
15+
}

0 commit comments

Comments
 (0)