Skip to content

Add GetRequiredService extension for IChatClient/EmbeddingGenerator #5930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -16,6 +17,7 @@ public static class ChatClientExtensions
/// <param name="client">The client.</param>
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
/// <exception cref="ArgumentNullException"><paramref name="client"/> is <see langword="null"/>.</exception>
/// <remarks>
/// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the <see cref="IChatClient"/>,
/// including itself or any services it might be wrapping.
Expand All @@ -24,7 +26,58 @@ public static class ChatClientExtensions
{
_ = Throw.IfNull(client);

return (TService?)client.GetService(typeof(TService), serviceKey);
return client.GetService(typeof(TService), serviceKey) is TService service ? service : default;
}

/// <summary>
/// Asks the <see cref="IChatClient"/> for an object of the specified type <paramref name="serviceType"/>
/// and throws an exception if one isn't available.
/// </summary>
/// <param name="client">The client.</param>
/// <param name="serviceType">The type of object being requested.</param>
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
/// <returns>The found object.</returns>
/// <exception cref="ArgumentNullException"><paramref name="client"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentNullException"><paramref name="serviceType"/> is <see langword="null"/>.</exception>
/// <exception cref="InvalidOperationException">No service of the requested type for the specified key is available.</exception>
/// <remarks>
/// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the <see cref="IChatClient"/>,
/// including itself or any services it might be wrapping.
/// </remarks>
public static object GetRequiredService(this IChatClient client, Type serviceType, object? serviceKey = null)
{
_ = Throw.IfNull(client);
_ = Throw.IfNull(serviceType);

return
client.GetService(serviceType, serviceKey) ??
throw Throw.CreateMissingServiceException(serviceType, serviceKey);
}

/// <summary>
/// Asks the <see cref="IChatClient"/> for an object of type <typeparamref name="TService"/>
/// and throws an exception if one isn't available.
/// </summary>
/// <typeparam name="TService">The type of the object to be retrieved.</typeparam>
/// <param name="client">The client.</param>
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
/// <returns>The found object.</returns>
/// <exception cref="ArgumentNullException"><paramref name="client"/> is <see langword="null"/>.</exception>
/// <exception cref="InvalidOperationException">No service of the requested type for the specified key is available.</exception>
/// <remarks>
/// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the <see cref="IChatClient"/>,
/// including itself or any services it might be wrapping.
/// </remarks>
public static TService GetRequiredService<TService>(this IChatClient client, object? serviceKey = null)
{
_ = Throw.IfNull(client);

if (client.GetService(typeof(TService), serviceKey) is not TService service)
{
throw Throw.CreateMissingServiceException(typeof(TService), serviceKey);
}

return service;
}

/// <summary>Sends a user chat text message and returns the response messages.</summary>
Expand All @@ -33,6 +86,8 @@ public static class ChatClientExtensions
/// <param name="options">The chat options to configure the request.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The response messages generated by the client.</returns>
/// <exception cref="ArgumentNullException"><paramref name="client"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentNullException"><paramref name="chatMessage"/> is <see langword="null"/>.</exception>
public static Task<ChatResponse> GetResponseAsync(
this IChatClient client,
string chatMessage,
Expand All @@ -51,6 +106,8 @@ public static Task<ChatResponse> GetResponseAsync(
/// <param name="options">The chat options to configure the request.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The response messages generated by the client.</returns>
/// <exception cref="ArgumentNullException"><paramref name="client"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentNullException"><paramref name="chatMessage"/> is <see langword="null"/>.</exception>
public static Task<ChatResponse> GetResponseAsync(
this IChatClient client,
ChatMessage chatMessage,
Expand All @@ -69,6 +126,8 @@ public static Task<ChatResponse> GetResponseAsync(
/// <param name="options">The chat options to configure the request.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The response messages generated by the client.</returns>
/// <exception cref="ArgumentNullException"><paramref name="client"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentNullException"><paramref name="chatMessage"/> is <see langword="null"/>.</exception>
public static IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
this IChatClient client,
string chatMessage,
Expand All @@ -87,6 +146,8 @@ public static IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
/// <param name="options">The chat options to configure the request.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The response messages generated by the client.</returns>
/// <exception cref="ArgumentNullException"><paramref name="client"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentNullException"><paramref name="chatMessage"/> is <see langword="null"/>.</exception>
public static IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
this IChatClient client,
ChatMessage chatMessage,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
/// <param name="serviceType">The type of object being requested.</param>
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
/// <exception cref="ArgumentNullException"><paramref name="serviceType"/> is <see langword="null"/>.</exception>
/// <remarks>
/// The purpose of this method is to allow for the retrieval of strongly typed services that might be provided by the <see cref="IChatClient"/>,
/// including itself or any services it might be wrapping.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using Microsoft.Shared.Diagnostics;

#pragma warning disable S2302 // "nameof" should be used
#pragma warning disable S4136 // Method overloads should be grouped together

namespace Microsoft.Extensions.AI;

Expand All @@ -22,19 +23,80 @@ public static class EmbeddingGeneratorExtensions
/// <param name="generator">The generator.</param>
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
/// <exception cref="ArgumentNullException"><paramref name="generator"/> is <see langword="null"/>.</exception>
/// <remarks>
/// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the
/// <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/>, including itself or any services it might be wrapping.
/// </remarks>
public static TService? GetService<TInput, TEmbedding, TService>(this IEmbeddingGenerator<TInput, TEmbedding> generator, object? serviceKey = null)
public static TService? GetService<TInput, TEmbedding, TService>(
this IEmbeddingGenerator<TInput, TEmbedding> generator, object? serviceKey = null)
where TEmbedding : Embedding
{
_ = Throw.IfNull(generator);

return (TService?)generator.GetService(typeof(TService), serviceKey);
return generator.GetService(typeof(TService), serviceKey) is TService service ? service : default;
}

// The following overload exists purely to work around the lack of partial generic type inference.
/// <summary>
/// Asks the <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/> for an object of the specified type <paramref name="serviceType"/>
/// and throws an exception if one isn't available.
/// </summary>
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
/// <typeparam name="TEmbedding">The numeric type of the embedding data.</typeparam>
/// <param name="generator">The generator.</param>
/// <param name="serviceType">The type of object being requested.</param>
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
/// <returns>The found object.</returns>
/// <exception cref="ArgumentNullException"><paramref name="generator"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentNullException"><paramref name="serviceType"/> is <see langword="null"/>.</exception>
/// <exception cref="InvalidOperationException">No service of the requested type for the specified key is available.</exception>
/// <remarks>
/// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the
/// <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/>, including itself or any services it might be wrapping.
/// </remarks>
public static object GetRequiredService<TInput, TEmbedding>(
this IEmbeddingGenerator<TInput, TEmbedding> generator, Type serviceType, object? serviceKey = null)
where TEmbedding : Embedding
{
_ = Throw.IfNull(generator);
_ = Throw.IfNull(serviceType);

return
generator.GetService(serviceType, serviceKey) ??
throw Throw.CreateMissingServiceException(serviceType, serviceKey);
}

/// <summary>
/// Asks the <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/> for an object of type <typeparamref name="TService"/>
/// and throws an exception if one isn't available.
/// </summary>
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
/// <typeparam name="TEmbedding">The numeric type of the embedding data.</typeparam>
/// <typeparam name="TService">The type of the object to be retrieved.</typeparam>
/// <param name="generator">The generator.</param>
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
/// <returns>The found object.</returns>
/// <exception cref="ArgumentNullException"><paramref name="generator"/> is <see langword="null"/>.</exception>
/// <exception cref="InvalidOperationException">No service of the requested type for the specified key is available.</exception>
/// <remarks>
/// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the
/// <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/>, including itself or any services it might be wrapping.
/// </remarks>
public static TService GetRequiredService<TInput, TEmbedding, TService>(
this IEmbeddingGenerator<TInput, TEmbedding> generator, object? serviceKey = null)
where TEmbedding : Embedding
{
_ = Throw.IfNull(generator);

if (generator.GetService(typeof(TService), serviceKey) is not TService service)
{
throw Throw.CreateMissingServiceException(typeof(TService), serviceKey);
}

return service;
}

// The following overloads exist purely to work around the lack of partial generic type inference.
// Given an IEmbeddingGenerator<TInput, TEmbedding> generator, to call GetService with TService, you still need
// to re-specify both TInput and TEmbedding, e.g. generator.GetService<string, Embedding<float>, TService>.
// The case of string/Embedding<float> is by far the most common case today, so this overload exists as an
Expand All @@ -45,13 +107,31 @@ public static class EmbeddingGeneratorExtensions
/// <param name="generator">The generator.</param>
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
/// <exception cref="ArgumentNullException"><paramref name="generator"/> is <see langword="null"/>.</exception>
/// <remarks>
/// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the
/// <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/>, including itself or any services it might be wrapping.
/// </remarks>
public static TService? GetService<TService>(this IEmbeddingGenerator<string, Embedding<float>> generator, object? serviceKey = null) =>
GetService<string, Embedding<float>, TService>(generator, serviceKey);

/// <summary>
/// Asks the <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/> for an object of type <typeparamref name="TService"/>
/// and throws an exception if one isn't available.
/// </summary>
/// <typeparam name="TService">The type of the object to be retrieved.</typeparam>
/// <param name="generator">The generator.</param>
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
/// <returns>The found object.</returns>
/// <exception cref="ArgumentNullException"><paramref name="generator"/> is <see langword="null"/>.</exception>
/// <exception cref="InvalidOperationException">No service of the requested type for the specified key is available.</exception>
/// <remarks>
/// The purpose of this method is to allow for the retrieval of strongly typed services that may be provided by the
/// <see cref="IEmbeddingGenerator{TInput,TEmbedding}"/>, including itself or any services it might be wrapping.
/// </remarks>
public static TService GetRequiredService<TService>(this IEmbeddingGenerator<string, Embedding<float>> generator, object? serviceKey = null) =>
GetRequiredService<string, Embedding<float>, TService>(generator, serviceKey);

/// <summary>Generates an embedding vector from the specified <paramref name="value"/>.</summary>
/// <typeparam name="TInput">The type from which embeddings will be generated.</typeparam>
/// <typeparam name="TEmbeddingElement">The numeric type of the embedding data.</typeparam>
Expand All @@ -60,6 +140,9 @@ public static class EmbeddingGeneratorExtensions
/// <param name="options">The embedding generation options to configure the request.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The generated embedding for the specified <paramref name="value"/>.</returns>
/// <exception cref="ArgumentNullException"><paramref name="generator"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentNullException"><paramref name="value"/> is <see langword="null"/>.</exception>
/// <exception cref="InvalidOperationException">The generator did not produce exactly one embedding.</exception>
/// <remarks>
/// This operation is equivalent to using <see cref="GenerateEmbeddingAsync"/> and returning the
/// resulting <see cref="Embedding{T}"/>'s <see cref="Embedding{T}.Vector"/> property.
Expand All @@ -84,6 +167,9 @@ public static async Task<ReadOnlyMemory<TEmbeddingElement>> GenerateEmbeddingVec
/// <returns>
/// The generated embedding for the specified <paramref name="value"/>.
/// </returns>
/// <exception cref="ArgumentNullException"><paramref name="generator"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentNullException"><paramref name="value"/> is <see langword="null"/>.</exception>
/// <exception cref="InvalidOperationException">The generator did not produce exactly one embedding.</exception>
/// <remarks>
/// This operations is equivalent to using <see cref="IEmbeddingGenerator{TInput, TEmbedding}.GenerateAsync"/> with a
/// collection composed of the single <paramref name="value"/> and then returning the first embedding element from the
Expand Down Expand Up @@ -125,6 +211,9 @@ public static async Task<TEmbedding> GenerateEmbeddingAsync<TInput, TEmbedding>(
/// <param name="options">The embedding generation options to configure the request.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>An array containing tuples of the input values and the associated generated embeddings.</returns>
/// <exception cref="ArgumentNullException"><paramref name="generator"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentNullException"><paramref name="values"/> is <see langword="null"/>.</exception>
/// <exception cref="InvalidOperationException">The generator did not produce one embedding for each input value.</exception>
public static async Task<(TInput Value, TEmbedding Embedding)[]> GenerateAndZipAsync<TInput, TEmbedding>(
this IEmbeddingGenerator<TInput, TEmbedding> generator,
IEnumerable<TInput> values,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Task<GeneratedEmbeddings<TEmbedding>> GenerateAsync(
/// <param name="serviceType">The type of object being requested.</param>
/// <param name="serviceKey">An optional key that can be used to help identify the target service.</param>
/// <returns>The found object, otherwise <see langword="null"/>.</returns>
/// <exception cref="ArgumentNullException"><paramref name="serviceType"/> is <see langword="null"/>.</exception>
/// <remarks>
/// The purpose of this method is to allow for the retrieval of strongly typed services that might be provided by the
/// <see cref="IEmbeddingGenerator{TInput, TEmbedding}"/>, including itself or any services it might be wrapping.
Expand Down
15 changes: 15 additions & 0 deletions src/Libraries/Microsoft.Extensions.AI.Abstractions/Throw.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;

namespace Microsoft.Shared.Diagnostics;

internal static partial class Throw
{
/// <summary>Throws an exception indicating that a required service is not available.</summary>
public static InvalidOperationException CreateMissingServiceException(Type serviceType, object? serviceKey) =>
new InvalidOperationException(serviceKey is null ?
$"No service of type '{serviceType}' is available." :
$"No service of type '{serviceType}' for the key '{serviceKey}' is available.");
}
Loading