Skip to content

Commit a0e6496

Browse files
authored
[MEDI] Introduce set of built-in Enrichers (#6957)
1 parent 8e4a3e2 commit a0e6496

File tree

13 files changed

+1100
-3
lines changed

13 files changed

+1100
-3
lines changed

src/Libraries/Microsoft.Extensions.DataIngestion/Microsoft.Extensions.DataIngestion.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
</ItemGroup>
1919

2020
<ItemGroup>
21+
<PackageReference Include="System.Collections.Immutable" Condition="'$(TargetFrameworkIdentifier)' != '.NETCoreApp'" />
2122
<PackageReference Include="Microsoft.Extensions.VectorData.Abstractions" />
2223
</ItemGroup>
2324

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Collections.Frozen;
6+
using System.Collections.Generic;
7+
using System.Runtime.CompilerServices;
8+
using System.Text;
9+
using System.Threading;
10+
using System.Threading.Tasks;
11+
using Microsoft.Extensions.AI;
12+
using Microsoft.Shared.Diagnostics;
13+
14+
namespace Microsoft.Extensions.DataIngestion;
15+
16+
/// <summary>
17+
/// Enriches document chunks with a classification label based on their content.
18+
/// </summary>
19+
/// <remarks>This class uses a chat-based language model to analyze the content of document chunks and assign a
20+
/// single, most relevant classification label. The classification is performed using a predefined set of classes, with
21+
/// an optional fallback class for cases where no suitable classification can be determined.</remarks>
22+
public sealed class ClassificationEnricher : IngestionChunkProcessor<string>
23+
{
24+
private readonly IChatClient _chatClient;
25+
private readonly ChatOptions? _chatOptions;
26+
private readonly FrozenSet<string> _predefinedClasses;
27+
private readonly ChatMessage _systemPrompt;
28+
29+
/// <summary>
30+
/// Initializes a new instance of the <see cref="ClassificationEnricher"/> class.
31+
/// </summary>
32+
/// <param name="chatClient">The chat client used for classification.</param>
33+
/// <param name="predefinedClasses">The set of predefined classification classes.</param>
34+
/// <param name="chatOptions">Options for the chat client.</param>
35+
/// <param name="fallbackClass">The fallback class to use when no suitable classification is found. When not provided, it defaults to "Unknown".</param>
36+
public ClassificationEnricher(IChatClient chatClient, ReadOnlySpan<string> predefinedClasses,
37+
ChatOptions? chatOptions = null, string? fallbackClass = null)
38+
{
39+
_chatClient = Throw.IfNull(chatClient);
40+
_chatOptions = chatOptions;
41+
if (string.IsNullOrWhiteSpace(fallbackClass))
42+
{
43+
fallbackClass = "Unknown";
44+
}
45+
46+
_predefinedClasses = CreatePredefinedSet(predefinedClasses, fallbackClass!);
47+
_systemPrompt = CreateSystemPrompt(predefinedClasses, fallbackClass!);
48+
}
49+
50+
/// <summary>
51+
/// Gets the metadata key used to store the classification.
52+
/// </summary>
53+
public static string MetadataKey => "classification";
54+
55+
/// <inheritdoc />
56+
public override async IAsyncEnumerable<IngestionChunk<string>> ProcessAsync(IAsyncEnumerable<IngestionChunk<string>> chunks,
57+
[EnumeratorCancellation] CancellationToken cancellationToken = default)
58+
{
59+
_ = Throw.IfNull(chunks);
60+
61+
await foreach (IngestionChunk<string> chunk in chunks.WithCancellation(cancellationToken))
62+
{
63+
var response = await _chatClient.GetResponseAsync(
64+
[
65+
_systemPrompt,
66+
new(ChatRole.User, chunk.Content)
67+
], _chatOptions, cancellationToken: cancellationToken).ConfigureAwait(false);
68+
69+
chunk.Metadata[MetadataKey] = _predefinedClasses.Contains(response.Text)
70+
? response.Text
71+
: throw new InvalidOperationException($"Classification returned an unexpected class: '{response.Text}'.");
72+
73+
yield return chunk;
74+
}
75+
}
76+
77+
private static FrozenSet<string> CreatePredefinedSet(ReadOnlySpan<string> predefinedClasses, string fallbackClass)
78+
{
79+
if (predefinedClasses.Length == 0)
80+
{
81+
Throw.ArgumentException(nameof(predefinedClasses), "Predefined classes must be provided.");
82+
}
83+
84+
HashSet<string> predefinedClassesSet = new(StringComparer.Ordinal) { fallbackClass };
85+
foreach (string predefinedClass in predefinedClasses)
86+
{
87+
#if NET
88+
if (predefinedClass.Contains(',', StringComparison.Ordinal))
89+
#else
90+
if (predefinedClass.IndexOf(',') >= 0)
91+
#endif
92+
{
93+
Throw.ArgumentException(nameof(predefinedClasses), $"Predefined class '{predefinedClass}' must not contain ',' character.");
94+
}
95+
96+
if (!predefinedClassesSet.Add(predefinedClass))
97+
{
98+
if (predefinedClass.Equals(fallbackClass, StringComparison.Ordinal))
99+
{
100+
Throw.ArgumentException(nameof(predefinedClasses), $"Fallback class '{fallbackClass}' must not be one of the predefined classes.");
101+
}
102+
103+
Throw.ArgumentException(nameof(predefinedClasses), $"Duplicate class found: '{predefinedClass}'.");
104+
}
105+
}
106+
107+
return predefinedClassesSet.ToFrozenSet();
108+
}
109+
110+
private static ChatMessage CreateSystemPrompt(ReadOnlySpan<string> predefinedClasses, string fallbackClass)
111+
{
112+
StringBuilder sb = new("You are a classification expert. Analyze the given text and assign a single, most relevant class. Use only the following predefined classes: ");
113+
114+
#if NET9_0_OR_GREATER
115+
sb.AppendJoin(", ", predefinedClasses!);
116+
#else
117+
#pragma warning disable IDE0058 // Expression value is never used
118+
for (int i = 0; i < predefinedClasses.Length; i++)
119+
{
120+
sb.Append(predefinedClasses[i]);
121+
if (i < predefinedClasses.Length - 1)
122+
{
123+
sb.Append(", ");
124+
}
125+
}
126+
#endif
127+
sb.Append(" and return ").Append(fallbackClass).Append(" when unable to classify.");
128+
#pragma warning restore IDE0058 // Expression value is never used
129+
130+
return new(ChatRole.System, sb.ToString());
131+
}
132+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Threading;
6+
using System.Threading.Tasks;
7+
using Microsoft.Extensions.AI;
8+
using Microsoft.Shared.Diagnostics;
9+
10+
namespace Microsoft.Extensions.DataIngestion;
11+
12+
/// <summary>
13+
/// Enriches <see cref="IngestionDocumentImage"/> elements with alternative text using an AI service,
14+
/// so the generated embeddings can include the image content information.
15+
/// </summary>
16+
public sealed class ImageAlternativeTextEnricher : IngestionDocumentProcessor
17+
{
18+
private readonly IChatClient _chatClient;
19+
private readonly ChatOptions? _chatOptions;
20+
private readonly ChatMessage _systemPrompt;
21+
22+
/// <summary>
23+
/// Initializes a new instance of the <see cref="ImageAlternativeTextEnricher"/> class.
24+
/// </summary>
25+
/// <param name="chatClient">The chat client used to get responses for generating alternative text.</param>
26+
/// <param name="chatOptions">Options for the chat client.</param>
27+
public ImageAlternativeTextEnricher(IChatClient chatClient, ChatOptions? chatOptions = null)
28+
{
29+
_chatClient = Throw.IfNull(chatClient);
30+
_chatOptions = chatOptions;
31+
_systemPrompt = new(ChatRole.System, "Write a detailed alternative text for this image with less than 50 words.");
32+
}
33+
34+
/// <inheritdoc/>
35+
public override async Task<IngestionDocument> ProcessAsync(IngestionDocument document, CancellationToken cancellationToken = default)
36+
{
37+
_ = Throw.IfNull(document);
38+
39+
foreach (var element in document.EnumerateContent())
40+
{
41+
if (element is IngestionDocumentImage image)
42+
{
43+
await ProcessAsync(image, cancellationToken).ConfigureAwait(false);
44+
}
45+
else if (element is IngestionDocumentTable table)
46+
{
47+
foreach (var cell in table.Cells)
48+
{
49+
if (cell is IngestionDocumentImage cellImage)
50+
{
51+
await ProcessAsync(cellImage, cancellationToken).ConfigureAwait(false);
52+
}
53+
}
54+
}
55+
}
56+
57+
return document;
58+
}
59+
60+
private async Task ProcessAsync(IngestionDocumentImage image, CancellationToken cancellationToken)
61+
{
62+
if (image.Content.HasValue && !string.IsNullOrEmpty(image.MediaType)
63+
&& string.IsNullOrEmpty(image.AlternativeText))
64+
{
65+
var response = await _chatClient.GetResponseAsync(
66+
[
67+
_systemPrompt,
68+
new(ChatRole.User, [new DataContent(image.Content.Value, image.MediaType!)])
69+
], _chatOptions, cancellationToken: cancellationToken).ConfigureAwait(false);
70+
71+
image.AlternativeText = response.Text;
72+
}
73+
}
74+
}
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Collections.Frozen;
6+
using System.Collections.Generic;
7+
using System.Runtime.CompilerServices;
8+
using System.Text;
9+
using System.Threading;
10+
using System.Threading.Tasks;
11+
using Microsoft.Extensions.AI;
12+
using Microsoft.Shared.Diagnostics;
13+
14+
namespace Microsoft.Extensions.DataIngestion;
15+
16+
/// <summary>
17+
/// Enriches chunks with keyword extraction using an AI chat model.
18+
/// </summary>
19+
/// <remarks>
20+
/// It adds "keywords" metadata to each chunk. It's an array of strings representing the extracted keywords.
21+
/// </remarks>
22+
public sealed class KeywordEnricher : IngestionChunkProcessor<string>
23+
{
24+
private const int DefaultMaxKeywords = 5;
25+
#if NET
26+
private static readonly System.Buffers.SearchValues<char> _illegalCharacters = System.Buffers.SearchValues.Create([';', ',']);
27+
#else
28+
private static readonly char[] _illegalCharacters = [';', ','];
29+
#endif
30+
private readonly IChatClient _chatClient;
31+
private readonly ChatOptions? _chatOptions;
32+
private readonly FrozenSet<string>? _predefinedKeywords;
33+
private readonly ChatMessage _systemPrompt;
34+
35+
/// <summary>
36+
/// Initializes a new instance of the <see cref="KeywordEnricher"/> class.
37+
/// </summary>
38+
/// <param name="chatClient">The chat client used for keyword extraction.</param>
39+
/// <param name="predefinedKeywords">The set of predefined keywords for extraction.</param>
40+
/// <param name="chatOptions">Options for the chat client.</param>
41+
/// <param name="maxKeywords">The maximum number of keywords to extract. When not provided, it defaults to 5.</param>
42+
/// <param name="confidenceThreshold">The confidence threshold for keyword inclusion. When not provided, it defaults to 0.7.</param>
43+
/// <remarks>
44+
/// If no predefined keywords are provided, the model will extract keywords based on the content alone.
45+
/// Such results may vary more significantly between different AI models.
46+
/// </remarks>
47+
public KeywordEnricher(IChatClient chatClient, ReadOnlySpan<string> predefinedKeywords,
48+
ChatOptions? chatOptions = null, int? maxKeywords = null, double? confidenceThreshold = null)
49+
{
50+
_chatClient = Throw.IfNull(chatClient);
51+
_chatOptions = chatOptions;
52+
_predefinedKeywords = CreatePredfinedKeywords(predefinedKeywords);
53+
54+
double threshold = confidenceThreshold.HasValue
55+
? Throw.IfOutOfRange(confidenceThreshold.Value, 0.0, 1.0, nameof(confidenceThreshold))
56+
: 0.7;
57+
int keywordsCount = maxKeywords.HasValue
58+
? Throw.IfLessThanOrEqual(maxKeywords.Value, 0, nameof(maxKeywords))
59+
: DefaultMaxKeywords;
60+
_systemPrompt = CreateSystemPrompt(keywordsCount, predefinedKeywords, threshold);
61+
}
62+
63+
/// <summary>
64+
/// Gets the metadata key used to store the keywords.
65+
/// </summary>
66+
public static string MetadataKey => "keywords";
67+
68+
/// <inheritdoc/>
69+
public override async IAsyncEnumerable<IngestionChunk<string>> ProcessAsync(IAsyncEnumerable<IngestionChunk<string>> chunks,
70+
[EnumeratorCancellation] CancellationToken cancellationToken = default)
71+
{
72+
_ = Throw.IfNull(chunks);
73+
74+
await foreach (IngestionChunk<string> chunk in chunks.WithCancellation(cancellationToken))
75+
{
76+
// Structured response is not used here because it's not part of Microsoft.Extensions.AI.Abstractions.
77+
var response = await _chatClient.GetResponseAsync(
78+
[
79+
_systemPrompt,
80+
new(ChatRole.User, chunk.Content)
81+
], _chatOptions, cancellationToken: cancellationToken).ConfigureAwait(false);
82+
83+
#pragma warning disable EA0009 // Use 'System.MemoryExtensions.Split' for improved performance
84+
string[] keywords = response.Text.Split(';');
85+
if (_predefinedKeywords is not null)
86+
{
87+
foreach (var keyword in keywords)
88+
{
89+
if (!_predefinedKeywords.Contains(keyword))
90+
{
91+
throw new InvalidOperationException($"The extracted keyword '{keyword}' is not in the predefined keywords list.");
92+
}
93+
}
94+
}
95+
96+
chunk.Metadata[MetadataKey] = keywords;
97+
98+
yield return chunk;
99+
}
100+
}
101+
102+
private static FrozenSet<string>? CreatePredfinedKeywords(ReadOnlySpan<string> predefinedKeywords)
103+
{
104+
if (predefinedKeywords.Length == 0)
105+
{
106+
return null;
107+
}
108+
109+
HashSet<string> result = new(StringComparer.Ordinal);
110+
foreach (string keyword in predefinedKeywords)
111+
{
112+
#if NET
113+
if (keyword.AsSpan().ContainsAny(_illegalCharacters))
114+
#else
115+
if (keyword.IndexOfAny(_illegalCharacters) >= 0)
116+
#endif
117+
{
118+
Throw.ArgumentException(nameof(predefinedKeywords), $"Predefined keyword '{keyword}' contains an invalid character (';' or ',').");
119+
}
120+
121+
if (!result.Add(keyword))
122+
{
123+
Throw.ArgumentException(nameof(predefinedKeywords), $"Duplicate keyword found: '{keyword}'");
124+
}
125+
}
126+
127+
return result.ToFrozenSet(StringComparer.Ordinal);
128+
}
129+
130+
private static ChatMessage CreateSystemPrompt(int maxKeywords, ReadOnlySpan<string> predefinedKeywords, double confidenceThreshold)
131+
{
132+
StringBuilder sb = new($"You are a keyword extraction expert. Analyze the given text and extract up to {maxKeywords} most relevant keywords. ");
133+
134+
if (predefinedKeywords.Length > 0)
135+
{
136+
#pragma warning disable IDE0058 // Expression value is never used
137+
sb.Append("Focus on extracting keywords from the following predefined list: ");
138+
#if NET9_0_OR_GREATER
139+
sb.AppendJoin(", ", predefinedKeywords!);
140+
#else
141+
for (int i = 0; i < predefinedKeywords.Length; i++)
142+
{
143+
sb.Append(predefinedKeywords[i]);
144+
if (i < predefinedKeywords.Length - 1)
145+
{
146+
sb.Append(", ");
147+
}
148+
}
149+
#endif
150+
151+
sb.Append(". ");
152+
}
153+
154+
sb.Append("Exclude keywords with confidence score below ").Append(confidenceThreshold).Append('.');
155+
sb.Append(" Return just the keywords separated with ';'.");
156+
#pragma warning restore IDE0058 // Expression value is never used
157+
158+
return new(ChatRole.System, sb.ToString());
159+
}
160+
}

0 commit comments

Comments
 (0)