Skip to content

Support IAsyncEnumerable returns in SignalR hubs #6791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ bool ExpectedErrors(WriteContext writeContext)
await connection.StartAsync().OrTimeout();
var channel = await connection.StreamAsChannelAsync<int>("StreamBroken").OrTimeout();
var ex = await Assert.ThrowsAsync<HubException>(() => channel.ReadAndCollectAllAsync()).OrTimeout();
Assert.Equal("The value returned by the streaming method 'StreamBroken' is not a ChannelReader<>.", ex.Message);
Assert.Equal("The value returned by the streaming method 'StreamBroken' is not a ChannelReader<> or IAsyncEnumerable<>.", ex.Message);
}
catch (Exception ex)
{
Expand Down
10 changes: 10 additions & 0 deletions src/SignalR/samples/SignalRSamples/Hubs/Streaming.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Reactive.Linq;
using System.Threading.Channels;
using System.Threading.Tasks;
Expand All @@ -11,6 +12,15 @@ namespace SignalRSamples.Hubs
{
public class Streaming : Hub
{
public async IAsyncEnumerable<int> AsyncEnumerableCounter(int count, int delay)
{
for (var i = 0; i < count; i++)
{
yield return i;
await Task.Delay(delay);
}
}

public ChannelReader<int> ObservableCounter(int count, int delay)
{
var observable = Observable.Interval(TimeSpan.FromMilliseconds(delay))
Expand Down
3 changes: 2 additions & 1 deletion src/SignalR/samples/SignalRSamples/SignalRSamples.csproj
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
<Project Sdk="Microsoft.NET.Sdk.Web">
<Project Sdk="Microsoft.NET.Sdk.Web">

<PropertyGroup>
<TargetFramework>netcoreapp3.0</TargetFramework>
<LangVersion>8.0</LangVersion>
</PropertyGroup>

<ItemGroup>
Expand Down
9 changes: 8 additions & 1 deletion src/SignalR/samples/SignalRSamples/wwwroot/streaming.html
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ <h2>Controls</h2>
</div>

<div>
<button id="asyncEnumerableButton" name="asyncEnumerable" type="button" disabled>From IAsyncEnumerable</button>
<button id="observableButton" name="observable" type="button" disabled>From Observable</button>
<button id="channelButton" name="channel" type="button" disabled>From Channel</button>
</div>
Expand All @@ -32,7 +33,7 @@ <h2>Results</h2>
let resultsList = document.getElementById('resultsList');
let channelButton = document.getElementById('channelButton');
let observableButton = document.getElementById('observableButton');
let clearButton = document.getElementById('clearButton');
let asyncEnumerableButton = document.getElementById('asyncEnumerableButton');

let connectButton = document.getElementById('connectButton');
let disconnectButton = document.getElementById('disconnectButton');
Expand Down Expand Up @@ -61,6 +62,7 @@ <h2>Results</h2>
connection.onclose(function () {
channelButton.disabled = true;
observableButton.disabled = true;
asyncEnumerableButton.disabled = true;
connectButton.disabled = false;
disconnectButton.disabled = true;

Expand All @@ -71,12 +73,17 @@ <h2>Results</h2>
.then(function () {
channelButton.disabled = false;
observableButton.disabled = false;
asyncEnumerableButton.disabled = false;
connectButton.disabled = true;
disconnectButton.disabled = false;
addLine('resultsList', 'connected', 'green');
});
});

click('asyncEnumerableButton', function () {
run('AsyncEnumerableCounter');
})

click('observableButton', function () {
run('ObservableCounter');
});
Expand Down
70 changes: 70 additions & 0 deletions src/SignalR/server/Core/src/Internal/AsyncEnumerableAdapters.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;

namespace Microsoft.AspNetCore.SignalR.Internal
{
// True-internal because this is a weird and tricky class to use :)
internal static class AsyncEnumerableAdapters
{
public static IAsyncEnumerable<object> MakeCancelableAsyncEnumerable<T>(IAsyncEnumerable<T> asyncEnumerable, CancellationToken cancellationToken = default)
{
return new CancelableAsyncEnumerable<T>(asyncEnumerable, cancellationToken);
}

public static IAsyncEnumerable<object> MakeCancelableAsyncEnumerableFromChannel<T>(ChannelReader<T> channel, CancellationToken cancellationToken = default)
{
return MakeCancelableAsyncEnumerable(channel.ReadAllAsync(), cancellationToken);
}

/// <summary>Converts an IAsyncEnumerable of T to an IAsyncEnumerable of object.</summary>
private class CancelableAsyncEnumerable<T> : IAsyncEnumerable<object>
{
private readonly IAsyncEnumerable<T> _asyncEnumerable;
private readonly CancellationToken _cancellationToken;

public CancelableAsyncEnumerable(IAsyncEnumerable<T> asyncEnumerable, CancellationToken cancellationToken)
{
_asyncEnumerable = asyncEnumerable;
_cancellationToken = cancellationToken;
}

public IAsyncEnumerator<object> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
// Assume that this will be iterated through with await foreach which always passes a default token.
// Instead use the token from the ctor.
Debug.Assert(cancellationToken == default);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

REVIEW: The reason we have to wrap IAsyncEnumerators with non-value-type generic parameters instead of just casting is because we iterate using await foreach which doesn't support flowing a CancellationToken through GetAsyncEnumerator. If we manually iterated over the enumerator, we could save ourselves wrapping in most cases.

So should we manually iterate instead of wrapping the GetAsyncEnumerator call?

@davidfowl

Copy link
Member Author

@halter73 halter73 Feb 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are the only consumer of the IAsyncEnumerable<object> btw

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we could loop over MoveNextAsync() manually and avoid the BoxedAsyncEnumerator? That sounds nice :D

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we should optimize

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could have said something before I merged. It was a while loop that manually called GetAsyncEnumerator/MoveNextAsync until you asked me to change it to await foreach @davidfowl 😆

I figured didn't care about the allocation since a bunch of other allocations happen in this code path anyway. I created #7960 to track this with an XS tag. If we want to change this, it's not difficult.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I care, lets fix it in preview4. Assign it to yourself 😄

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm glad you care 😉


var enumeratorOfT = _asyncEnumerable.GetAsyncEnumerator(_cancellationToken);
return enumeratorOfT as IAsyncEnumerator<object> ?? new BoxedAsyncEnumerator(enumeratorOfT);
}

private class BoxedAsyncEnumerator : IAsyncEnumerator<object>
{
private IAsyncEnumerator<T> _asyncEnumerator;

public BoxedAsyncEnumerator(IAsyncEnumerator<T> asyncEnumerator)
{
_asyncEnumerator = asyncEnumerator;
}

public object Current => _asyncEnumerator.Current;

public ValueTask<bool> MoveNextAsync()
{
return _asyncEnumerator.MoveNextAsync();
}

public ValueTask DisposeAsync()
{
return _asyncEnumerator.DisposeAsync();
}
}
}
}
}
84 changes: 0 additions & 84 deletions src/SignalR/server/Core/src/Internal/AsyncEnumeratorAdapters.cs

This file was deleted.

44 changes: 13 additions & 31 deletions src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -293,16 +293,20 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
{
var result = await ExecuteHubMethod(methodExecutor, hub, arguments);

if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, ref cts))
if (result == null)
{
Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name);
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
$"The value returned by the streaming method '{methodExecutor.MethodInfo.Name}' is not a ChannelReader<>.");
$"The value returned by the streaming method '{methodExecutor.MethodInfo.Name}' is not a ChannelReader<> or IAsyncEnumerable<>.");
return;
}

cts = cts ?? CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
connection.ActiveRequestCancellationSources.TryAdd(hubMethodInvocationMessage.InvocationId, cts);
var enumerable = descriptor.FromReturnedStream(result, cts.Token);

Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, cts, hubMethodInvocationMessage);
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerable, scope, hubActivator, hub, cts, hubMethodInvocationMessage);
}

else if (string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
Expand Down Expand Up @@ -393,17 +397,17 @@ private ValueTask CleanupInvocation(HubConnectionContext connection, HubMethodIn
return scope.DisposeAsync();
}

private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerator<object> enumerator, IServiceScope scope,
private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerable<object> enumerable, IServiceScope scope,
IHubActivator<THub> hubActivator, THub hub, CancellationTokenSource streamCts, HubMethodInvocationMessage hubMethodInvocationMessage)
{
string error = null;

try
{
while (await enumerator.MoveNextAsync())
await foreach (var streamItem in enumerable)
{
// Send the stream item
await connection.WriteAsync(new StreamItemMessage(invocationId, enumerator.Current));
await connection.WriteAsync(new StreamItemMessage(invocationId, streamItem));
}
}
catch (ChannelClosedException ex)
Expand All @@ -422,8 +426,6 @@ private async Task StreamResultsAsync(string invocationId, HubConnectionContext
}
finally
{
(enumerator as IDisposable)?.Dispose();

await CleanupInvocation(connection, hubMethodInvocationMessage, hubActivator, hub, scope);

// Dispose the linked CTS for the stream.
Expand Down Expand Up @@ -502,10 +504,10 @@ private static async Task<bool> IsHubMethodAuthorizedSlow(IServiceProvider provi
return authorizationResult.Succeeded;
}

private async Task<bool> ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamedInvocation,
private async Task<bool> ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamResponse,
HubMethodInvocationMessage hubMethodInvocationMessage, HubConnectionContext connection)
{
if (hubMethodDescriptor.IsStreamable && !isStreamedInvocation)
if (hubMethodDescriptor.IsStreamResponse && !isStreamResponse)
{
// Non-null/empty InvocationId? Blocking
if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
Expand All @@ -518,7 +520,7 @@ await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessa
return false;
}

if (!hubMethodDescriptor.IsStreamable && isStreamedInvocation)
if (!hubMethodDescriptor.IsStreamResponse && isStreamResponse)
{
Log.NonStreamingMethodCalledWithStream(_logger, hubMethodInvocationMessage);
await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId,
Expand All @@ -530,26 +532,6 @@ await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessa
return true;
}

private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator<object> enumerator, ref CancellationTokenSource streamCts)
{
if (result != null)
{
if (hubMethodDescriptor.IsChannel)
{
if (streamCts == null)
{
streamCts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
}
connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts);
enumerator = hubMethodDescriptor.FromChannel(result, streamCts.Token);
return true;
}
}

enumerator = null;
return false;
}

private void DiscoverHubMethods()
{
var hubType = typeof(THub);
Expand Down
Loading