Skip to content

Removing CancelableAsyncEnumerable #32090

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 45 additions & 51 deletions src/SignalR/common/Shared/AsyncEnumerableAdapters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
Expand All @@ -13,38 +11,57 @@ namespace Microsoft.AspNetCore.SignalR.Internal
// True-internal because this is a weird and tricky class to use :)
internal static class AsyncEnumerableAdapters
{
public static IAsyncEnumerable<object?> MakeCancelableAsyncEnumerable<T>(IAsyncEnumerable<T> asyncEnumerable, CancellationToken cancellationToken = default)
public static IAsyncEnumerator<object?> MakeCancelableAsyncEnumerator<T>(IAsyncEnumerable<T> asyncEnumerable, CancellationToken cancellationToken = default)
{
return new CancelableAsyncEnumerable<T>(asyncEnumerable, cancellationToken);
var enumerator = asyncEnumerable.GetAsyncEnumerator(cancellationToken);
return enumerator as IAsyncEnumerator<object?> ?? new BoxedAsyncEnumerator<T>(enumerator);
}

public static IAsyncEnumerable<T> MakeCancelableTypedAsyncEnumerable<T>(IAsyncEnumerable<T> asyncEnumerable, CancellationTokenSource cts)
{
return new CancelableTypedAsyncEnumerable<T>(asyncEnumerable, cts);
}

#if NETCOREAPP
public static async IAsyncEnumerable<object?> MakeAsyncEnumerableFromChannel<T>(ChannelReader<T> channel, [EnumeratorCancellation] CancellationToken cancellationToken = default)
public static IAsyncEnumerator<object?> MakeAsyncEnumeratorFromChannel<T>(ChannelReader<T> channel, CancellationToken cancellationToken = default)
{
await foreach (var item in channel.ReadAllAsync(cancellationToken))
{
yield return item;
}
return new ChannelAsyncEnumerator<T>(channel, cancellationToken);
}
#else
// System.Threading.Channels.ReadAllAsync() is not available on netstandard2.0 and netstandard2.1
// But this is the exact same code that it uses
public static async IAsyncEnumerable<object?> MakeAsyncEnumerableFromChannel<T>(ChannelReader<T> channel, [EnumeratorCancellation] CancellationToken cancellationToken = default)

private class ChannelAsyncEnumerator<T> : IAsyncEnumerator<object?>
{
while (await channel.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
private readonly ChannelReader<T> _channel;
private readonly CancellationToken _cancellationToken;
public ChannelAsyncEnumerator(ChannelReader<T> channel, CancellationToken cancellationToken)
{
_channel = channel;
_cancellationToken = cancellationToken;
}

public object? Current { get; private set; }

public ValueTask<bool> MoveNextAsync()
{
while (channel.TryRead(out var item))
if (_channel.TryRead(out var item))
{
yield return item;
Current = item;
return new ValueTask<bool>(true);
}

return new ValueTask<bool>(MoveNextAsyncAwaited());
}

private async Task<bool> MoveNextAsyncAwaited()
{
if (await _channel.WaitToReadAsync(_cancellationToken).ConfigureAwait(false) && _channel.TryRead(out var item))
{
Current = item;
return true;
}
return false;
}

public ValueTask DisposeAsync() => default;
}
#endif

private class CancelableTypedAsyncEnumerable<TResult> : IAsyncEnumerable<TResult>
{
Expand Down Expand Up @@ -99,48 +116,25 @@ public ValueTask DisposeAsync()
}
}

/// <summary>Converts an IAsyncEnumerable of T to an IAsyncEnumerable of object.</summary>
private class CancelableAsyncEnumerable<T> : IAsyncEnumerable<object?>
private class BoxedAsyncEnumerator<T> : IAsyncEnumerator<object?>
{
private readonly IAsyncEnumerable<T> _asyncEnumerable;
private readonly CancellationToken _cancellationToken;
private IAsyncEnumerator<T> _asyncEnumerator;

public CancelableAsyncEnumerable(IAsyncEnumerable<T> asyncEnumerable, CancellationToken cancellationToken)
public BoxedAsyncEnumerator(IAsyncEnumerator<T> asyncEnumerator)
{
_asyncEnumerable = asyncEnumerable;
_cancellationToken = cancellationToken;
_asyncEnumerator = asyncEnumerator;
}

public IAsyncEnumerator<object?> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
// Assume that this will be iterated through with await foreach which always passes a default token.
// Instead use the token from the ctor.
Debug.Assert(cancellationToken == default);
public object? Current => _asyncEnumerator.Current;

var enumeratorOfT = _asyncEnumerable.GetAsyncEnumerator(_cancellationToken);
return enumeratorOfT as IAsyncEnumerator<object?> ?? new BoxedAsyncEnumerator(enumeratorOfT);
public ValueTask<bool> MoveNextAsync()
{
return _asyncEnumerator.MoveNextAsync();
}

private class BoxedAsyncEnumerator : IAsyncEnumerator<object?>
public ValueTask DisposeAsync()
{
private IAsyncEnumerator<T> _asyncEnumerator;

public BoxedAsyncEnumerator(IAsyncEnumerator<T> asyncEnumerator)
{
_asyncEnumerator = asyncEnumerator;
}

public object? Current => _asyncEnumerator.Current;

public ValueTask<bool> MoveNextAsync()
{
return _asyncEnumerator.MoveNextAsync();
}

public ValueTask DisposeAsync()
{
return _asyncEnumerator.DisposeAsync();
}
return _asyncEnumerator.DisposeAsync();
}
}
}
Expand Down
21 changes: 17 additions & 4 deletions src/SignalR/common/testassets/Tests.Utils/TestClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,21 @@ public Task<IList<HubMessage>> StreamAsync(string methodName, params object[] ar
public async Task<IList<HubMessage>> StreamAsync(string methodName, string[] streamIds, params object[] args)
{
var invocationId = await SendStreamInvocationAsync(methodName, streamIds, args);
return await ListenAllAsync(invocationId);
}

var messages = new List<HubMessage>();
public async Task<IList<HubMessage>> ListenAllAsync(string invocationId)
{
var result = new List<HubMessage>();
await foreach(var item in ListenAsync(invocationId))
{
result.Add(item);
}
return result;
}

public async IAsyncEnumerable<HubMessage> ListenAsync(string invocationId)
{
while (true)
{
var message = await ReadAsync();
Expand All @@ -120,11 +133,11 @@ public async Task<IList<HubMessage>> StreamAsync(string methodName, string[] str
switch (message)
{
case StreamItemMessage _:
messages.Add(message);
yield return message;
break;
case CompletionMessage _:
messages.Add(message);
return messages;
yield return message;
yield break;
default:
// Message implement ToString so this should be helpful.
throw new NotSupportedException($"TestClient recieved an unexpected message: {message}.");
Expand Down
11 changes: 5 additions & 6 deletions src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContex
public override async Task OnConnectedAsync(HubConnectionContext connection)
{
var scope = _serviceScopeFactory.CreateScope();
connection.HubCallerClients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId);
connection.HubCallerClients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId);

try
{
Expand Down Expand Up @@ -470,14 +470,13 @@ private async Task StreamAsync(string invocationId, HubConnectionContext connect
return;
}

var enumerable = descriptor.FromReturnedStream(result, streamCts.Token);

await using var enumerator = descriptor.FromReturnedStream(result, streamCts.Token);
Log.StreamingResult(_logger, invocationId, descriptor.MethodExecutor);

var streamItemMessage = new StreamItemMessage(invocationId, null);
await foreach (var streamItem in enumerable)

while (await enumerator.MoveNextAsync())
{
streamItemMessage.Item = streamItem;
streamItemMessage.Item = enumerator.Current;
// Send the stream item
await connection.WriteAsync(streamItemMessage);
}
Expand Down
32 changes: 16 additions & 16 deletions src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal
{
internal class HubMethodDescriptor
{
private static readonly MethodInfo MakeCancelableAsyncEnumerableMethod = typeof(AsyncEnumerableAdapters)
private static readonly MethodInfo MakeCancelableAsyncEnumeratorMethod = typeof(AsyncEnumerableAdapters)
.GetRuntimeMethods()
.Single(m => m.Name.Equals(nameof(AsyncEnumerableAdapters.MakeCancelableAsyncEnumerable)) && m.IsGenericMethod);
.Single(m => m.Name.Equals(nameof(AsyncEnumerableAdapters.MakeCancelableAsyncEnumerator)) && m.IsGenericMethod);

private static readonly MethodInfo MakeAsyncEnumerableFromChannelMethod = typeof(AsyncEnumerableAdapters)
private static readonly MethodInfo MakeAsyncEnumeratorFromChannelMethod = typeof(AsyncEnumerableAdapters)
.GetRuntimeMethods()
.Single(m => m.Name.Equals(nameof(AsyncEnumerableAdapters.MakeAsyncEnumerableFromChannel)) && m.IsGenericMethod);
.Single(m => m.Name.Equals(nameof(AsyncEnumerableAdapters.MakeAsyncEnumeratorFromChannel)) && m.IsGenericMethod);

private readonly MethodInfo? _makeCancelableEnumerableMethodInfo;
private Func<object, CancellationToken, IAsyncEnumerable<object>>? _makeCancelableEnumerable;
private readonly MethodInfo? _makeCancelableEnumeratorMethodInfo;
private Func<object, CancellationToken, IAsyncEnumerator<object>>? _makeCancelableEnumerator;

public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAuthorizeData> policies)
{
Expand All @@ -46,14 +46,14 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAut
if (openReturnType == typeof(IAsyncEnumerable<>))
{
StreamReturnType = returnType.GetGenericArguments()[0];
_makeCancelableEnumerableMethodInfo = MakeCancelableAsyncEnumerableMethod;
_makeCancelableEnumeratorMethodInfo = MakeCancelableAsyncEnumeratorMethod;
break;
}

if (openReturnType == typeof(ChannelReader<>))
{
StreamReturnType = returnType.GetGenericArguments()[0];
_makeCancelableEnumerableMethodInfo = MakeAsyncEnumerableFromChannelMethod;
_makeCancelableEnumeratorMethodInfo = MakeAsyncEnumeratorFromChannelMethod;
break;
}
}
Expand Down Expand Up @@ -107,22 +107,22 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAut

public bool HasSyntheticArguments { get; private set; }

public IAsyncEnumerable<object> FromReturnedStream(object stream, CancellationToken cancellationToken)
public IAsyncEnumerator<object> FromReturnedStream(object stream, CancellationToken cancellationToken)
{
// there is the potential for compile to be called times but this has no harmful effect other than perf
if (_makeCancelableEnumerable == null)
if (_makeCancelableEnumerator == null)
{
_makeCancelableEnumerable = CompileConvertToEnumerable(_makeCancelableEnumerableMethodInfo!, StreamReturnType!);
_makeCancelableEnumerator = CompileConvertToEnumerator(_makeCancelableEnumeratorMethodInfo!, StreamReturnType!);
}

return _makeCancelableEnumerable.Invoke(stream, cancellationToken);
return _makeCancelableEnumerator.Invoke(stream, cancellationToken);
}

private static Func<object, CancellationToken, IAsyncEnumerable<object>> CompileConvertToEnumerable(MethodInfo adapterMethodInfo, Type streamReturnType)
private static Func<object, CancellationToken, IAsyncEnumerator<object>> CompileConvertToEnumerator(MethodInfo adapterMethodInfo, Type streamReturnType)
{
// This will call one of two adapter methods to wrap the passed in streamable value into an IAsyncEnumerable<object>:
// - AsyncEnumerableAdapters.MakeCancelableAsyncEnumerable<T>(asyncEnumerable, cancellationToken);
// - AsyncEnumerableAdapters.MakeCancelableAsyncEnumerableFromChannel<T>(channelReader, cancellationToken);
// - AsyncEnumerableAdapters.MakeCancelableAsyncEnumerator<T>(asyncEnumerable, cancellationToken);
// - AsyncEnumerableAdapters.MakeCancelableAsyncEnumeratorFromChannel<T>(channelReader, cancellationToken);

var parameters = new[]
{
Expand All @@ -139,7 +139,7 @@ private static Func<object, CancellationToken, IAsyncEnumerable<object>> Compile
};

var methodCall = Expression.Call(null, genericMethodInfo, methodArguments);
var lambda = Expression.Lambda<Func<object, CancellationToken, IAsyncEnumerable<object>>>(methodCall, parameters);
var lambda = Expression.Lambda<Func<object, CancellationToken, IAsyncEnumerator<object>>>(methodCall, parameters);
return lambda.Compile();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ public Task SendToCaller(string message)

public Task ProtocolError()
{
return Clients.Caller.SendAsync("Send", new SelfRef());
return Clients.Caller.SendAsync("Send", new SelfRef());
}

public void InvalidArgument(CancellationToken token)
Expand Down Expand Up @@ -1027,6 +1027,39 @@ public async IAsyncEnumerable<int> CancelableStreamGeneratedAsyncEnumerable([Enu
yield break;
}

public async IAsyncEnumerable<int> CountingCancelableStreamGeneratedAsyncEnumerable(int count, [EnumeratorCancellation] CancellationToken token)
{
for (int i = 0; i < count; i++)
{
await Task.Yield();
yield return i;
}
_tcsService.StartedMethod.SetResult(null);
await token.WaitForCancellationAsync();
_tcsService.EndMethod.SetResult(null);
yield break;
}

public ChannelReader<int> CountingCancelableStreamGeneratedChannel(int count, CancellationToken token)
{
var channel = Channel.CreateBounded<int>(10);

Task.Run(async () =>
{
for (int i = 0; i < count; i++)
{
await Task.Yield();
await channel.Writer.WriteAsync(i);
}
_tcsService.StartedMethod.SetResult(null);
await token.WaitForCancellationAsync();
channel.Writer.TryComplete();
_tcsService.EndMethod.SetResult(null);
});

return channel.Reader;
}

public IAsyncEnumerable<int> CancelableStreamCustomAsyncEnumerable()
{
return new CustomAsyncEnumerable(_tcsService);
Expand Down
Loading