From eae2eb2e58e1a5ba9f5137ff97f08153c541f922 Mon Sep 17 00:00:00 2001 From: Laszlo Deak Date: Wed, 5 May 2021 21:30:24 +0200 Subject: [PATCH 1/2] Convert MakeEnumerable and MakeEnumerableChannel to making IAsyncEnumerator DefaultHubDispatcher iterates using the enumerator Adding a tests --- .../common/Shared/AsyncEnumerableAdapters.cs | 96 +++++++++---------- .../testassets/Tests.Utils/TestClient.cs | 21 +++- .../Core/src/Internal/DefaultHubDispatcher.cs | 11 +-- .../Core/src/Internal/HubMethodDescriptor.cs | 32 +++---- .../HubConnectionHandlerTestUtils/Hubs.cs | 35 ++++++- .../SignalR/test/HubConnectionHandlerTests.cs | 80 ++++++++++++++++ 6 files changed, 197 insertions(+), 78 deletions(-) diff --git a/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs b/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs index 5a641510dd41..ca3f8bcacc6a 100644 --- a/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs +++ b/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs @@ -2,8 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Collections.Generic; -using System.Diagnostics; -using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; @@ -13,9 +11,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal // True-internal because this is a weird and tricky class to use :) internal static class AsyncEnumerableAdapters { - public static IAsyncEnumerable MakeCancelableAsyncEnumerable(IAsyncEnumerable asyncEnumerable, CancellationToken cancellationToken = default) + public static IAsyncEnumerator MakeCancelableAsyncEnumerator(IAsyncEnumerable asyncEnumerable, CancellationToken cancellationToken = default) { - return new CancelableAsyncEnumerable(asyncEnumerable, cancellationToken); + var enumerator = asyncEnumerable.GetAsyncEnumerator(cancellationToken); + return enumerator as IAsyncEnumerator ?? new BoxedAsyncEnumerator(enumerator); } public static IAsyncEnumerable MakeCancelableTypedAsyncEnumerable(IAsyncEnumerable asyncEnumerable, CancellationTokenSource cts) @@ -23,28 +22,46 @@ public static IAsyncEnumerable MakeCancelableTypedAsyncEnumerable(IAsyncEn return new CancelableTypedAsyncEnumerable(asyncEnumerable, cts); } -#if NETCOREAPP - public static async IAsyncEnumerable MakeAsyncEnumerableFromChannel(ChannelReader channel, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public static IAsyncEnumerator MakeAsyncEnumeratorFromChannel(ChannelReader channel, CancellationToken cancellationToken = default) { - await foreach (var item in channel.ReadAllAsync(cancellationToken)) - { - yield return item; - } + return new ChannelAsyncEnumerator(channel, cancellationToken); } -#else - // System.Threading.Channels.ReadAllAsync() is not available on netstandard2.0 and netstandard2.1 - // But this is the exact same code that it uses - public static async IAsyncEnumerable MakeAsyncEnumerableFromChannel(ChannelReader channel, [EnumeratorCancellation] CancellationToken cancellationToken = default) + + private class ChannelAsyncEnumerator : IAsyncEnumerator { - while (await channel.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) + private readonly ChannelReader _channel; + private readonly CancellationToken _cancellationToken; + public ChannelAsyncEnumerator(ChannelReader channel, CancellationToken cancellationToken) + { + _channel = channel; + _cancellationToken = cancellationToken; + } + + public object? Current { get; private set; } + + public ValueTask MoveNextAsync() { - while (channel.TryRead(out var item)) + if (_channel.TryRead(out var item)) { - yield return item; + Current = item; + return new ValueTask(true); } + + return new ValueTask(MoveNextAsyncAwaited()); } + + private async Task MoveNextAsyncAwaited() + { + if (await _channel.WaitToReadAsync(_cancellationToken) && _channel.TryRead(out var item)) + { + Current = item; + return true; + } + return false; + } + + public ValueTask DisposeAsync() => default; } -#endif private class CancelableTypedAsyncEnumerable : IAsyncEnumerable { @@ -99,48 +116,25 @@ public ValueTask DisposeAsync() } } - /// Converts an IAsyncEnumerable of T to an IAsyncEnumerable of object. - private class CancelableAsyncEnumerable : IAsyncEnumerable + private class BoxedAsyncEnumerator : IAsyncEnumerator { - private readonly IAsyncEnumerable _asyncEnumerable; - private readonly CancellationToken _cancellationToken; + private IAsyncEnumerator _asyncEnumerator; - public CancelableAsyncEnumerable(IAsyncEnumerable asyncEnumerable, CancellationToken cancellationToken) + public BoxedAsyncEnumerator(IAsyncEnumerator asyncEnumerator) { - _asyncEnumerable = asyncEnumerable; - _cancellationToken = cancellationToken; + _asyncEnumerator = asyncEnumerator; } - public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) - { - // Assume that this will be iterated through with await foreach which always passes a default token. - // Instead use the token from the ctor. - Debug.Assert(cancellationToken == default); + public object? Current => _asyncEnumerator.Current; - var enumeratorOfT = _asyncEnumerable.GetAsyncEnumerator(_cancellationToken); - return enumeratorOfT as IAsyncEnumerator ?? new BoxedAsyncEnumerator(enumeratorOfT); + public ValueTask MoveNextAsync() + { + return _asyncEnumerator.MoveNextAsync(); } - private class BoxedAsyncEnumerator : IAsyncEnumerator + public ValueTask DisposeAsync() { - private IAsyncEnumerator _asyncEnumerator; - - public BoxedAsyncEnumerator(IAsyncEnumerator asyncEnumerator) - { - _asyncEnumerator = asyncEnumerator; - } - - public object? Current => _asyncEnumerator.Current; - - public ValueTask MoveNextAsync() - { - return _asyncEnumerator.MoveNextAsync(); - } - - public ValueTask DisposeAsync() - { - return _asyncEnumerator.DisposeAsync(); - } + return _asyncEnumerator.DisposeAsync(); } } } diff --git a/src/SignalR/common/testassets/Tests.Utils/TestClient.cs b/src/SignalR/common/testassets/Tests.Utils/TestClient.cs index d6ec7f756b32..2f0c2926637d 100644 --- a/src/SignalR/common/testassets/Tests.Utils/TestClient.cs +++ b/src/SignalR/common/testassets/Tests.Utils/TestClient.cs @@ -101,8 +101,21 @@ public Task> StreamAsync(string methodName, params object[] ar public async Task> StreamAsync(string methodName, string[] streamIds, params object[] args) { var invocationId = await SendStreamInvocationAsync(methodName, streamIds, args); + return await ListenAllAsync(invocationId); + } - var messages = new List(); + public async Task> ListenAllAsync(string invocationId) + { + var result = new List(); + await foreach(var item in ListenAsync(invocationId)) + { + result.Add(item); + } + return result; + } + + public async IAsyncEnumerable ListenAsync(string invocationId) + { while (true) { var message = await ReadAsync(); @@ -120,11 +133,11 @@ public async Task> StreamAsync(string methodName, string[] str switch (message) { case StreamItemMessage _: - messages.Add(message); + yield return message; break; case CompletionMessage _: - messages.Add(message); - return messages; + yield return message; + yield break; default: // Message implement ToString so this should be helpful. throw new NotSupportedException($"TestClient recieved an unexpected message: {message}."); diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 243b3a5d8622..cfa4dada3b90 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -74,7 +74,7 @@ public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContex public override async Task OnConnectedAsync(HubConnectionContext connection) { var scope = _serviceScopeFactory.CreateScope(); - connection.HubCallerClients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId); + connection.HubCallerClients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId); try { @@ -470,14 +470,13 @@ private async Task StreamAsync(string invocationId, HubConnectionContext connect return; } - var enumerable = descriptor.FromReturnedStream(result, streamCts.Token); - + await using var enumerator = descriptor.FromReturnedStream(result, streamCts.Token); Log.StreamingResult(_logger, invocationId, descriptor.MethodExecutor); - var streamItemMessage = new StreamItemMessage(invocationId, null); - await foreach (var streamItem in enumerable) + + while (await enumerator.MoveNextAsync()) { - streamItemMessage.Item = streamItem; + streamItemMessage.Item = enumerator.Current; // Send the stream item await connection.WriteAsync(streamItemMessage); } diff --git a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs index 0f34e94b5b7a..82df14b20469 100644 --- a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs +++ b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs @@ -15,16 +15,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal { internal class HubMethodDescriptor { - private static readonly MethodInfo MakeCancelableAsyncEnumerableMethod = typeof(AsyncEnumerableAdapters) + private static readonly MethodInfo MakeCancelableAsyncEnumeratorMethod = typeof(AsyncEnumerableAdapters) .GetRuntimeMethods() - .Single(m => m.Name.Equals(nameof(AsyncEnumerableAdapters.MakeCancelableAsyncEnumerable)) && m.IsGenericMethod); + .Single(m => m.Name.Equals(nameof(AsyncEnumerableAdapters.MakeCancelableAsyncEnumerator)) && m.IsGenericMethod); - private static readonly MethodInfo MakeAsyncEnumerableFromChannelMethod = typeof(AsyncEnumerableAdapters) + private static readonly MethodInfo MakeAsyncEnumeratorFromChannelMethod = typeof(AsyncEnumerableAdapters) .GetRuntimeMethods() - .Single(m => m.Name.Equals(nameof(AsyncEnumerableAdapters.MakeAsyncEnumerableFromChannel)) && m.IsGenericMethod); + .Single(m => m.Name.Equals(nameof(AsyncEnumerableAdapters.MakeAsyncEnumeratorFromChannel)) && m.IsGenericMethod); - private readonly MethodInfo? _makeCancelableEnumerableMethodInfo; - private Func>? _makeCancelableEnumerable; + private readonly MethodInfo? _makeCancelableEnumeratorMethodInfo; + private Func>? _makeCancelableEnumerator; public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable policies) { @@ -46,14 +46,14 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable)) { StreamReturnType = returnType.GetGenericArguments()[0]; - _makeCancelableEnumerableMethodInfo = MakeCancelableAsyncEnumerableMethod; + _makeCancelableEnumeratorMethodInfo = MakeCancelableAsyncEnumeratorMethod; break; } if (openReturnType == typeof(ChannelReader<>)) { StreamReturnType = returnType.GetGenericArguments()[0]; - _makeCancelableEnumerableMethodInfo = MakeAsyncEnumerableFromChannelMethod; + _makeCancelableEnumeratorMethodInfo = MakeAsyncEnumeratorFromChannelMethod; break; } } @@ -107,22 +107,22 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable FromReturnedStream(object stream, CancellationToken cancellationToken) + public IAsyncEnumerator FromReturnedStream(object stream, CancellationToken cancellationToken) { // there is the potential for compile to be called times but this has no harmful effect other than perf - if (_makeCancelableEnumerable == null) + if (_makeCancelableEnumerator == null) { - _makeCancelableEnumerable = CompileConvertToEnumerable(_makeCancelableEnumerableMethodInfo!, StreamReturnType!); + _makeCancelableEnumerator = CompileConvertToEnumerator(_makeCancelableEnumeratorMethodInfo!, StreamReturnType!); } - return _makeCancelableEnumerable.Invoke(stream, cancellationToken); + return _makeCancelableEnumerator.Invoke(stream, cancellationToken); } - private static Func> CompileConvertToEnumerable(MethodInfo adapterMethodInfo, Type streamReturnType) + private static Func> CompileConvertToEnumerator(MethodInfo adapterMethodInfo, Type streamReturnType) { // This will call one of two adapter methods to wrap the passed in streamable value into an IAsyncEnumerable: - // - AsyncEnumerableAdapters.MakeCancelableAsyncEnumerable(asyncEnumerable, cancellationToken); - // - AsyncEnumerableAdapters.MakeCancelableAsyncEnumerableFromChannel(channelReader, cancellationToken); + // - AsyncEnumerableAdapters.MakeCancelableAsyncEnumerator(asyncEnumerable, cancellationToken); + // - AsyncEnumerableAdapters.MakeCancelableAsyncEnumeratorFromChannel(channelReader, cancellationToken); var parameters = new[] { @@ -139,7 +139,7 @@ private static Func> Compile }; var methodCall = Expression.Call(null, genericMethodInfo, methodArguments); - var lambda = Expression.Lambda>>(methodCall, parameters); + var lambda = Expression.Lambda>>(methodCall, parameters); return lambda.Compile(); } } diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs index 2c7b76a105a9..089e80a37c61 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs @@ -183,7 +183,7 @@ public Task SendToCaller(string message) public Task ProtocolError() { - return Clients.Caller.SendAsync("Send", new SelfRef()); + return Clients.Caller.SendAsync("Send", new SelfRef()); } public void InvalidArgument(CancellationToken token) @@ -1027,6 +1027,39 @@ public async IAsyncEnumerable CancelableStreamGeneratedAsyncEnumerable([Enu yield break; } + public async IAsyncEnumerable CountingCancelableStreamGeneratedAsyncEnumerable(int count, [EnumeratorCancellation] CancellationToken token) + { + for (int i = 0; i < count; i++) + { + await Task.Yield(); + yield return i; + } + _tcsService.StartedMethod.SetResult(null); + await token.WaitForCancellationAsync(); + _tcsService.EndMethod.SetResult(null); + yield break; + } + + public ChannelReader CountingCancelableStreamGeneratedChannel(int count, CancellationToken token) + { + var channel = Channel.CreateBounded(10); + + Task.Run(async () => + { + for (int i = 0; i < count; i++) + { + await Task.Yield(); + await channel.Writer.WriteAsync(i); + } + _tcsService.StartedMethod.SetResult(null); + await token.WaitForCancellationAsync(); + channel.Writer.TryComplete(); + _tcsService.EndMethod.SetResult(null); + }); + + return channel.Reader; + } + public IAsyncEnumerable CancelableStreamCustomAsyncEnumerable() { return new CustomAsyncEnumerable(_tcsService); diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index 314da2851a08..31cfe76e801c 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -4147,6 +4147,53 @@ public async Task StreamHubMethodCanBeTriggeredOnCancellation(string methodName, } } + [Theory] + [InlineData(nameof(LongRunningHub.CountingCancelableStreamGeneratedAsyncEnumerable), 2)] + [InlineData(nameof(LongRunningHub.CountingCancelableStreamGeneratedChannel), 2)] + public async Task CancellationAfterGivenMessagesEndsStreaming(string methodName, int count) + { + using (StartVerifiableLog()) + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + builder.AddSingleton(tcsService); + }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + var invocationBinder = new Mock(); + invocationBinder.Setup(b => b.GetStreamItemType(It.IsAny())).Returns(typeof(string)); + + using (var client = new TestClient(invocationBinder: invocationBinder.Object)) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + + // Start streaming count number of messages. + var invocationId = await client.SendStreamInvocationAsync(methodName, count).DefaultTimeout(); + + // Listening on incoming messages + var listeningMessages = client.ListenAsync(invocationId); + + // Wait for the number of messages expected to be received. This point the sender just waits forever or until cancellation. + await listeningMessages.ReadAsync(count).DefaultTimeout(); + + // Send cancellation. + await client.SendHubMessageAsync(new CancelInvocationMessage(invocationId)).DefaultTimeout(); + + // Wait for the completion message. + var messages = await listeningMessages.ReadAllAsync().DefaultTimeout(); + Assert.Single(messages); + + // CancellationToken passed to hub method will allow EndMethod to be triggered if it is canceled. + await tcsService.EndMethod.Task.DefaultTimeout(); + + // Shut down + client.Dispose(); + + await connectionHandlerTask.DefaultTimeout(); + } + } + } + [Fact] public async Task StreamHubMethodCanAcceptCancellationTokenAsArgumentAndBeTriggeredOnConnectionAborted() { @@ -4522,4 +4569,37 @@ private class HttpContextFeatureImpl : IHttpContextFeature public HttpContext HttpContext { get; set; } } } + + public static class IAsyncEnumerableExtension + { + public static async Task> ReadAsync(this IAsyncEnumerable enumerable, int count) + { + if (count <= 0) + { + throw new ArgumentException("Input must be greater than zero.", nameof(count)); + } + + var result = new List(); + await foreach (var item in enumerable) + { + result.Add(item); + if (result.Count == count) + { + break; + } + } + return result; + } + + public static async Task> ReadAllAsync(this IAsyncEnumerable enumerable) + { + var result = new List(); + await foreach (var item in enumerable) + { + result.Add(item); + } + + return result; + } + } } From c0c8ae57fa861272fc3c0d140b63718cea6d73e2 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Wed, 12 May 2021 16:34:55 -0700 Subject: [PATCH 2/2] Update src/SignalR/common/Shared/AsyncEnumerableAdapters.cs --- src/SignalR/common/Shared/AsyncEnumerableAdapters.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs b/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs index ca3f8bcacc6a..a94f23296a07 100644 --- a/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs +++ b/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs @@ -52,7 +52,7 @@ public ValueTask MoveNextAsync() private async Task MoveNextAsyncAwaited() { - if (await _channel.WaitToReadAsync(_cancellationToken) && _channel.TryRead(out var item)) + if (await _channel.WaitToReadAsync(_cancellationToken).ConfigureAwait(false) && _channel.TryRead(out var item)) { Current = item; return true;