Skip to content

Send CloseMessage in more cases #25693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1349,7 +1349,11 @@ public async Task ServerLogsErrorIfClientInvokeCannotBeSerialized(string protoco
};

var protocol = HubProtocols[protocolName];
await using (var server = await StartServer<Startup>(write => write.EventId.Name == "FailedWritingMessage"))
await using (var server = await StartServer<Startup>(write =>
{
return write.EventId.Name == "FailedWritingMessage" || write.EventId.Name == "ReceivedCloseWithError"
|| write.EventId.Name == "ShutdownWithError";
}))
{
var connection = CreateHubConnection(server.Url, "/default", HttpTransportType.WebSockets, protocol, LoggerFactory);
var closedTcs = new TaskCompletionSource<Exception>(TaskCreationOptions.RunContinuationsAsynchronously);
Expand All @@ -1361,9 +1365,12 @@ public async Task ServerLogsErrorIfClientInvokeCannotBeSerialized(string protoco
var result = connection.InvokeAsync<string>(nameof(TestHub.CallWithUnserializableObject));

// The connection should close.
Assert.Null(await closedTcs.Task.OrTimeout());
var exception = await closedTcs.Task.OrTimeout();
Assert.Contains("Connection closed with an error.", exception.Message);

await Assert.ThrowsAsync<TaskCanceledException>(() => result).OrTimeout();
var hubException = await Assert.ThrowsAsync<HubException>(() => result).OrTimeout();
Assert.Contains("Connection closed with an error.", hubException.Message);
Assert.Contains(exceptionSubstring, hubException.Message);
}
catch (Exception ex)
{
Expand Down Expand Up @@ -1396,7 +1403,11 @@ public async Task ServerLogsErrorIfReturnValueCannotBeSerialized(string protocol
};

var protocol = HubProtocols[protocolName];
await using (var server = await StartServer<Startup>(write => write.EventId.Name == "FailedWritingMessage"))
await using (var server = await StartServer<Startup>(write =>
{
return write.EventId.Name == "FailedWritingMessage" || write.EventId.Name == "ReceivedCloseWithError"
|| write.EventId.Name == "ShutdownWithError";
}))
{
var connection = CreateHubConnection(server.Url, "/default", HttpTransportType.LongPolling, protocol, LoggerFactory);
var closedTcs = new TaskCompletionSource<Exception>(TaskCreationOptions.RunContinuationsAsynchronously);
Expand All @@ -1408,9 +1419,12 @@ public async Task ServerLogsErrorIfReturnValueCannotBeSerialized(string protocol
var result = connection.InvokeAsync<string>(nameof(TestHub.GetUnserializableObject)).OrTimeout();

// The connection should close.
Assert.Null(await closedTcs.Task.OrTimeout());
var exception = await closedTcs.Task.OrTimeout();
Assert.Contains("Connection closed with an error.", exception.Message);

await Assert.ThrowsAsync<TaskCanceledException>(() => result).OrTimeout();
var hubException = await Assert.ThrowsAsync<HubException>(() => result).OrTimeout();
Assert.Contains("Connection closed with an error.", hubException.Message);
Assert.Contains(exceptionSubstring, hubException.Message);
}
catch (Exception ex)
{
Expand Down
15 changes: 10 additions & 5 deletions src/SignalR/server/Core/src/HubConnectionContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,19 @@ internal StreamTracker StreamTracker
internal ConcurrentDictionary<string, CancellationTokenSource> ActiveRequestCancellationSources { get; } = new ConcurrentDictionary<string, CancellationTokenSource>(StringComparer.Ordinal);

public virtual ValueTask WriteAsync(HubMessage message, CancellationToken cancellationToken = default)
{
return WriteAsync(message, ignoreAbort: false, cancellationToken);
}

internal ValueTask WriteAsync(HubMessage message, bool ignoreAbort, CancellationToken cancellationToken = default)
{
// Try to grab the lock synchronously, if we fail, go to the slower path
if (!_writeLock.Wait(0))
{
return new ValueTask(WriteSlowAsync(message, cancellationToken));
return new ValueTask(WriteSlowAsync(message, ignoreAbort, cancellationToken));
}

if (_connectionAborted)
if (_connectionAborted && !ignoreAbort)
{
_writeLock.Release();
return default;
Expand Down Expand Up @@ -272,14 +277,14 @@ private async Task CompleteWriteAsync(ValueTask<FlushResult> task)
}
}

private async Task WriteSlowAsync(HubMessage message, CancellationToken cancellationToken)
private async Task WriteSlowAsync(HubMessage message, bool ignoreAbort, CancellationToken cancellationToken)
{
// Failed to get the lock immediately when entering WriteAsync so await until it is available
await _writeLock.WaitAsync(cancellationToken);

try
{
if (_connectionAborted)
if (_connectionAborted && !ignoreAbort)
{
return;
}
Expand All @@ -301,7 +306,7 @@ private async Task WriteSlowAsync(HubMessage message, CancellationToken cancella
private async Task WriteSlowAsync(SerializedHubMessage message, CancellationToken cancellationToken)
{
// Failed to get the lock immediately when entering WriteAsync so await until it is available
await _writeLock.WaitAsync();
await _writeLock.WaitAsync(cancellationToken);

try
{
Expand Down
2 changes: 1 addition & 1 deletion src/SignalR/server/Core/src/HubConnectionHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ private async Task SendCloseAsync(HubConnectionContext connection, Exception? ex

try
{
await connection.WriteAsync(closeMessage);
await connection.WriteAsync(closeMessage, ignoreAbort: true);
}
catch (Exception ex)
{
Expand Down
23 changes: 12 additions & 11 deletions src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ public async Task AbortFromHubMethodForcesClientDisconnect()

await client.SendInvocationAsync(nameof(AbortHub.Kill)).OrTimeout();

var close = Assert.IsType<CloseMessage>(await client.ReadAsync().OrTimeout());
Assert.False(close.AllowReconnect);

await connectionHandlerTask.OrTimeout();

Assert.Null(client.TryRead());
Expand Down Expand Up @@ -955,15 +958,18 @@ public async Task HubMethodListeningToConnectionAbortedClosesOnConnectionContext
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);

var invokeTask = client.InvokeAsync(nameof(MethodHub.BlockingMethod));
await client.SendInvocationAsync(nameof(MethodHub.BlockingMethod)).OrTimeout();

client.Connection.Abort();

var closeMessage = Assert.IsType<CloseMessage>(await client.ReadAsync().OrTimeout());
Assert.False(closeMessage.AllowReconnect);

// If this completes then the server has completed the connection
await connectionHandlerTask.OrTimeout();

// Nothing written to connection because it was closed
Assert.False(invokeTask.IsCompleted);
Assert.Null(client.TryRead());
}
}
}
Expand Down Expand Up @@ -1019,16 +1025,11 @@ public async Task HubMethodDoesNotSendResultWhenInvocationIsNonBlocking()
// kill the connection
client.Dispose();

var message = Assert.IsType<CloseMessage>(client.TryRead());
Assert.True(message.AllowReconnect);

// Ensure the client channel is empty
var message = client.TryRead();
switch (message)
{
case CloseMessage close:
break;
default:
Assert.Null(message);
break;
}
Assert.Null(client.TryRead());

await connectionHandlerTask.OrTimeout();
}
Expand Down