From 4de77e9c2c5f4cc9a18b1efcf4bf7906973861b1 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Wed, 16 Jan 2019 16:37:52 -0800 Subject: [PATCH 1/7] Support IAsyncEnumerable returns in SignalR hubs --- .../samples/SignalRSamples/Hubs/Streaming.cs | 10 ++ .../SignalRSamples/SignalRSamples.csproj | 2 +- .../SignalRSamples/wwwroot/streaming.html | 9 +- .../src/Internal/AsyncEnumeratorAdapters.cs | 96 ++++++++++++------- .../Core/src/Internal/DefaultHubDispatcher.cs | 6 +- .../Core/src/Internal/HubMethodDescriptor.cs | 94 +++++++++--------- 6 files changed, 130 insertions(+), 87 deletions(-) diff --git a/src/SignalR/samples/SignalRSamples/Hubs/Streaming.cs b/src/SignalR/samples/SignalRSamples/Hubs/Streaming.cs index ee5401b7c1e0..f0d1dc4baaa6 100644 --- a/src/SignalR/samples/SignalRSamples/Hubs/Streaming.cs +++ b/src/SignalR/samples/SignalRSamples/Hubs/Streaming.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Reactive.Linq; using System.Threading.Channels; using System.Threading.Tasks; @@ -11,6 +12,15 @@ namespace SignalRSamples.Hubs { public class Streaming : Hub { + public async IAsyncEnumerable AsyncEnumerableCounter(int count, int delay) + { + for (var i = 0; i < count; i++) + { + yield return i; + await Task.Delay(delay); + } + } + public ChannelReader ObservableCounter(int count, int delay) { var observable = Observable.Interval(TimeSpan.FromMilliseconds(delay)) diff --git a/src/SignalR/samples/SignalRSamples/SignalRSamples.csproj b/src/SignalR/samples/SignalRSamples/SignalRSamples.csproj index 25201e0f9aa0..283fe8d1bb39 100644 --- a/src/SignalR/samples/SignalRSamples/SignalRSamples.csproj +++ b/src/SignalR/samples/SignalRSamples/SignalRSamples.csproj @@ -1,4 +1,4 @@ - + netcoreapp3.0 diff --git a/src/SignalR/samples/SignalRSamples/wwwroot/streaming.html b/src/SignalR/samples/SignalRSamples/wwwroot/streaming.html index 44becb3cb29b..5cc0b3fd6745 100644 --- a/src/SignalR/samples/SignalRSamples/wwwroot/streaming.html +++ b/src/SignalR/samples/SignalRSamples/wwwroot/streaming.html @@ -17,6 +17,7 @@

Controls

+
@@ -32,7 +33,7 @@

Results

let resultsList = document.getElementById('resultsList'); let channelButton = document.getElementById('channelButton'); let observableButton = document.getElementById('observableButton'); - let clearButton = document.getElementById('clearButton'); + let asyncEnumerableButton = document.getElementById('asyncEnumerableButton'); let connectButton = document.getElementById('connectButton'); let disconnectButton = document.getElementById('disconnectButton'); @@ -61,6 +62,7 @@

Results

connection.onclose(function () { channelButton.disabled = true; observableButton.disabled = true; + asyncEnumerableButton.disabled = true; connectButton.disabled = false; disconnectButton.disabled = true; @@ -71,12 +73,17 @@

Results

.then(function () { channelButton.disabled = false; observableButton.disabled = false; + asyncEnumerableButton.disabled = false; connectButton.disabled = true; disconnectButton.disabled = false; addLine('resultsList', 'connected', 'green'); }); }); + click('asyncEnumerableButton', function () { + run('AsyncEnumerableCounter'); + }) + click('observableButton', function () { run('ObservableCounter'); }); diff --git a/src/SignalR/server/Core/src/Internal/AsyncEnumeratorAdapters.cs b/src/SignalR/server/Core/src/Internal/AsyncEnumeratorAdapters.cs index 0e50b5c4338e..1dcc06d6a2b2 100644 --- a/src/SignalR/server/Core/src/Internal/AsyncEnumeratorAdapters.cs +++ b/src/SignalR/server/Core/src/Internal/AsyncEnumeratorAdapters.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; @@ -11,74 +12,95 @@ namespace Microsoft.AspNetCore.SignalR.Internal // True-internal because this is a weird and tricky class to use :) internal static class AsyncEnumeratorAdapters { - public static IAsyncEnumerator GetAsyncEnumerator(ChannelReader channel, CancellationToken cancellationToken = default(CancellationToken)) + public static IAsyncEnumerator GetAsyncEnumeratorFromAsyncEnumerable(IAsyncEnumerable asyncEnumerable, CancellationToken cancellationToken = default(CancellationToken)) { - // Nothing to dispose when we finish enumerating in this case. - return new AsyncEnumerator(channel, cancellationToken, disposable: null); + var enumerator = asyncEnumerable.GetAsyncEnumerator(cancellationToken); + + if (typeof(T).IsValueType) + { + return new BoxedAsyncEnumerator(enumerator); + } + + return (IAsyncEnumerator)enumerator; + } + + public static IAsyncEnumerator GetAsyncEnumeratorFromChannel(ChannelReader channel, CancellationToken cancellationToken = default(CancellationToken)) + { + return new ChannelAsyncEnumerator(channel, cancellationToken); + } + + /// Converts an IAsyncEnumerator of T to an IAsyncEnumerator of object. + private class BoxedAsyncEnumerator : IAsyncEnumerator + { + private IAsyncEnumerator _asyncEnumerator; + + public BoxedAsyncEnumerator(IAsyncEnumerator asyncEnumerator) + { + _asyncEnumerator = asyncEnumerator; + } + + public object Current => _asyncEnumerator.Current; + + public ValueTask MoveNextAsync() + { + return _asyncEnumerator.MoveNextAsync(); + } + + public ValueTask DisposeAsync() + { + return _asyncEnumerator.DisposeAsync(); + } } /// Provides an async enumerator for the data in a channel. - internal class AsyncEnumerator : IAsyncEnumerator, IDisposable + private class ChannelAsyncEnumerator : IAsyncEnumerator { /// The channel being enumerated. private readonly ChannelReader _channel; /// Cancellation token used to cancel the enumeration. private readonly CancellationToken _cancellationToken; /// The current element of the enumeration. - private object _current; + private T _current; - private readonly IDisposable _disposable; - - internal AsyncEnumerator(ChannelReader channel, CancellationToken cancellationToken, IDisposable disposable) + public ChannelAsyncEnumerator(ChannelReader channel, CancellationToken cancellationToken) { _channel = channel; _cancellationToken = cancellationToken; - _disposable = disposable; } public object Current => _current; - public Task MoveNextAsync() + public ValueTask MoveNextAsync() { var result = _channel.ReadAsync(_cancellationToken); if (result.IsCompletedSuccessfully) { _current = result.Result; - return Task.FromResult(true); + return new ValueTask(true); } - return result.AsTask().ContinueWith((t, s) => + return new ValueTask(MoveNextAsyncAwaited(result)); + } + + private async Task MoveNextAsyncAwaited(ValueTask channelReadTask) + { + try { - var thisRef = (AsyncEnumerator)s; - if (t.IsFaulted && t.Exception.InnerException is ChannelClosedException cce && cce.InnerException == null) - { - return false; - } - thisRef._current = t.GetAwaiter().GetResult(); - return true; - }, this, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.NotOnCanceled, TaskScheduler.Default); + _current = await channelReadTask; + } + catch (ChannelClosedException ex) when (ex.InnerException == null) + { + return false; + } + + return true; } - public void Dispose() + public ValueTask DisposeAsync() { - _disposable?.Dispose(); + return default; } } } - - /// Represents an enumerator accessed asynchronously. - /// Specifies the type of the data enumerated. - internal interface IAsyncEnumerator - { - /// Asynchronously move the enumerator to the next element. - /// - /// A task that returns true if the enumerator was successfully advanced to the next item, - /// or false if no more data was available in the collection. - /// - Task MoveNextAsync(); - - /// Gets the current element being enumerated. - T Current { get; } - } } diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 01f470b0c7e8..1a0f3f89ec5f 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -422,7 +422,7 @@ private async Task StreamResultsAsync(string invocationId, HubConnectionContext } finally { - (enumerator as IDisposable)?.Dispose(); + await enumerator.DisposeAsync(); await CleanupInvocation(connection, hubMethodInvocationMessage, hubActivator, hub, scope); @@ -534,14 +534,14 @@ private bool TryGetStreamingEnumerator(HubConnectionContext connection, string i { if (result != null) { - if (hubMethodDescriptor.IsChannel) + if (hubMethodDescriptor.IsStreamable) { if (streamCts == null) { streamCts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted); } connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts); - enumerator = hubMethodDescriptor.FromChannel(result, streamCts.Token); + enumerator = hubMethodDescriptor.FromReturnedStream(result, streamCts.Token); return true; } } diff --git a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs index dec2e67aafb9..1c487225feec 100644 --- a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs +++ b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs @@ -15,9 +15,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal { internal class HubMethodDescriptor { - private static readonly MethodInfo GetAsyncEnumeratorMethod = typeof(AsyncEnumeratorAdapters) + private static readonly MethodInfo GetAsyncEnumeratorFromAsyncEnumerableMethod = typeof(AsyncEnumeratorAdapters) .GetRuntimeMethods() - .Single(m => m.Name.Equals(nameof(AsyncEnumeratorAdapters.GetAsyncEnumerator)) && m.IsGenericMethod); + .Single(m => m.Name.Equals(nameof(AsyncEnumeratorAdapters.GetAsyncEnumeratorFromAsyncEnumerable)) && m.IsGenericMethod); + + private static readonly MethodInfo GetAsyncEnumeratorFromChannelMethod = typeof(AsyncEnumeratorAdapters) + .GetRuntimeMethods() + .Single(m => m.Name.Equals(nameof(AsyncEnumeratorAdapters.GetAsyncEnumeratorFromChannel)) && m.IsGenericMethod); + + private MethodInfo _convertToEnumeratorMethodInfo; + private Func> _convertToEnumerator; public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable policies) { @@ -27,10 +34,28 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable)) + { + StreamReturnType = closedType.GetGenericArguments()[0]; + _convertToEnumeratorMethodInfo = GetAsyncEnumeratorFromAsyncEnumerableMethod; + break; + } + + if (openType == typeof(ChannelReader<>)) + { + StreamReturnType = closedType.GetGenericArguments()[0]; + _convertToEnumeratorMethodInfo = GetAsyncEnumeratorFromChannelMethod; + break; + } } // Take out synthetic arguments that will be provided by the server, this list will be given to the protocol parsers @@ -66,8 +91,6 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable StreamingParameters { get; private set; } - private Func> _convertToEnumerator; - public ObjectMethodExecutor MethodExecutor { get; } public IReadOnlyList ParameterTypes { get; } @@ -76,9 +99,7 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable IsChannel; + public bool IsStreamable => StreamReturnType != null; public Type StreamReturnType { get; } @@ -86,57 +107,40 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable t.IsGenericType && t.GetGenericTypeDefinition() == typeof(ChannelReader<>)); - if (channelType == null) - { - payloadType = null; - return false; - } - - payloadType = channelType.GetGenericArguments()[0]; - return true; - } - - public IAsyncEnumerator FromChannel(object channel, CancellationToken cancellationToken) + public IAsyncEnumerator FromReturnedStream(object stream, CancellationToken cancellationToken) { // there is the potential for compile to be called times but this has no harmful effect other than perf if (_convertToEnumerator == null) { - _convertToEnumerator = CompileConvertToEnumerator(GetAsyncEnumeratorMethod, StreamReturnType); + _convertToEnumerator = CompileConvertToEnumerator(_convertToEnumeratorMethodInfo, StreamReturnType); } - return _convertToEnumerator.Invoke(channel, cancellationToken); + return _convertToEnumerator.Invoke(stream, cancellationToken); } private static Func> CompileConvertToEnumerator(MethodInfo adapterMethodInfo, Type streamReturnType) { - // This will call one of two adapter methods to wrap the passed in streamable value - // and cancellation token to an IAsyncEnumerator - // ChannelReader - // AsyncEnumeratorAdapters.GetAsyncEnumerator(channelReader, cancellationToken); + // This will call one of two adapter methods to wrap the passed in streamable value and cancellation token + // into an IAsyncEnumerator: + // - AsyncEnumeratorAdapters.GetAsyncEnumeratorFromAsyncEnumerable(asyncEnumerable, cancellationToken); + // - AsyncEnumeratorAdapters.GetAsyncEnumeratorFromChannel(channelReader, cancellationToken); - var genericMethodInfo = adapterMethodInfo.MakeGenericMethod(streamReturnType); + var parameters = new[] + { + Expression.Parameter(typeof(object)), + Expression.Parameter(typeof(CancellationToken)), + }; + var genericMethodInfo = adapterMethodInfo.MakeGenericMethod(streamReturnType); var methodParameters = genericMethodInfo.GetParameters(); - - // arg1 and arg2 are the parameter names on Func - // we reference these values and then use them to call adaptor method - var targetParameter = Expression.Parameter(typeof(object), "arg1"); - var parametersParameter = Expression.Parameter(typeof(CancellationToken), "arg2"); - - var parameters = new List + var methodArguements = new Expression[] { - Expression.Convert(targetParameter, methodParameters[0].ParameterType), - parametersParameter + Expression.Convert(parameters[0], methodParameters[0].ParameterType), + parameters[1], }; - var methodCall = Expression.Call(null, genericMethodInfo, parameters); - - var castMethodCall = Expression.Convert(methodCall, typeof(IAsyncEnumerator)); - - var lambda = Expression.Lambda>>(castMethodCall, targetParameter, parametersParameter); + var methodCall = Expression.Call(null, genericMethodInfo, methodArguements); + var lambda = Expression.Lambda>>(methodCall, parameters); return lambda.Compile(); } } From a42fff6bd2b66ed87817bff5eda63de519399c96 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Wed, 16 Jan 2019 20:15:31 -0800 Subject: [PATCH 2/7] Bah, Humbug! I was hoping typeof(T).IsValueType would get evaluated during JIT compilation which would allow for dead code elimination, but alas: https://github.com/dotnet/corefx/issues/16217 --- .../server/Core/src/Internal/AsyncEnumeratorAdapters.cs | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/SignalR/server/Core/src/Internal/AsyncEnumeratorAdapters.cs b/src/SignalR/server/Core/src/Internal/AsyncEnumeratorAdapters.cs index 1dcc06d6a2b2..2b1f60f81ec4 100644 --- a/src/SignalR/server/Core/src/Internal/AsyncEnumeratorAdapters.cs +++ b/src/SignalR/server/Core/src/Internal/AsyncEnumeratorAdapters.cs @@ -15,13 +15,7 @@ internal static class AsyncEnumeratorAdapters public static IAsyncEnumerator GetAsyncEnumeratorFromAsyncEnumerable(IAsyncEnumerable asyncEnumerable, CancellationToken cancellationToken = default(CancellationToken)) { var enumerator = asyncEnumerable.GetAsyncEnumerator(cancellationToken); - - if (typeof(T).IsValueType) - { - return new BoxedAsyncEnumerator(enumerator); - } - - return (IAsyncEnumerator)enumerator; + return enumerator as IAsyncEnumerator ?? new BoxedAsyncEnumerator(enumerator); } public static IAsyncEnumerator GetAsyncEnumeratorFromChannel(ChannelReader channel, CancellationToken cancellationToken = default(CancellationToken)) From dbfadd1fdcbf009ac1e7f178ccaef7a68fdc4959 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Thu, 17 Jan 2019 12:37:05 -0800 Subject: [PATCH 3/7] Add tests for hubs returning IAsyncEnumerable - Allow hub methods to also return types implementing IAsyncEnumerable --- .../Core/src/Internal/HubMethodDescriptor.cs | 16 +++---- .../HubConnectionHandlerTestUtils/Hubs.cs | 47 ++++++++++++++++++- .../SignalR/test/HubConnectionHandlerTests.cs | 14 ++++-- .../Microsoft.AspNetCore.SignalR.Tests.csproj | 2 +- 4 files changed, 65 insertions(+), 14 deletions(-) diff --git a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs index 1c487225feec..ab9e2e5e98ca 100644 --- a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs +++ b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs @@ -23,7 +23,7 @@ internal class HubMethodDescriptor .GetRuntimeMethods() .Single(m => m.Name.Equals(nameof(AsyncEnumeratorAdapters.GetAsyncEnumeratorFromChannel)) && m.IsGenericMethod); - private MethodInfo _convertToEnumeratorMethodInfo; + private readonly MethodInfo _convertToEnumeratorMethodInfo; private Func> _convertToEnumerator; public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable policies) @@ -34,25 +34,25 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable)) + if (openReturnType == typeof(IAsyncEnumerable<>)) { - StreamReturnType = closedType.GetGenericArguments()[0]; + StreamReturnType = returnType.GetGenericArguments()[0]; _convertToEnumeratorMethodInfo = GetAsyncEnumeratorFromAsyncEnumerableMethod; break; } - if (openType == typeof(ChannelReader<>)) + if (openReturnType == typeof(ChannelReader<>)) { - StreamReturnType = closedType.GetGenericArguments()[0]; + StreamReturnType = returnType.GetGenericArguments()[0]; _convertToEnumeratorMethodInfo = GetAsyncEnumeratorFromChannelMethod; break; } diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs index 681c4759fd86..129f4bcbf298 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs @@ -591,6 +591,26 @@ public async ValueTask> CounterChannelValueTaskAsync(int c return CounterChannel(count); } + public async IAsyncEnumerable CounterAsyncEnumerable(int count) + { + for (int i = 0; i < count; i++) + { + await Task.Yield(); + yield return i.ToString(); + } + } + + public async Task> CounterAsyncEnumerableAsync(int count) + { + await Task.Yield(); + return CounterAsyncEnumerable(count); + } + + public WrappedAsyncEnumerable CounterWrappedAsyncEnumerable(int count) + { + return new WrappedAsyncEnumerable(CounterAsyncEnumerable(count)); + } + public ChannelReader BlockingStream() { return Channel.CreateUnbounded().Reader; @@ -627,6 +647,21 @@ public ChannelReader StreamEcho(ChannelReader source) return output.Reader; } + + public class WrappedAsyncEnumerable : IAsyncEnumerable + { + private readonly IAsyncEnumerable _inner; + + public WrappedAsyncEnumerable(IAsyncEnumerable inner) + { + _inner = inner; + } + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return _inner.GetAsyncEnumerator(cancellationToken); + } + } } public class SimpleHub : Hub @@ -696,6 +731,14 @@ public ChannelReader CancelableStream(CancellationToken token) return channel.Reader; } + public async IAsyncEnumerable CancelableAsyncEnumerableStream(CancellationToken token) + { + _tcsService.StartedMethod.SetResult(null); + await token.WaitForCancellationAsync(); + _tcsService.EndMethod.SetResult(null); + yield break; + } + public ChannelReader CancelableStream2(int ignore, int ignore2, CancellationToken token) { var channel = Channel.CreateBounded(10); @@ -734,8 +777,8 @@ public int SimpleMethod() public class TcsService { - public TaskCompletionSource StartedMethod = new TaskCompletionSource(); - public TaskCompletionSource EndMethod = new TaskCompletionSource(); + public TaskCompletionSource StartedMethod = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + public TaskCompletionSource EndMethod = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); } public interface ITypedHubClient diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index b51c47c229cb..4519d44f29bd 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -1909,10 +1909,17 @@ public static IEnumerable StreamingMethodAndHubProtocols { get { - foreach (var method in new[] + var methods = new[] { - nameof(StreamingHub.CounterChannel), nameof(StreamingHub.CounterChannelAsync), nameof(StreamingHub.CounterChannelValueTaskAsync) - }) + nameof(StreamingHub.CounterChannel), + nameof(StreamingHub.CounterChannelAsync), + nameof(StreamingHub.CounterChannelValueTaskAsync), + nameof(StreamingHub.CounterAsyncEnumerable), + nameof(StreamingHub.CounterAsyncEnumerableAsync), + nameof(StreamingHub.CounterWrappedAsyncEnumerable), + }; + + foreach (var method in methods) { foreach (var protocolName in HubProtocolHelpers.AllProtocolNames) { @@ -3153,6 +3160,7 @@ public async Task UploadStreamAndStreamingMethodClosesStreamsOnServerWhenMethodC [InlineData(nameof(LongRunningHub.CancelableStream))] [InlineData(nameof(LongRunningHub.CancelableStream2), 1, 2)] [InlineData(nameof(LongRunningHub.CancelableStreamMiddle), 1, 2)] + [InlineData(nameof(LongRunningHub.CancelableAsyncEnumerableStream))] public async Task StreamHubMethodCanAcceptCancellationTokenAsArgumentAndBeTriggeredOnCancellation(string methodName, params object[] args) { using (StartVerifiableLog()) diff --git a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests.csproj b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests.csproj index 9b35e64990b5..c4fd5fd37b39 100644 --- a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests.csproj +++ b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests.csproj @@ -1,4 +1,4 @@ - + netcoreapp3.0 From ec694c762d7f5662d9aef2e292f6f8ef2666151b Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Thu, 17 Jan 2019 15:03:20 -0800 Subject: [PATCH 4/7] Prefer IAsyncEnumerable over ChannelReader if hub method return type is both --- .../Core/src/Internal/HubMethodDescriptor.cs | 2 +- .../HubConnectionHandlerTestUtils/Hubs.cs | 91 ++++++++++++++++++- .../SignalR/test/HubConnectionHandlerTests.cs | 11 ++- 3 files changed, 94 insertions(+), 10 deletions(-) diff --git a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs index ab9e2e5e98ca..113c1f733e0c 100644 --- a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs +++ b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs @@ -34,7 +34,7 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable> CounterAsyncEnumerableAsync(int coun return CounterAsyncEnumerable(count); } - public WrappedAsyncEnumerable CounterWrappedAsyncEnumerable(int count) + public AsyncEnumerableImpl CounterAsyncEnumerableImpl(int count) { - return new WrappedAsyncEnumerable(CounterAsyncEnumerable(count)); + return new AsyncEnumerableImpl(CounterAsyncEnumerable(count)); + } + + public AsyncEnumerableImplChannelThrows AsyncEnumerableIsPreferedOverChannelReader(int count) + { + return new AsyncEnumerableImplChannelThrows(CounterChannel(count)); } public ChannelReader BlockingStream() @@ -648,11 +653,11 @@ public ChannelReader StreamEcho(ChannelReader source) return output.Reader; } - public class WrappedAsyncEnumerable : IAsyncEnumerable + public class AsyncEnumerableImpl : IAsyncEnumerable { private readonly IAsyncEnumerable _inner; - public WrappedAsyncEnumerable(IAsyncEnumerable inner) + public AsyncEnumerableImpl(IAsyncEnumerable inner) { _inner = inner; } @@ -662,6 +667,84 @@ public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToke return _inner.GetAsyncEnumerator(cancellationToken); } } + + public class AsyncEnumerableImplChannelThrows : ChannelReader, IAsyncEnumerable + { + private ChannelReader _inner; + + public AsyncEnumerableImplChannelThrows(ChannelReader inner) + { + _inner = inner; + } + + public override bool TryRead(out T item) + { + // Not implemented to verify this is consumed as an IAsyncEnumerable instead of a ChannelReader. + throw new NotImplementedException(); + } + + public override ValueTask WaitToReadAsync(CancellationToken cancellationToken = default) + { + // Not implemented to verify this is consumed as an IAsyncEnumerable instead of a ChannelReader. + throw new NotImplementedException(); + } + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new ChannelAsyncEnumerator(_inner, cancellationToken); + } + + // Copied from AsyncEnumeratorAdapters.ChannelAsyncEnumerator. Implements IAsyncEnumerator instead of IAsyncEnumerator. + private class ChannelAsyncEnumerator : IAsyncEnumerator + { + /// The channel being enumerated. + private readonly ChannelReader _channel; + /// Cancellation token used to cancel the enumeration. + private readonly CancellationToken _cancellationToken; + /// The current element of the enumeration. + private T _current; + + public ChannelAsyncEnumerator(ChannelReader channel, CancellationToken cancellationToken) + { + _channel = channel; + _cancellationToken = cancellationToken; + } + + public T Current => _current; + + public ValueTask MoveNextAsync() + { + var result = _channel.ReadAsync(_cancellationToken); + + if (result.IsCompletedSuccessfully) + { + _current = result.Result; + return new ValueTask(true); + } + + return new ValueTask(MoveNextAsyncAwaited(result)); + } + + private async Task MoveNextAsyncAwaited(ValueTask channelReadTask) + { + try + { + _current = await channelReadTask; + } + catch (ChannelClosedException ex) when (ex.InnerException == null) + { + return false; + } + + return true; + } + + public ValueTask DisposeAsync() + { + return default; + } + } + } } public class SimpleHub : Hub diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index 4519d44f29bd..acf5ffcde906 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -1763,10 +1763,10 @@ public async Task HubsCanStreamResponses(string method, string protocolName) { var protocol = HubProtocolHelpers.GetHubProtocol(protocolName); - var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(null, LoggerFactory); - var connectionHandler = serviceProvider.GetService>(); - var invocationBinder = new Mock(); - invocationBinder.Setup(b => b.GetStreamItemType(It.IsAny())).Returns(typeof(string)); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(null, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + var invocationBinder = new Mock(); + invocationBinder.Setup(b => b.GetStreamItemType(It.IsAny())).Returns(typeof(string)); using (var client = new TestClient(protocol: protocol, invocationBinder: invocationBinder.Object)) { @@ -1916,7 +1916,8 @@ public static IEnumerable StreamingMethodAndHubProtocols nameof(StreamingHub.CounterChannelValueTaskAsync), nameof(StreamingHub.CounterAsyncEnumerable), nameof(StreamingHub.CounterAsyncEnumerableAsync), - nameof(StreamingHub.CounterWrappedAsyncEnumerable), + nameof(StreamingHub.CounterAsyncEnumerableImpl), + nameof(StreamingHub.AsyncEnumerableIsPreferedOverChannelReader), }; foreach (var method in methods) From 41544a289a99c9950b260beee329600b61dd8441 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Thu, 17 Jan 2019 18:04:11 -0800 Subject: [PATCH 5/7] Use async foreach to consume IAsyncEnumerable - Flow CancellationToken to IAsyncEnumerable.GetAsyncEnumerator --- .../src/Internal/AsyncEnumerableAdapters.cs | 135 ++++++++++++++++++ .../src/Internal/AsyncEnumeratorAdapters.cs | 100 ------------- .../Core/src/Internal/DefaultHubDispatcher.cs | 44 ++---- .../Core/src/Internal/HubMethodDescriptor.cs | 39 +++-- .../Microsoft.AspNetCore.SignalR.Core.csproj | 2 +- .../HubConnectionHandlerTestUtils/Hubs.cs | 71 +++++++-- .../SignalR/test/HubConnectionHandlerTests.cs | 13 +- 7 files changed, 234 insertions(+), 170 deletions(-) create mode 100644 src/SignalR/server/Core/src/Internal/AsyncEnumerableAdapters.cs delete mode 100644 src/SignalR/server/Core/src/Internal/AsyncEnumeratorAdapters.cs diff --git a/src/SignalR/server/Core/src/Internal/AsyncEnumerableAdapters.cs b/src/SignalR/server/Core/src/Internal/AsyncEnumerableAdapters.cs new file mode 100644 index 000000000000..7b550d9d359b --- /dev/null +++ b/src/SignalR/server/Core/src/Internal/AsyncEnumerableAdapters.cs @@ -0,0 +1,135 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR.Internal +{ + // True-internal because this is a weird and tricky class to use :) + internal static class AsyncEnumerableAdapters + { + public static IAsyncEnumerable MakeCancelableAsyncEnumerable(IAsyncEnumerable asyncEnumerable, CancellationToken cancellationToken = default) + { + return new CancelableAsyncEnumerable(asyncEnumerable, cancellationToken); + } + + public static IAsyncEnumerable GetAsyncEnumerableFromChannel(ChannelReader channel, CancellationToken cancellationToken = default) + { + return new ChannelAsyncEnumerable(channel, cancellationToken); + } + + /// Converts an IAsyncEnumerable of T to an IAsyncEnumerable of object. + private class CancelableAsyncEnumerable : IAsyncEnumerable + { + private readonly IAsyncEnumerable _asyncEnumerable; + private readonly CancellationToken _cancellationToken; + + public CancelableAsyncEnumerable(IAsyncEnumerable asyncEnumerable, CancellationToken cancellationToken) + { + _asyncEnumerable = asyncEnumerable; + _cancellationToken = cancellationToken; + } + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + Debug.Assert(cancellationToken == default); + return new CancelableAsyncEnumerator(_asyncEnumerable.GetAsyncEnumerator(_cancellationToken)); + } + + private class CancelableAsyncEnumerator : IAsyncEnumerator + { + private IAsyncEnumerator _asyncEnumerator; + + public CancelableAsyncEnumerator(IAsyncEnumerator asyncEnumerator) + { + _asyncEnumerator = asyncEnumerator; + } + + public object Current => _asyncEnumerator.Current; + + public ValueTask MoveNextAsync() + { + return _asyncEnumerator.MoveNextAsync(); + } + + public ValueTask DisposeAsync() + { + return _asyncEnumerator.DisposeAsync(); + } + } + } + + /// Provides an IAsyncEnumerable of object for the data in a channel. + private class ChannelAsyncEnumerable : IAsyncEnumerable + { + private readonly ChannelReader _channel; + private readonly CancellationToken _cancellationToken; + + public ChannelAsyncEnumerable(ChannelReader channel, CancellationToken cancellationToken) + { + _channel = channel; + _cancellationToken = cancellationToken; + } + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + Debug.Assert(cancellationToken == default); + return new ChannelAsyncEnumerator(_channel, _cancellationToken); + } + + private class ChannelAsyncEnumerator : IAsyncEnumerator + { + /// The channel being enumerated. + private readonly ChannelReader _channel; + /// Cancellation token used to cancel the enumeration. + private readonly CancellationToken _cancellationToken; + /// The current element of the enumeration. + private T _current; + + public ChannelAsyncEnumerator(ChannelReader channel, CancellationToken cancellationToken) + { + _channel = channel; + _cancellationToken = cancellationToken; + } + + public object Current => _current; + + public ValueTask MoveNextAsync() + { + var result = _channel.ReadAsync(_cancellationToken); + + if (result.IsCompletedSuccessfully) + { + _current = result.Result; + return new ValueTask(true); + } + + return new ValueTask(MoveNextAsyncAwaited(result)); + } + + private async Task MoveNextAsyncAwaited(ValueTask channelReadTask) + { + try + { + _current = await channelReadTask; + } + catch (ChannelClosedException ex) when (ex.InnerException == null) + { + return false; + } + + return true; + } + + public ValueTask DisposeAsync() + { + return default; + } + } + } + } +} diff --git a/src/SignalR/server/Core/src/Internal/AsyncEnumeratorAdapters.cs b/src/SignalR/server/Core/src/Internal/AsyncEnumeratorAdapters.cs deleted file mode 100644 index 2b1f60f81ec4..000000000000 --- a/src/SignalR/server/Core/src/Internal/AsyncEnumeratorAdapters.cs +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Channels; -using System.Threading.Tasks; - -namespace Microsoft.AspNetCore.SignalR.Internal -{ - // True-internal because this is a weird and tricky class to use :) - internal static class AsyncEnumeratorAdapters - { - public static IAsyncEnumerator GetAsyncEnumeratorFromAsyncEnumerable(IAsyncEnumerable asyncEnumerable, CancellationToken cancellationToken = default(CancellationToken)) - { - var enumerator = asyncEnumerable.GetAsyncEnumerator(cancellationToken); - return enumerator as IAsyncEnumerator ?? new BoxedAsyncEnumerator(enumerator); - } - - public static IAsyncEnumerator GetAsyncEnumeratorFromChannel(ChannelReader channel, CancellationToken cancellationToken = default(CancellationToken)) - { - return new ChannelAsyncEnumerator(channel, cancellationToken); - } - - /// Converts an IAsyncEnumerator of T to an IAsyncEnumerator of object. - private class BoxedAsyncEnumerator : IAsyncEnumerator - { - private IAsyncEnumerator _asyncEnumerator; - - public BoxedAsyncEnumerator(IAsyncEnumerator asyncEnumerator) - { - _asyncEnumerator = asyncEnumerator; - } - - public object Current => _asyncEnumerator.Current; - - public ValueTask MoveNextAsync() - { - return _asyncEnumerator.MoveNextAsync(); - } - - public ValueTask DisposeAsync() - { - return _asyncEnumerator.DisposeAsync(); - } - } - - /// Provides an async enumerator for the data in a channel. - private class ChannelAsyncEnumerator : IAsyncEnumerator - { - /// The channel being enumerated. - private readonly ChannelReader _channel; - /// Cancellation token used to cancel the enumeration. - private readonly CancellationToken _cancellationToken; - /// The current element of the enumeration. - private T _current; - - public ChannelAsyncEnumerator(ChannelReader channel, CancellationToken cancellationToken) - { - _channel = channel; - _cancellationToken = cancellationToken; - } - - public object Current => _current; - - public ValueTask MoveNextAsync() - { - var result = _channel.ReadAsync(_cancellationToken); - - if (result.IsCompletedSuccessfully) - { - _current = result.Result; - return new ValueTask(true); - } - - return new ValueTask(MoveNextAsyncAwaited(result)); - } - - private async Task MoveNextAsyncAwaited(ValueTask channelReadTask) - { - try - { - _current = await channelReadTask; - } - catch (ChannelClosedException ex) when (ex.InnerException == null) - { - return false; - } - - return true; - } - - public ValueTask DisposeAsync() - { - return default; - } - } - } -} diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 1a0f3f89ec5f..99a0b4bd3cbe 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -293,16 +293,20 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, { var result = await ExecuteHubMethod(methodExecutor, hub, arguments); - if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, ref cts)) + if (result == null) { Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name); await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, - $"The value returned by the streaming method '{methodExecutor.MethodInfo.Name}' is not a ChannelReader<>."); + $"The value returned by the streaming method '{methodExecutor.MethodInfo.Name}' is not a ChannelReader<> or IAsyncEnumerable<>."); return; } + cts = cts ?? CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted); + connection.ActiveRequestCancellationSources.TryAdd(hubMethodInvocationMessage.InvocationId, cts); + var enumerable = descriptor.FromReturnedStream(result, cts.Token); + Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor); - _ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, cts, hubMethodInvocationMessage); + _ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerable, scope, hubActivator, hub, cts, hubMethodInvocationMessage); } else if (string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId)) @@ -393,17 +397,17 @@ private ValueTask CleanupInvocation(HubConnectionContext connection, HubMethodIn return scope.DisposeAsync(); } - private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerator enumerator, IServiceScope scope, + private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerable enumerable, IServiceScope scope, IHubActivator hubActivator, THub hub, CancellationTokenSource streamCts, HubMethodInvocationMessage hubMethodInvocationMessage) { string error = null; try { - while (await enumerator.MoveNextAsync()) + await foreach (var streamItem in enumerable) { // Send the stream item - await connection.WriteAsync(new StreamItemMessage(invocationId, enumerator.Current)); + await connection.WriteAsync(new StreamItemMessage(invocationId, streamItem)); } } catch (ChannelClosedException ex) @@ -422,8 +426,6 @@ private async Task StreamResultsAsync(string invocationId, HubConnectionContext } finally { - await enumerator.DisposeAsync(); - await CleanupInvocation(connection, hubMethodInvocationMessage, hubActivator, hub, scope); // Dispose the linked CTS for the stream. @@ -502,10 +504,10 @@ private static async Task IsHubMethodAuthorizedSlow(IServiceProvider provi return authorizationResult.Succeeded; } - private async Task ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamedInvocation, + private async Task ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamResponse, HubMethodInvocationMessage hubMethodInvocationMessage, HubConnectionContext connection) { - if (hubMethodDescriptor.IsStreamable && !isStreamedInvocation) + if (hubMethodDescriptor.IsStreamResponse && !isStreamResponse) { // Non-null/empty InvocationId? Blocking if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId)) @@ -518,7 +520,7 @@ await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessa return false; } - if (!hubMethodDescriptor.IsStreamable && isStreamedInvocation) + if (!hubMethodDescriptor.IsStreamResponse && isStreamResponse) { Log.NonStreamingMethodCalledWithStream(_logger, hubMethodInvocationMessage); await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId, @@ -530,26 +532,6 @@ await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessa return true; } - private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator enumerator, ref CancellationTokenSource streamCts) - { - if (result != null) - { - if (hubMethodDescriptor.IsStreamable) - { - if (streamCts == null) - { - streamCts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted); - } - connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts); - enumerator = hubMethodDescriptor.FromReturnedStream(result, streamCts.Token); - return true; - } - } - - enumerator = null; - return false; - } - private void DiscoverHubMethods() { var hubType = typeof(THub); diff --git a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs index 113c1f733e0c..3bc4d3fce630 100644 --- a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs +++ b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs @@ -15,16 +15,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal { internal class HubMethodDescriptor { - private static readonly MethodInfo GetAsyncEnumeratorFromAsyncEnumerableMethod = typeof(AsyncEnumeratorAdapters) + private static readonly MethodInfo MakeCancelableAsyncEnumerableMethod = typeof(AsyncEnumerableAdapters) .GetRuntimeMethods() - .Single(m => m.Name.Equals(nameof(AsyncEnumeratorAdapters.GetAsyncEnumeratorFromAsyncEnumerable)) && m.IsGenericMethod); + .Single(m => m.Name.Equals(nameof(AsyncEnumerableAdapters.MakeCancelableAsyncEnumerable)) && m.IsGenericMethod); - private static readonly MethodInfo GetAsyncEnumeratorFromChannelMethod = typeof(AsyncEnumeratorAdapters) + private static readonly MethodInfo GetAsyncEnumerableFromChannelMethod = typeof(AsyncEnumerableAdapters) .GetRuntimeMethods() - .Single(m => m.Name.Equals(nameof(AsyncEnumeratorAdapters.GetAsyncEnumeratorFromChannel)) && m.IsGenericMethod); + .Single(m => m.Name.Equals(nameof(AsyncEnumerableAdapters.GetAsyncEnumerableFromChannel)) && m.IsGenericMethod); - private readonly MethodInfo _convertToEnumeratorMethodInfo; - private Func> _convertToEnumerator; + private readonly MethodInfo _convertToEnumerableMethodInfo; + private Func> _convertToEnumerable; public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable policies) { @@ -46,14 +46,14 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable)) { StreamReturnType = returnType.GetGenericArguments()[0]; - _convertToEnumeratorMethodInfo = GetAsyncEnumeratorFromAsyncEnumerableMethod; + _convertToEnumerableMethodInfo = MakeCancelableAsyncEnumerableMethod; break; } if (openReturnType == typeof(ChannelReader<>)) { StreamReturnType = returnType.GetGenericArguments()[0]; - _convertToEnumeratorMethodInfo = GetAsyncEnumeratorFromChannelMethod; + _convertToEnumerableMethodInfo = GetAsyncEnumerableFromChannelMethod; break; } } @@ -62,7 +62,7 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable { // Only streams can take CancellationTokens currently - if (IsStreamable && p.ParameterType == typeof(CancellationToken)) + if (IsStreamResponse && p.ParameterType == typeof(CancellationToken)) { HasSyntheticArguments = true; return false; @@ -99,7 +99,7 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable StreamReturnType != null; + public bool IsStreamResponse => StreamReturnType != null; public Type StreamReturnType { get; } @@ -107,23 +107,22 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable FromReturnedStream(object stream, CancellationToken cancellationToken) + public IAsyncEnumerable FromReturnedStream(object stream, CancellationToken cancellationToken) { // there is the potential for compile to be called times but this has no harmful effect other than perf - if (_convertToEnumerator == null) + if (_convertToEnumerable == null) { - _convertToEnumerator = CompileConvertToEnumerator(_convertToEnumeratorMethodInfo, StreamReturnType); + _convertToEnumerable = CompileConvertToEnumerable(_convertToEnumerableMethodInfo, StreamReturnType); } - return _convertToEnumerator.Invoke(stream, cancellationToken); + return _convertToEnumerable.Invoke(stream, cancellationToken); } - private static Func> CompileConvertToEnumerator(MethodInfo adapterMethodInfo, Type streamReturnType) + private static Func> CompileConvertToEnumerable(MethodInfo adapterMethodInfo, Type streamReturnType) { - // This will call one of two adapter methods to wrap the passed in streamable value and cancellation token - // into an IAsyncEnumerator: - // - AsyncEnumeratorAdapters.GetAsyncEnumeratorFromAsyncEnumerable(asyncEnumerable, cancellationToken); - // - AsyncEnumeratorAdapters.GetAsyncEnumeratorFromChannel(channelReader, cancellationToken); + // This will call one of two adapter methods to wrap the passed in streamable value into an IAsyncEnumerable: + // - AsyncEnumerableAdapters.GetAsyncEnumerableFromAsyncEnumerable(asyncEnumerable, cancellationToken); + // - AsyncEnumerableAdapters.GetAsyncEnumerableFromChannel(channelReader, cancellationToken); var parameters = new[] { @@ -140,7 +139,7 @@ private static Func> Compile }; var methodCall = Expression.Call(null, genericMethodInfo, methodArguements); - var lambda = Expression.Lambda>>(methodCall, parameters); + var lambda = Expression.Lambda>>(methodCall, parameters); return lambda.Compile(); } } diff --git a/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj b/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj index d26d4a30a9ef..1a5f643bad25 100644 --- a/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj +++ b/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj @@ -1,4 +1,4 @@ - + Real-time communication framework for ASP.NET Core. diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs index fe3a391ebefc..0d5e64d1c63e 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs @@ -694,7 +694,7 @@ public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToke return new ChannelAsyncEnumerator(_inner, cancellationToken); } - // Copied from AsyncEnumeratorAdapters.ChannelAsyncEnumerator. Implements IAsyncEnumerator instead of IAsyncEnumerator. + // Copied from AsyncEnumeratorAdapters private class ChannelAsyncEnumerator : IAsyncEnumerator { /// The channel being enumerated. @@ -799,7 +799,7 @@ public async Task> LongRunningStream() return Channel.CreateUnbounded().Reader; } - public ChannelReader CancelableStream(CancellationToken token) + public ChannelReader CancelableStreamSingleParameter(CancellationToken token) { var channel = Channel.CreateBounded(10); @@ -814,15 +814,7 @@ public ChannelReader CancelableStream(CancellationToken token) return channel.Reader; } - public async IAsyncEnumerable CancelableAsyncEnumerableStream(CancellationToken token) - { - _tcsService.StartedMethod.SetResult(null); - await token.WaitForCancellationAsync(); - _tcsService.EndMethod.SetResult(null); - yield break; - } - - public ChannelReader CancelableStream2(int ignore, int ignore2, CancellationToken token) + public ChannelReader CancelableStreamMultiParameter(int ignore, int ignore2, CancellationToken token) { var channel = Channel.CreateBounded(10); @@ -837,7 +829,7 @@ public ChannelReader CancelableStream2(int ignore, int ignore2, Cancellatio return channel.Reader; } - public ChannelReader CancelableStreamMiddle(int ignore, CancellationToken token, int ignore2) + public ChannelReader CancelableStreamMiddleParameter(int ignore, CancellationToken token, int ignore2) { var channel = Channel.CreateBounded(10); @@ -852,10 +844,65 @@ public ChannelReader CancelableStreamMiddle(int ignore, CancellationToken t return channel.Reader; } + public async IAsyncEnumerable CancelableStreamGeneratedAsyncEnumerable(CancellationToken token) + { + _tcsService.StartedMethod.SetResult(null); + await token.WaitForCancellationAsync(); + _tcsService.EndMethod.SetResult(null); + yield break; + } + + public IAsyncEnumerable CancelableStreamCustomAsyncEnumerable() + { + return new CustomAsyncEnumerable(_tcsService); + } + public int SimpleMethod() { return 21; } + + private class CustomAsyncEnumerable : IAsyncEnumerable + { + private readonly TcsService _tcsService; + + public CustomAsyncEnumerable(TcsService tcsService) + { + _tcsService = tcsService; + } + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new CustomAsyncEnumerator(_tcsService, cancellationToken); + } + + private class CustomAsyncEnumerator : IAsyncEnumerator + { + private readonly TcsService _tcsService; + private readonly CancellationToken _cancellationToken; + + public CustomAsyncEnumerator(TcsService tcsService, CancellationToken cancellationToken) + { + _tcsService = tcsService; + _cancellationToken = cancellationToken; + } + + public int Current => throw new NotImplementedException(); + + public ValueTask DisposeAsync() + { + return default; + } + + public async ValueTask MoveNextAsync() + { + _tcsService.StartedMethod.SetResult(null); + await _cancellationToken.WaitForCancellationAsync(); + _tcsService.EndMethod.SetResult(null); + return false; + } + } + } } public class TcsService diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index acf5ffcde906..b7a576954fd3 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -3158,11 +3158,12 @@ public async Task UploadStreamAndStreamingMethodClosesStreamsOnServerWhenMethodC } [Theory] - [InlineData(nameof(LongRunningHub.CancelableStream))] - [InlineData(nameof(LongRunningHub.CancelableStream2), 1, 2)] - [InlineData(nameof(LongRunningHub.CancelableStreamMiddle), 1, 2)] - [InlineData(nameof(LongRunningHub.CancelableAsyncEnumerableStream))] - public async Task StreamHubMethodCanAcceptCancellationTokenAsArgumentAndBeTriggeredOnCancellation(string methodName, params object[] args) + [InlineData(nameof(LongRunningHub.CancelableStreamSingleParameter))] + [InlineData(nameof(LongRunningHub.CancelableStreamMultiParameter), 1, 2)] + [InlineData(nameof(LongRunningHub.CancelableStreamMiddleParameter), 1, 2)] + [InlineData(nameof(LongRunningHub.CancelableStreamGeneratedAsyncEnumerable))] + [InlineData(nameof(LongRunningHub.CancelableStreamCustomAsyncEnumerable))] + public async Task StreamHubMethodCanBeTriggeredOnCancellation(string methodName, params object[] args) { using (StartVerifiableLog()) { @@ -3216,7 +3217,7 @@ public async Task StreamHubMethodCanAcceptCancellationTokenAsArgumentAndBeTrigge { var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); - var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.CancelableStream)).OrTimeout(); + var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.CancelableStreamSingleParameter)).OrTimeout(); // Wait for the stream method to start await tcsService.StartedMethod.Task.OrTimeout(); From 646060aad039559a17b4eace9eb622534d90c990 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Thu, 21 Feb 2019 18:38:18 -0800 Subject: [PATCH 6/7] Use IAsyncEnumerable returned by ChannelReader.ReadAllAsync() --- .../SignalRSamples/SignalRSamples.csproj | 1 + .../src/Internal/AsyncEnumerableAdapters.cs | 83 ++----------------- .../Core/src/Internal/DefaultHubDispatcher.cs | 2 +- .../Core/src/Internal/HubMethodDescriptor.cs | 22 ++--- .../Microsoft.AspNetCore.SignalR.Core.csproj | 3 +- .../Microsoft.AspNetCore.SignalR.Tests.csproj | 1 + 6 files changed, 25 insertions(+), 87 deletions(-) diff --git a/src/SignalR/samples/SignalRSamples/SignalRSamples.csproj b/src/SignalR/samples/SignalRSamples/SignalRSamples.csproj index 283fe8d1bb39..6f0b3379d822 100644 --- a/src/SignalR/samples/SignalRSamples/SignalRSamples.csproj +++ b/src/SignalR/samples/SignalRSamples/SignalRSamples.csproj @@ -2,6 +2,7 @@ netcoreapp3.0 + 8.0 diff --git a/src/SignalR/server/Core/src/Internal/AsyncEnumerableAdapters.cs b/src/SignalR/server/Core/src/Internal/AsyncEnumerableAdapters.cs index 7b550d9d359b..c0bef6ad7808 100644 --- a/src/SignalR/server/Core/src/Internal/AsyncEnumerableAdapters.cs +++ b/src/SignalR/server/Core/src/Internal/AsyncEnumerableAdapters.cs @@ -17,9 +17,9 @@ public static IAsyncEnumerable MakeCancelableAsyncEnumerable(IAsyncEn return new CancelableAsyncEnumerable(asyncEnumerable, cancellationToken); } - public static IAsyncEnumerable GetAsyncEnumerableFromChannel(ChannelReader channel, CancellationToken cancellationToken = default) + public static IAsyncEnumerable MakeCancelableAsyncEnumerableFromChannel(ChannelReader channel, CancellationToken cancellationToken = default) { - return new ChannelAsyncEnumerable(channel, cancellationToken); + return MakeCancelableAsyncEnumerable(channel.ReadAllAsync(), cancellationToken); } /// Converts an IAsyncEnumerable of T to an IAsyncEnumerable of object. @@ -36,15 +36,19 @@ public CancelableAsyncEnumerable(IAsyncEnumerable asyncEnumerable, Cancellati public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { + // Assume that this will be iterated through with await foreach which always passes a default token. + // Instead use the token from the ctor. Debug.Assert(cancellationToken == default); - return new CancelableAsyncEnumerator(_asyncEnumerable.GetAsyncEnumerator(_cancellationToken)); + + var enumeratorOfT = _asyncEnumerable.GetAsyncEnumerator(_cancellationToken); + return enumeratorOfT as IAsyncEnumerator ?? new BoxedAsyncEnumerator(enumeratorOfT); } - private class CancelableAsyncEnumerator : IAsyncEnumerator + private class BoxedAsyncEnumerator : IAsyncEnumerator { private IAsyncEnumerator _asyncEnumerator; - public CancelableAsyncEnumerator(IAsyncEnumerator asyncEnumerator) + public BoxedAsyncEnumerator(IAsyncEnumerator asyncEnumerator) { _asyncEnumerator = asyncEnumerator; } @@ -62,74 +66,5 @@ public ValueTask DisposeAsync() } } } - - /// Provides an IAsyncEnumerable of object for the data in a channel. - private class ChannelAsyncEnumerable : IAsyncEnumerable - { - private readonly ChannelReader _channel; - private readonly CancellationToken _cancellationToken; - - public ChannelAsyncEnumerable(ChannelReader channel, CancellationToken cancellationToken) - { - _channel = channel; - _cancellationToken = cancellationToken; - } - - public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) - { - Debug.Assert(cancellationToken == default); - return new ChannelAsyncEnumerator(_channel, _cancellationToken); - } - - private class ChannelAsyncEnumerator : IAsyncEnumerator - { - /// The channel being enumerated. - private readonly ChannelReader _channel; - /// Cancellation token used to cancel the enumeration. - private readonly CancellationToken _cancellationToken; - /// The current element of the enumeration. - private T _current; - - public ChannelAsyncEnumerator(ChannelReader channel, CancellationToken cancellationToken) - { - _channel = channel; - _cancellationToken = cancellationToken; - } - - public object Current => _current; - - public ValueTask MoveNextAsync() - { - var result = _channel.ReadAsync(_cancellationToken); - - if (result.IsCompletedSuccessfully) - { - _current = result.Result; - return new ValueTask(true); - } - - return new ValueTask(MoveNextAsyncAwaited(result)); - } - - private async Task MoveNextAsyncAwaited(ValueTask channelReadTask) - { - try - { - _current = await channelReadTask; - } - catch (ChannelClosedException ex) when (ex.InnerException == null) - { - return false; - } - - return true; - } - - public ValueTask DisposeAsync() - { - return default; - } - } - } } } diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 99a0b4bd3cbe..c08faa645680 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -397,7 +397,7 @@ private ValueTask CleanupInvocation(HubConnectionContext connection, HubMethodIn return scope.DisposeAsync(); } - private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerable enumerable, IServiceScope scope, + private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerable enumerable, IServiceScope scope, IHubActivator hubActivator, THub hub, CancellationTokenSource streamCts, HubMethodInvocationMessage hubMethodInvocationMessage) { string error = null; diff --git a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs index 3bc4d3fce630..205c1ced7244 100644 --- a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs +++ b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs @@ -19,12 +19,12 @@ internal class HubMethodDescriptor .GetRuntimeMethods() .Single(m => m.Name.Equals(nameof(AsyncEnumerableAdapters.MakeCancelableAsyncEnumerable)) && m.IsGenericMethod); - private static readonly MethodInfo GetAsyncEnumerableFromChannelMethod = typeof(AsyncEnumerableAdapters) + private static readonly MethodInfo MakeCancelableAsyncEnumerableFromChannelMethod = typeof(AsyncEnumerableAdapters) .GetRuntimeMethods() - .Single(m => m.Name.Equals(nameof(AsyncEnumerableAdapters.GetAsyncEnumerableFromChannel)) && m.IsGenericMethod); + .Single(m => m.Name.Equals(nameof(AsyncEnumerableAdapters.MakeCancelableAsyncEnumerableFromChannel)) && m.IsGenericMethod); - private readonly MethodInfo _convertToEnumerableMethodInfo; - private Func> _convertToEnumerable; + private readonly MethodInfo _makeCancelableEnumerableMethodInfo; + private Func> _makeCancelableEnumerable; public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable policies) { @@ -46,14 +46,14 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable)) { StreamReturnType = returnType.GetGenericArguments()[0]; - _convertToEnumerableMethodInfo = MakeCancelableAsyncEnumerableMethod; + _makeCancelableEnumerableMethodInfo = MakeCancelableAsyncEnumerableMethod; break; } if (openReturnType == typeof(ChannelReader<>)) { StreamReturnType = returnType.GetGenericArguments()[0]; - _convertToEnumerableMethodInfo = GetAsyncEnumerableFromChannelMethod; + _makeCancelableEnumerableMethodInfo = MakeCancelableAsyncEnumerableFromChannelMethod; break; } } @@ -110,19 +110,19 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable FromReturnedStream(object stream, CancellationToken cancellationToken) { // there is the potential for compile to be called times but this has no harmful effect other than perf - if (_convertToEnumerable == null) + if (_makeCancelableEnumerable == null) { - _convertToEnumerable = CompileConvertToEnumerable(_convertToEnumerableMethodInfo, StreamReturnType); + _makeCancelableEnumerable = CompileConvertToEnumerable(_makeCancelableEnumerableMethodInfo, StreamReturnType); } - return _convertToEnumerable.Invoke(stream, cancellationToken); + return _makeCancelableEnumerable.Invoke(stream, cancellationToken); } private static Func> CompileConvertToEnumerable(MethodInfo adapterMethodInfo, Type streamReturnType) { // This will call one of two adapter methods to wrap the passed in streamable value into an IAsyncEnumerable: - // - AsyncEnumerableAdapters.GetAsyncEnumerableFromAsyncEnumerable(asyncEnumerable, cancellationToken); - // - AsyncEnumerableAdapters.GetAsyncEnumerableFromChannel(channelReader, cancellationToken); + // - AsyncEnumerableAdapters.MakeCancelableAsyncEnumerable(asyncEnumerable, cancellationToken); + // - AsyncEnumerableAdapters.MakeCancelableAsyncEnumerableFromChannel(channelReader, cancellationToken); var parameters = new[] { diff --git a/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj b/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj index 1a5f643bad25..8470423c2ce3 100644 --- a/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj +++ b/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj @@ -1,10 +1,11 @@ - + Real-time communication framework for ASP.NET Core. netcoreapp3.0 true Microsoft.AspNetCore.SignalR + 8.0 diff --git a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests.csproj b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests.csproj index c4fd5fd37b39..ec113f4e57ab 100644 --- a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests.csproj +++ b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests.csproj @@ -2,6 +2,7 @@ netcoreapp3.0 + 8.0 From 1cf98a5b5c533d61d4b5c3681680157733dbb244 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Fri, 22 Feb 2019 16:14:40 -0800 Subject: [PATCH 7/7] Fix client tests --- .../csharp/Client/test/FunctionalTests/HubConnectionTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs index f351e669702c..378f7bc7e088 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs @@ -822,7 +822,7 @@ bool ExpectedErrors(WriteContext writeContext) await connection.StartAsync().OrTimeout(); var channel = await connection.StreamAsChannelAsync("StreamBroken").OrTimeout(); var ex = await Assert.ThrowsAsync(() => channel.ReadAndCollectAllAsync()).OrTimeout(); - Assert.Equal("The value returned by the streaming method 'StreamBroken' is not a ChannelReader<>.", ex.Message); + Assert.Equal("The value returned by the streaming method 'StreamBroken' is not a ChannelReader<> or IAsyncEnumerable<>.", ex.Message); } catch (Exception ex) {