Skip to content

Commit 42587e4

Browse files
committed
Use async foreach to consume IAsyncEnumerable
- Flow CancellationToken to IAsyncEnumerable.GetAsyncEnumerator
1 parent b4a862a commit 42587e4

File tree

7 files changed

+236
-170
lines changed

7 files changed

+236
-170
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
// Copyright (c) .NET Foundation. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
3+
4+
using System.Collections.Generic;
5+
using System.Diagnostics;
6+
using System.Threading;
7+
using System.Threading.Channels;
8+
using System.Threading.Tasks;
9+
10+
namespace Microsoft.AspNetCore.SignalR.Internal
11+
{
12+
// True-internal because this is a weird and tricky class to use :)
13+
internal static class AsyncEnumerableAdapters
14+
{
15+
public static IAsyncEnumerable<object> MakeCancelableAsyncEnumerable<T>(IAsyncEnumerable<T> asyncEnumerable, CancellationToken cancellationToken = default)
16+
{
17+
return new CancelableAsyncEnumerable<T>(asyncEnumerable, cancellationToken);
18+
}
19+
20+
public static IAsyncEnumerable<object> GetAsyncEnumerableFromChannel<T>(ChannelReader<T> channel, CancellationToken cancellationToken = default)
21+
{
22+
return new ChannelAsyncEnumerable<T>(channel, cancellationToken);
23+
}
24+
25+
/// <summary>Converts an IAsyncEnumerable of T to an IAsyncEnumerable of object.</summary>
26+
private class CancelableAsyncEnumerable<T> : IAsyncEnumerable<object>
27+
{
28+
private readonly IAsyncEnumerable<T> _asyncEnumerable;
29+
private readonly CancellationToken _cancellationToken;
30+
31+
public CancelableAsyncEnumerable(IAsyncEnumerable<T> asyncEnumerable, CancellationToken cancellationToken)
32+
{
33+
_asyncEnumerable = asyncEnumerable;
34+
_cancellationToken = cancellationToken;
35+
}
36+
37+
public IAsyncEnumerator<object> GetAsyncEnumerator(CancellationToken cancellationToken = default)
38+
{
39+
Debug.Assert(cancellationToken == default);
40+
return new CancelableAsyncEnumerator(_asyncEnumerable.GetAsyncEnumerator(_cancellationToken));
41+
}
42+
43+
private class CancelableAsyncEnumerator : IAsyncEnumerator<object>
44+
{
45+
private IAsyncEnumerator<T> _asyncEnumerator;
46+
47+
public CancelableAsyncEnumerator(IAsyncEnumerator<T> asyncEnumerator)
48+
{
49+
_asyncEnumerator = asyncEnumerator;
50+
}
51+
52+
public object Current => _asyncEnumerator.Current;
53+
54+
public ValueTask<bool> MoveNextAsync()
55+
{
56+
return _asyncEnumerator.MoveNextAsync();
57+
}
58+
59+
public ValueTask DisposeAsync()
60+
{
61+
return _asyncEnumerator.DisposeAsync();
62+
}
63+
}
64+
}
65+
66+
/// <summary>Provides an IAsyncEnumerable of object for the data in a channel.</summary>
67+
private class ChannelAsyncEnumerable<T> : IAsyncEnumerable<object>
68+
{
69+
private readonly ChannelReader<T> _channel;
70+
private readonly CancellationToken _cancellationToken;
71+
72+
public ChannelAsyncEnumerable(ChannelReader<T> channel, CancellationToken cancellationToken)
73+
{
74+
_channel = channel;
75+
_cancellationToken = cancellationToken;
76+
}
77+
78+
public IAsyncEnumerator<object> GetAsyncEnumerator(CancellationToken cancellationToken = default)
79+
{
80+
Debug.Assert(cancellationToken == default);
81+
return new ChannelAsyncEnumerator(_channel, _cancellationToken);
82+
}
83+
84+
private class ChannelAsyncEnumerator : IAsyncEnumerator<object>
85+
{
86+
/// <summary>The channel being enumerated.</summary>
87+
private readonly ChannelReader<T> _channel;
88+
/// <summary>Cancellation token used to cancel the enumeration.</summary>
89+
private readonly CancellationToken _cancellationToken;
90+
/// <summary>The current element of the enumeration.</summary>
91+
private T _current;
92+
93+
public ChannelAsyncEnumerator(ChannelReader<T> channel, CancellationToken cancellationToken)
94+
{
95+
_channel = channel;
96+
_cancellationToken = cancellationToken;
97+
}
98+
99+
public object Current => _current;
100+
101+
public ValueTask<bool> MoveNextAsync()
102+
{
103+
var result = _channel.ReadAsync(_cancellationToken);
104+
105+
if (result.IsCompletedSuccessfully)
106+
{
107+
_current = result.Result;
108+
return new ValueTask<bool>(true);
109+
}
110+
111+
return new ValueTask<bool>(MoveNextAsyncAwaited(result));
112+
}
113+
114+
private async Task<bool> MoveNextAsyncAwaited(ValueTask<T> channelReadTask)
115+
{
116+
try
117+
{
118+
_current = await channelReadTask;
119+
}
120+
catch (ChannelClosedException ex) when (ex.InnerException == null)
121+
{
122+
return false;
123+
}
124+
125+
return true;
126+
}
127+
128+
public ValueTask DisposeAsync()
129+
{
130+
return default;
131+
}
132+
}
133+
}
134+
}
135+
}

src/SignalR/server/Core/src/Internal/AsyncEnumeratorAdapters.cs

Lines changed: 0 additions & 100 deletions
This file was deleted.

src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -263,16 +263,20 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
263263
{
264264
var result = await ExecuteHubMethod(methodExecutor, hub, arguments);
265265

266-
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, ref cts))
266+
if (result == null)
267267
{
268268
Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name);
269269
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
270-
$"The value returned by the streaming method '{methodExecutor.MethodInfo.Name}' is not a ChannelReader<>.");
270+
$"The value returned by the streaming method '{methodExecutor.MethodInfo.Name}' is not a ChannelReader<> or IAsyncEnumerable<>.");
271271
return;
272272
}
273273

274+
cts = cts ?? CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
275+
connection.ActiveRequestCancellationSources.TryAdd(hubMethodInvocationMessage.InvocationId, cts);
276+
var enumerable = descriptor.FromReturnedStream(result, cts.Token);
277+
274278
Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
275-
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, cts);
279+
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerable, scope, hubActivator, hub, cts);
276280
}
277281

278282
else if (string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
@@ -358,7 +362,7 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
358362
}
359363
}
360364

361-
private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerator<object> enumerator, IServiceScope scope,
365+
private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerable<object> enumerable, IServiceScope scope,
362366
IHubActivator<THub> hubActivator, THub hub, CancellationTokenSource streamCts)
363367
{
364368
string error = null;
@@ -367,10 +371,11 @@ private async Task StreamResultsAsync(string invocationId, HubConnectionContext
367371
{
368372
try
369373
{
370-
while (await enumerator.MoveNextAsync())
374+
375+
await foreach (var streamItem in enumerable)
371376
{
372377
// Send the stream item
373-
await connection.WriteAsync(new StreamItemMessage(invocationId, enumerator.Current));
378+
await connection.WriteAsync(new StreamItemMessage(invocationId, streamItem));
374379
}
375380
}
376381
catch (ChannelClosedException ex)
@@ -389,8 +394,6 @@ private async Task StreamResultsAsync(string invocationId, HubConnectionContext
389394
}
390395
finally
391396
{
392-
await enumerator.DisposeAsync();
393-
394397
hubActivator.Release(hub);
395398

396399
// Dispose the linked CTS for the stream.
@@ -470,10 +473,10 @@ private static async Task<bool> IsHubMethodAuthorizedSlow(IServiceProvider provi
470473
return authorizationResult.Succeeded;
471474
}
472475

473-
private async Task<bool> ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamedInvocation,
476+
private async Task<bool> ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamResponse,
474477
HubMethodInvocationMessage hubMethodInvocationMessage, HubConnectionContext connection)
475478
{
476-
if (hubMethodDescriptor.IsStreamable && !isStreamedInvocation)
479+
if (hubMethodDescriptor.IsStreamResponse && !isStreamResponse)
477480
{
478481
// Non-null/empty InvocationId? Blocking
479482
if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
@@ -486,7 +489,7 @@ await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessa
486489
return false;
487490
}
488491

489-
if (!hubMethodDescriptor.IsStreamable && isStreamedInvocation)
492+
if (!hubMethodDescriptor.IsStreamResponse && isStreamResponse)
490493
{
491494
Log.NonStreamingMethodCalledWithStream(_logger, hubMethodInvocationMessage);
492495
await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId,
@@ -498,26 +501,6 @@ await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessa
498501
return true;
499502
}
500503

501-
private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator<object> enumerator, ref CancellationTokenSource streamCts)
502-
{
503-
if (result != null)
504-
{
505-
if (hubMethodDescriptor.IsStreamable)
506-
{
507-
if (streamCts == null)
508-
{
509-
streamCts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
510-
}
511-
connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts);
512-
enumerator = hubMethodDescriptor.FromReturnedStream(result, streamCts.Token);
513-
return true;
514-
}
515-
}
516-
517-
enumerator = null;
518-
return false;
519-
}
520-
521504
private void DiscoverHubMethods()
522505
{
523506
var hubType = typeof(THub);

0 commit comments

Comments
 (0)