diff --git a/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs b/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs index 93556e8d9491..6ebd1dcb25f3 100644 --- a/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs +++ b/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs @@ -51,10 +51,37 @@ public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellati { ((CancellationTokenSource)ctsState).Cancel(); }, _cts); + + return new CancelableEnumerator(_asyncEnumerable.GetAsyncEnumerator(), registration); } return enumerator; } + + private class CancelableEnumerator : IAsyncEnumerator + { + private IAsyncEnumerator _asyncEnumerator; + private readonly CancellationTokenRegistration _cancellationTokenRegistration; + + public T Current => (T)_asyncEnumerator.Current; + + public CancelableEnumerator(IAsyncEnumerator asyncEnumerator, CancellationTokenRegistration registration) + { + _asyncEnumerator = asyncEnumerator; + _cancellationTokenRegistration = registration; + } + + public ValueTask MoveNextAsync() + { + return _asyncEnumerator.MoveNextAsync(); + } + + public ValueTask DisposeAsync() + { + _cancellationTokenRegistration.Dispose(); + return _asyncEnumerator.DisposeAsync(); + } + } } /// Converts an IAsyncEnumerable of T to an IAsyncEnumerable of object. @@ -71,6 +98,10 @@ public CancelableAsyncEnumerable(IAsyncEnumerable asyncEnumerable, Cancellati public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { + // Assume that this will be iterated through with await foreach which always passes a default token. + // Instead use the token from the ctor. + Debug.Assert(cancellationToken == default); + var enumeratorOfT = _asyncEnumerable.GetAsyncEnumerator(_cancellationToken); return enumeratorOfT as IAsyncEnumerator ?? new BoxedAsyncEnumerator(enumeratorOfT); } diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 6d097d731def..ee39dcad9d1d 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -403,12 +403,13 @@ private async Task StreamResultsAsync(string invocationId, HubConnectionContext IHubActivator hubActivator, THub hub, CancellationTokenSource streamCts, HubMethodInvocationMessage hubMethodInvocationMessage) { string error = null; + try { - await foreach(var item in enumerable.WithCancellation(streamCts.Token)) + await foreach (var streamItem in enumerable) { // Send the stream item - await connection.WriteAsync(new StreamItemMessage(invocationId, item)); + await connection.WriteAsync(new StreamItemMessage(invocationId, streamItem)); } } catch (ChannelClosedException ex)