Skip to content
This repository was archived by the owner on Dec 18, 2018. It is now read-only.

Exceptions from user's event handlers should be caught and logged #819

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 35 additions & 14 deletions src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,25 @@ private async Task StartAsyncInternal()
if (Interlocked.CompareExchange(ref _connectionState, ConnectionState.Connected, ConnectionState.Connecting)
== ConnectionState.Connecting)
{
var ignore = _eventQueue.Enqueue(() =>
_ = _eventQueue.Enqueue(async () =>
{
_logger.RaiseConnected(_connectionId);

Connected?.Invoke();

return Task.CompletedTask;
var connectedEventHandler = Connected;
if (connectedEventHandler != null)
{
try
{
await connectedEventHandler.Invoke();
}
catch (Exception ex)
{
_logger.ExceptionThrownFromEventHandler(_connectionId, nameof(Connected), ex);
}
}
});

ignore = Input.Completion.ContinueWith(async t =>
_ = Input.Completion.ContinueWith(async t =>
{
Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected);

Expand All @@ -183,9 +192,18 @@ private async Task StartAsyncInternal()

_logger.RaiseClosed(_connectionId);

Closed?.Invoke(t.IsFaulted ? t.Exception.InnerException : null);

return Task.CompletedTask;
var closedEventHandler = Closed;
if (closedEventHandler != null)
{
try
{
await closedEventHandler.Invoke(t.IsFaulted ? t.Exception.InnerException : null);
}
catch (Exception ex)
{
_logger.ExceptionThrownFromEventHandler(_connectionId, nameof(Closed), ex);
}
}
});

// start receive loop only after the Connected event was raised to
Expand Down Expand Up @@ -331,19 +349,22 @@ private async Task ReceiveAsync()
if (Input.TryRead(out var buffer))
{
_logger.ScheduleReceiveEvent(_connectionId);
_ = _eventQueue.Enqueue(() =>
_ = _eventQueue.Enqueue(async () =>
{
_logger.RaiseReceiveEvent(_connectionId);

// Making a copy of the Received handler to ensure that its not null
// Can't use the ? operator because we specifically want to check if the handler is null
var receivedHandler = Received;
if (receivedHandler != null)
{
return receivedHandler(buffer);
try
{
await receivedHandler(buffer);
}
catch (Exception ex)
{
_logger.ExceptionThrownFromEventHandler(_connectionId, nameof(Received), ex);
}
}

return Task.CompletedTask;
});
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ internal static class SocketClientLoggerExtensions
private static readonly Action<ILogger, DateTime, string, Exception> _stoppingClient =
LoggerMessage.Define<DateTime, string>(LogLevel.Information, 18, "{time}: Connection Id {connectionId}: Stopping client.");

private static readonly Action<ILogger, DateTime, string, string, Exception> _exceptionThrownFromHandler =
LoggerMessage.Define<DateTime, string, string>(LogLevel.Error, 19, "{time}: Connection Id {connectionId}: An exception was thrown from the '{eventHandlerName}' event handler.");


public static void StartTransport(this ILogger logger, string connectionId, TransferMode transferMode)
{
if (logger.IsEnabled(LogLevel.Information))
Expand Down Expand Up @@ -509,5 +513,13 @@ public static void StoppingClient(this ILogger logger, string connectionId)
_stoppingClient(logger, DateTime.Now, connectionId, null);
}
}

public static void ExceptionThrownFromEventHandler(this ILogger logger, string connectionId, string eventHandlerName, Exception exception)
{
if (logger.IsEnabled(LogLevel.Error))
{
_exceptionThrownFromHandler(logger, DateTime.Now, connectionId, eventHandlerName, exception);
}
}
}
}
124 changes: 123 additions & 1 deletion test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ public async Task EventQueueTimeout()
closedTcs.SetResult(null);
return Task.CompletedTask;
};

await connection.StartAsync();
channel.Out.TryWrite(Array.Empty<byte>());

Expand Down Expand Up @@ -746,6 +746,128 @@ public async Task CanReceiveData()
}
}

[Fact]
public async Task CanReceiveDataEvenIfUserThrowsInConnectedEvent()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test for a synchronous exception.

{
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();

var content = string.Empty;

if (request.Method == HttpMethod.Get)
{
content = "42";
}

return request.Method == HttpMethod.Options
? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse())
: ResponseUtils.CreateResponse(HttpStatusCode.OK, content);
});

var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object);
try
{
connection.Connected += () => Task.FromException(new InvalidOperationException());

var receiveTcs = new TaskCompletionSource<string>();
connection.Received += data =>
{
receiveTcs.TrySetResult(Encoding.UTF8.GetString(data));
return Task.CompletedTask;
};

connection.Closed += e =>
{
if (e != null)
{
receiveTcs.TrySetException(e);
}
else
{
receiveTcs.TrySetCanceled();
}
return Task.CompletedTask;
};

await connection.StartAsync();

Assert.Equal("42", await receiveTcs.Task.OrTimeout());
}
finally
{
await connection.DisposeAsync();
}
}

[Fact]
public async Task CanReceiveDataEvenIfExceptionThrownFromPreviousReceivedEvent()
{
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();

var content = string.Empty;

if (request.Method == HttpMethod.Get)
{
content = "42";
}

return request.Method == HttpMethod.Options
? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse())
: ResponseUtils.CreateResponse(HttpStatusCode.OK, content);
});

var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object);
try
{

var receiveTcs = new TaskCompletionSource<string>();

var receivedRaised = false;
connection.Received += data =>
{
if (!receivedRaised)
{
receivedRaised = true;
return Task.FromException(new InvalidOperationException());
}

receiveTcs.TrySetResult(Encoding.UTF8.GetString(data));
return Task.CompletedTask;
};

connection.Closed += e =>
{
if (e != null)
{
receiveTcs.TrySetException(e);
}
else
{
receiveTcs.TrySetCanceled();
}
return Task.CompletedTask;
};

await connection.StartAsync();

Assert.Equal("42", await receiveTcs.Task.OrTimeout());
}
finally
{
await connection.DisposeAsync();
}
}


[Fact]
public async Task CannotSendAfterReceiveThrewException()
{
Expand Down