diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs index 898c4aff32..e8c4e46553 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs @@ -24,6 +24,8 @@ namespace Microsoft.AspNetCore.SignalR.Client { public class HubConnection { + public static readonly TimeSpan DefaultServerTimeout = TimeSpan.FromSeconds(30); // Server ping rate is 15 sec, this is 2 times that. + private readonly ILoggerFactory _loggerFactory; private readonly ILogger _logger; private readonly IConnection _connection; @@ -38,9 +40,17 @@ public class HubConnection private int _nextId = 0; private volatile bool _startCalled; + private Timer _timeoutTimer; + private bool _needKeepAlive; public Task Closed { get; } + /// + /// Gets or sets the server timeout interval for the connection. Changes to this value + /// will not be applied until the Keep Alive timer is next reset. + /// + public TimeSpan ServerTimeout { get; set; } = DefaultServerTimeout; + public HubConnection(IConnection connection, IHubProtocol protocol, ILoggerFactory loggerFactory) { if (connection == null) @@ -64,6 +74,9 @@ public HubConnection(IConnection connection, IHubProtocol protocol, ILoggerFacto Shutdown(task.Exception); return task; }).Unwrap(); + + // Create the timer for timeout, but disabled by default (we enable it when started). + _timeoutTimer = new Timer(state => ((HubConnection)state).TimeoutElapsed(), this, Timeout.Infinite, Timeout.Infinite); } public async Task StartAsync() @@ -78,6 +91,20 @@ public async Task StartAsync() } } + private void TimeoutElapsed() + { + _connection.AbortAsync(new TimeoutException($"Server timeout ({ServerTimeout.TotalMilliseconds:0.00}ms) elapsed without receiving a message from the server.")); + } + + private void ResetTimeoutTimer() + { + if (_needKeepAlive) + { + _logger.ResettingKeepAliveTimer(); + _timeoutTimer.Change(ServerTimeout, Timeout.InfiniteTimeSpan); + } + } + private async Task StartAsyncCore() { var transferModeFeature = _connection.Features.Get(); @@ -94,6 +121,7 @@ private async Task StartAsyncCore() transferModeFeature.TransferMode = requestedTransferMode; await _connection.StartAsync(); + _needKeepAlive = _connection.Features.Get() == null; var actualTransferMode = transferModeFeature.TransferMode; _protocolReaderWriter = new HubProtocolReaderWriter(_protocol, GetDataEncoder(requestedTransferMode, actualTransferMode)); @@ -105,6 +133,8 @@ private async Task StartAsyncCore() NegotiationProtocol.WriteMessage(new NegotiationMessage(_protocol.Name), memoryStream); await _connection.SendAsync(memoryStream.ToArray(), _connectionActive.Token); } + + ResetTimeoutTimer(); } private IDataEncoder GetDataEncoder(TransferMode requestedTransferMode, TransferMode actualTransferMode) @@ -125,6 +155,7 @@ private IDataEncoder GetDataEncoder(TransferMode requestedTransferMode, Transfer private async Task DisposeAsyncCore() { + _timeoutTimer.Dispose(); await _connection.DisposeAsync(); await Closed; } @@ -298,6 +329,7 @@ private async Task SendAsyncCore(string methodName, object[] args, CancellationT private async Task OnDataReceivedAsync(byte[] data) { + ResetTimeoutTimer(); if (_protocolReaderWriter.ReadMessages(data, _binder, out var messages)) { foreach (var message in messages) diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/Internal/SignalRClientLoggerExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/Internal/SignalRClientLoggerExtensions.cs index c30c4e17f3..1f7edd7c29 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/Internal/SignalRClientLoggerExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/Internal/SignalRClientLoggerExtensions.cs @@ -85,6 +85,9 @@ internal static class SignalRClientLoggerExtensions private static readonly Action _preparingStreamingInvocation = LoggerMessage.Define(LogLevel.Trace, new EventId(24, nameof(PreparingStreamingInvocation)), "Preparing streaming invocation '{invocationId}' of '{target}', with return type '{returnType}' and {argumentCount} argument(s)."); + private static readonly Action _resettingKeepAliveTimer = + LoggerMessage.Define(LogLevel.Trace, new EventId(25, nameof(ResettingKeepAliveTimer)), "Resetting keep-alive timer, received a message from the server."); + // Category: Streaming and NonStreaming private static readonly Action _invocationCreated = LoggerMessage.Define(LogLevel.Trace, new EventId(0, nameof(InvocationCreated)), "Invocation {invocationId} created."); @@ -282,7 +285,12 @@ public static void StreamItemOnNonStreamInvocation(this ILogger logger, string i public static void ErrorInvokingClientSideMethod(this ILogger logger, string methodName, Exception exception) { - _errorInvokingClientSideMethod(logger, methodName, exception); + _errorInvokingClientSideMethod(logger, methodName, exception); + } + + public static void ResettingKeepAliveTimer(this ILogger logger) + { + _resettingKeepAliveTimer(logger, null); } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/Microsoft.AspNetCore.SignalR.Client.Core.csproj b/src/Microsoft.AspNetCore.SignalR.Client.Core/Microsoft.AspNetCore.SignalR.Client.Core.csproj index e05b0b19f7..9e3d8f1784 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/Microsoft.AspNetCore.SignalR.Client.Core.csproj +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/Microsoft.AspNetCore.SignalR.Client.Core.csproj @@ -3,6 +3,7 @@ Client for ASP.NET Core SignalR netstandard2.0 + Microsoft.AspNetCore.SignalR.Client diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolConstants.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolConstants.cs index 09179c276f..b8efdfa960 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolConstants.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolConstants.cs @@ -1,6 +1,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol { - internal static class HubProtocolConstants + public static class HubProtocolConstants { public const int InvocationMessageType = 1; public const int StreamItemMessageType = 2; diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs index be3f5f26fe..7152c33f9b 100644 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -13,6 +13,7 @@ public interface IConnection Task StartAsync(); Task SendAsync(byte[] data, CancellationToken cancellationToken); Task DisposeAsync(); + Task AbortAsync(Exception ex); IDisposable OnReceived(Func callback, object state); diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs index e312bad199..0773017e92 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs @@ -95,7 +95,7 @@ public HttpConnection(Uri url, ITransportFactory transportFactory, ILoggerFactor _logger = _loggerFactory.CreateLogger(); _httpOptions = httpOptions; _httpClient = _httpOptions?.HttpMessageHandler == null ? new HttpClient() : new HttpClient(_httpOptions?.HttpMessageHandler); - _httpClient.Timeout = HttpClientTimeout; + _httpClient.Timeout = HttpClientTimeout; _transportFactory = transportFactory ?? throw new ArgumentNullException(nameof(transportFactory)); } @@ -303,7 +303,7 @@ private async Task StartTransport(Uri connectUrl) // Start the transport, giving it one end of the pipeline try { - await _transport.StartAsync(connectUrl, applicationSide, requestedTransferMode: GetTransferMode(), connectionId: _connectionId); + await _transport.StartAsync(connectUrl, applicationSide, GetTransferMode(), _connectionId, this); // actual transfer mode can differ from the one that was requested so set it on the feature Debug.Assert(_transport.Mode.HasValue, "transfer mode not set after transport started"); @@ -435,11 +435,25 @@ private async Task SendAsyncCore(byte[] data, CancellationToken cancellationToke } } + public async Task AbortAsync(Exception ex) => await DisposeAsyncCore(ex ?? new InvalidOperationException("Connection aborted")).ForceAsync(); + public async Task DisposeAsync() => await DisposeAsyncCore().ForceAsync(); - private async Task DisposeAsyncCore() + private async Task DisposeAsyncCore(Exception ex = null) { - _logger.StoppingClient(_connectionId); + if (ex != null) + { + _logger.AbortingClient(_connectionId, ex); + + // Immediately fault the close task. When the transport shuts down, + // it will trigger the close task to be completed, so we want it to be + // marked faulted before that happens + _closedTcs.TrySetException(ex); + } + else + { + _logger.StoppingClient(_connectionId); + } if (Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected) == ConnectionState.Initial) { @@ -472,6 +486,7 @@ private async Task DisposeAsyncCore() await _receiveLoopTask; } + // If we haven't already done so, trigger the Closed task. _closedTcs.TrySetResult(null); _httpClient?.Dispose(); } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransport.cs index 784400db85..620e0e1a1d 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransport.cs @@ -9,7 +9,7 @@ namespace Microsoft.AspNetCore.Sockets.Client { public interface ITransport { - Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, string connectionId); + Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, string connectionId, IConnection connection); Task StopAsync(); TransferMode? Mode { get; } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs index 1ac5a0a0d4..a68cfb9659 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs @@ -153,6 +153,9 @@ internal static class SocketClientLoggerExtensions private static readonly Action _exceptionThrownFromCallback = LoggerMessage.Define(LogLevel.Error, new EventId(19, nameof(ExceptionThrownFromCallback)), "{time}: Connection Id {connectionId}: An exception was thrown from the '{callback}' callback"); + private static readonly Action _abortingClient = + LoggerMessage.Define(LogLevel.Error, new EventId(20, nameof(AbortingClient)), "{time}: Connection Id {connectionId}: Aborting client."); + public static void StartTransport(this ILogger logger, string connectionId, TransferMode transferMode) { @@ -506,6 +509,14 @@ public static void SendingMessage(this ILogger logger, string connectionId) } } + public static void AbortingClient(this ILogger logger, string connectionId, Exception ex) + { + if (logger.IsEnabled(LogLevel.Error)) + { + _abortingClient(logger, DateTime.Now, connectionId, ex); + } + } + public static void StoppingClient(this ILogger logger, string connectionId) { if (logger.IsEnabled(LogLevel.Information)) diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs index b2c8dec3ae..5135a8a56c 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs @@ -2,14 +2,14 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; using System.Net; using System.Net.Http; using System.Threading; -using System.Threading.Tasks; using System.Threading.Channels; +using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets.Client.Http; using Microsoft.AspNetCore.Sockets.Client.Internal; +using Microsoft.AspNetCore.Sockets.Features; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -42,13 +42,15 @@ public LongPollingTransport(HttpClient httpClient, HttpOptions httpOptions, ILog _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); } - public Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, string connectionId) + public Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, string connectionId, IConnection connection) { if (requestedTransferMode != TransferMode.Binary && requestedTransferMode != TransferMode.Text) { throw new ArgumentException("Invalid transfer mode.", nameof(requestedTransferMode)); } + connection.Features.Set(new ConnectionInherentKeepAliveFeature(_httpClient.Timeout)); + _application = application; Mode = requestedTransferMode; _connectionId = connectionId; diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs index 2f288562a1..c0c5b929cd 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs @@ -49,7 +49,7 @@ public ServerSentEventsTransport(HttpClient httpClient, HttpOptions httpOptions, _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); } - public Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, string connectionId) + public Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, string connectionId, IConnection connection) { if (requestedTransferMode != TransferMode.Binary && requestedTransferMode != TransferMode.Text) { diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs index fdaa2b64c0..f68c2d8b37 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs @@ -55,7 +55,7 @@ public WebSocketsTransport(HttpOptions httpOptions, ILoggerFactory loggerFactory _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); } - public async Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, string connectionId) + public async Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, string connectionId, IConnection connection) { if (url == null) { diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs index 86d975fb89..baa342476e 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs @@ -155,7 +155,7 @@ public async Task CanStopStartingConnection() releaseDisposeTcs.SetResult(null); await disposeTask.OrTimeout(); - transport.Verify(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny()), Times.Never); + transport.Verify(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny()), Times.Never); } [Fact] @@ -263,8 +263,8 @@ public async Task ReceivedCallbackNotRaisedAfterConnectionIsDisposed() var mockTransport = new Mock(); Channel channel = null; - mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny())) - .Returns, TransferMode, string>((url, c, transferMode, connectionId) => + mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns, TransferMode, string, IConnection>((url, c, transferMode, connectionId, _) => { channel = c; return Task.CompletedTask; @@ -311,8 +311,8 @@ public async Task EventsAreNotRunningOnMainLoop() var mockTransport = new Mock(); Channel channel = null; - mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny())) - .Returns, TransferMode, string>((url, c, transferMode, connectionId) => + mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns, TransferMode, string, IConnection>((url, c, transferMode, connectionId, _) => { channel = c; return Task.CompletedTask; @@ -368,8 +368,8 @@ public async Task EventQueueTimeout() var mockTransport = new Mock(); Channel channel = null; - mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny())) - .Returns, TransferMode, string>((url, c, transferMode, connectionId) => + mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns, TransferMode, string, IConnection>((url, c, transferMode, connectionId, _) => { channel = c; return Task.CompletedTask; @@ -413,8 +413,8 @@ public async Task EventQueueTimeoutWithException() var mockTransport = new Mock(); Channel channel = null; - mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny())) - .Returns, TransferMode, string>((url, c, transferMode, connectionId) => + mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns, TransferMode, string, IConnection>((url, c, transferMode, connectionId, _) => { channel = c; return Task.CompletedTask; @@ -925,8 +925,8 @@ public async Task CanStartConnectionWithoutSettingTransferModeFeature() var mockTransport = new Mock(); Channel channel = null; - mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny())) - .Returns, TransferMode, string>((url, c, transferMode, connectionId) => + mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns, TransferMode, string, IConnection>((url, c, transferMode, connectionId, _) => { channel = c; return Task.CompletedTask; @@ -947,7 +947,7 @@ public async Task CanStartConnectionWithoutSettingTransferModeFeature() await connection.DisposeAsync().OrTimeout(); mockTransport.Verify(t => t.StartAsync( - It.IsAny(), It.IsAny>(), TransferMode.Text, It.IsAny()), Times.Once); + It.IsAny(), It.IsAny>(), TransferMode.Text, It.IsAny(), It.IsAny()), Times.Once); Assert.NotNull(transferModeFeature); Assert.Equal(TransferMode.Binary, transferModeFeature.TransferMode); } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs index 35cd3a1a6e..85784f347d 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs @@ -9,6 +9,7 @@ using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets.Client; +using Microsoft.AspNetCore.Sockets.Features; using Microsoft.Extensions.Logging; using Moq; using Xunit; @@ -33,6 +34,7 @@ public async Task StartAsyncCallsConnectionStart() public async Task DisposeAsyncCallsConnectionStart() { var connection = new Mock(); + connection.Setup(m => m.Features).Returns(new FeatureCollection()); connection.Setup(m => m.StartAsync()).Verifiable(); var hubConnection = new HubConnection(connection.Object, Mock.Of(), null); await hubConnection.DisposeAsync(); @@ -185,6 +187,19 @@ public async Task PendingInvocationsAreTerminatedWithExceptionWhenConnectionClos await Assert.ThrowsAsync(async () => await invokeTask); } + [Fact] + public async Task ConnectionTerminatedIfServerTimeoutIntervalElapsesWithNoMessages() + { + var connection = new TestConnection(); + var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); + + hubConnection.ServerTimeout = TimeSpan.FromMilliseconds(100); + + await hubConnection.StartAsync().OrTimeout(); + var ex = await Assert.ThrowsAsync(async () => await hubConnection.Closed.OrTimeout()); + Assert.Equal("Server timeout (100.00ms) elapsed without receiving a message from the server.", ex.Message); + } + // Moq really doesn't handle out parameters well, so to make these tests work I added a manual mock -anurse private class MockHubProtocol : IHubProtocol { diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs index 55c1452223..e4e5e43a15 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs @@ -9,6 +9,7 @@ using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Client.Tests; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.AspNetCore.Sockets.Internal; @@ -43,7 +44,7 @@ public async Task LongPollingTransportStopsPollAndSendLoopsWhenTransportStopped( var connectionToTransport = Channel.CreateUnbounded(); var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty, connection: new TestConnection()); transportActiveTask = longPollingTransport.Running; @@ -79,7 +80,7 @@ public async Task LongPollingTransportStopsWhenPollReceives204() var connectionToTransport = Channel.CreateUnbounded(); var transportToConnection = Channel.CreateUnbounded(); var channelConnection = ChannelConnection.Create(connectionToTransport, transportToConnection); - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty, connection: new TestConnection()); await longPollingTransport.Running.OrTimeout(); Assert.True(transportToConnection.Reader.Completion.IsCompleted); @@ -132,7 +133,7 @@ public async Task LongPollingTransportResponseWithNoContentDoesNotStopPoll() var connectionToTransport = Channel.CreateUnbounded(); var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty, connection: new TestConnection()); var data = await transportToConnection.Reader.ReadAllAsync().OrTimeout(); await longPollingTransport.Running.OrTimeout(); @@ -168,7 +169,7 @@ public async Task LongPollingTransportStopsWhenPollRequestFails() var connectionToTransport = Channel.CreateUnbounded(); var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty, connection: new TestConnection()); var exception = await Assert.ThrowsAsync(async () => await transportToConnection.Reader.Completion.OrTimeout()); @@ -204,7 +205,7 @@ public async Task LongPollingTransportStopsWhenSendRequestFails() var connectionToTransport = Channel.CreateUnbounded(); var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty, connection: new TestConnection()); await connectionToTransport.Writer.WriteAsync(new SendMessage()); @@ -245,7 +246,7 @@ public async Task LongPollingTransportShutsDownWhenChannelIsClosed() var connectionToTransport = Channel.CreateUnbounded(); var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty, connection: new TestConnection()); connectionToTransport.Writer.Complete(); @@ -296,7 +297,7 @@ public async Task LongPollingTransportDispatchesMessagesReceivedFromPoll() var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); // Start the transport - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty, connection: new TestConnection()); // Wait for the transport to finish await longPollingTransport.Running.OrTimeout(); @@ -361,7 +362,7 @@ public async Task LongPollingTransportSendsAvailableMessagesWhenTheyArrive() await connectionToTransport.Writer.WriteAsync(new SendMessage(Encoding.UTF8.GetBytes("World"), tcs2)).OrTimeout(); // Start the transport - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty, connection: new TestConnection()); connectionToTransport.Writer.Complete(); @@ -404,7 +405,7 @@ public async Task LongPollingTransportSetsTransferMode(TransferMode transferMode var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); Assert.Null(longPollingTransport.Mode); - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, transferMode, connectionId: string.Empty); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, transferMode, connectionId: string.Empty, connection: new TestConnection()); Assert.Equal(transferMode, longPollingTransport.Mode); } finally @@ -430,7 +431,7 @@ public async Task LongPollingTransportThrowsForInvalidTransferMode() { var longPollingTransport = new LongPollingTransport(httpClient); var exception = await Assert.ThrowsAsync(() => - longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), null, TransferMode.Text | TransferMode.Binary, connectionId: string.Empty)); + longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), null, TransferMode.Text | TransferMode.Binary, connectionId: string.Empty, connection: new TestConnection())); Assert.Contains("Invalid transfer mode.", exception.Message); Assert.Equal("requestedTransferMode", exception.ParamName); @@ -468,7 +469,7 @@ public async Task LongPollingTransportRePollsIfRequestCancelled() var connectionToTransport = Channel.CreateUnbounded(); var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty, connection: new TestConnection()); var completedTask = await Task.WhenAny(completionTcs.Task, longPollingTransport.Running).OrTimeout(); Assert.Equal(completionTcs.Task, completedTask); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs index 961700585f..4a4e8f81a0 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs @@ -55,7 +55,7 @@ public async Task CanStartStopSSETransport() var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); await sseTransport.StartAsync( - new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty).OrTimeout(); + new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty, connection: Mock.Of()).OrTimeout(); await eventStreamTcs.Task.OrTimeout(); await sseTransport.StopAsync().OrTimeout(); @@ -108,7 +108,7 @@ public async Task SSETransportStopsSendAndReceiveLoopsWhenTransportStopped() var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); await sseTransport.StartAsync( - new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty).OrTimeout(); + new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty, connection: Mock.Of()).OrTimeout(); transportActiveTask = sseTransport.Running; Assert.False(transportActiveTask.IsCompleted); @@ -156,7 +156,7 @@ public async Task SSETransportStopsWithErrorIfServerSendsIncompleteResults() var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); await sseTransport.StartAsync( - new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty).OrTimeout(); + new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty, connection: Mock.Of()).OrTimeout(); var exception = await Assert.ThrowsAsync(() => sseTransport.Running.OrTimeout()); Assert.Equal("Incomplete message.", exception.Message); @@ -202,7 +202,7 @@ public async Task SSETransportStopsWithErrorIfSendingMessageFails() var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); await sseTransport.StartAsync( - new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty).OrTimeout(); + new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty, connection: Mock.Of()).OrTimeout(); await eventStreamTcs.Task; var sendTcs = new TaskCompletionSource(); @@ -249,7 +249,7 @@ public async Task SSETransportStopsIfChannelClosed() var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); await sseTransport.StartAsync( - new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty).OrTimeout(); + new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty, connection: Mock.Of()).OrTimeout(); await eventStreamTcs.Task.OrTimeout(); connectionToTransport.Writer.TryComplete(null); @@ -278,7 +278,7 @@ public async Task SSETransportStopsIfTheServerClosesTheStream() var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); await sseTransport.StartAsync( - new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty).OrTimeout(); + new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty, connection: Mock.Of()).OrTimeout(); var message = await transportToConnection.Reader.ReadAsync().AsTask().OrTimeout(); Assert.Equal("3:abc", Encoding.ASCII.GetString(message)); @@ -308,7 +308,7 @@ public async Task SSETransportSetsTransferMode(TransferMode transferMode) var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); Assert.Null(sseTransport.Mode); - await sseTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, transferMode, connectionId: string.Empty).OrTimeout(); + await sseTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, transferMode, connectionId: string.Empty, connection: Mock.Of()).OrTimeout(); Assert.Equal(TransferMode.Text, sseTransport.Mode); await sseTransport.StopAsync().OrTimeout(); } @@ -333,7 +333,7 @@ public async Task SSETransportThrowsForInvalidTransferMode() var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); var exception = await Assert.ThrowsAsync(() => - sseTransport.StartAsync(new Uri("http://fakeuri.org"), null, TransferMode.Text | TransferMode.Binary, connectionId: string.Empty)); + sseTransport.StartAsync(new Uri("http://fakeuri.org"), null, TransferMode.Text | TransferMode.Binary, connectionId: string.Empty, connection: Mock.Of())); Assert.Contains("Invalid transfer mode.", exception.Message); Assert.Equal("requestedTransferMode", exception.ParamName); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs index 6e8c8be30d..aee51722ec 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs @@ -48,9 +48,22 @@ public TestConnection(TransferMode? transferMode = null) _receiveLoop = ReceiveLoopAsync(_receiveShutdownToken.Token); } - public Task DisposeAsync() + public Task AbortAsync(Exception ex) => DisposeCoreAsync(ex); + public Task DisposeAsync() => DisposeCoreAsync(); + + private Task DisposeCoreAsync(Exception ex = null) { - _disposed.TrySetResult(null); + if (ex == null) + { + _closeTcs.TrySetResult(null); + _disposed.TrySetResult(null); + } + else + { + _closeTcs.TrySetException(ex); + _disposed.TrySetException(ex); + } + _receiveShutdownToken.Cancel(); return _receiveLoop; } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs index 2845e0a6c2..946bbbf877 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs @@ -9,6 +9,7 @@ using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.AspNetCore.Testing.xunit; using Microsoft.Extensions.Logging.Testing; +using Moq; using Xunit; using Xunit.Abstractions; @@ -41,7 +42,7 @@ public async Task WebSocketsTransportStopsSendAndReceiveLoopsWhenTransportIsStop var webSocketsTransport = new WebSocketsTransport(httpOptions: null, loggerFactory: loggerFactory); await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, - TransferMode.Binary, connectionId: string.Empty).OrTimeout(); + TransferMode.Binary, connectionId: string.Empty, connection: Mock.Of()).OrTimeout(); await webSocketsTransport.StopAsync().OrTimeout(); await webSocketsTransport.Running.OrTimeout(); } @@ -59,7 +60,7 @@ public async Task WebSocketsTransportStopsWhenConnectionChannelClosed() var webSocketsTransport = new WebSocketsTransport(httpOptions: null, loggerFactory: loggerFactory); await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, - TransferMode.Binary, connectionId: string.Empty); + TransferMode.Binary, connectionId: string.Empty, connection: Mock.Of()); connectionToTransport.Writer.TryComplete(); await webSocketsTransport.Running.OrTimeout(TimeSpan.FromSeconds(10)); } @@ -78,7 +79,7 @@ public async Task WebSocketsTransportStopsWhenConnectionClosedByTheServer(Transf var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); var webSocketsTransport = new WebSocketsTransport(httpOptions: null, loggerFactory: loggerFactory); - await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, transferMode, connectionId: string.Empty); + await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, transferMode, connectionId: string.Empty, connection: Mock.Of()); var sendTcs = new TaskCompletionSource(); connectionToTransport.Writer.TryWrite(new SendMessage(new byte[] { 0x42 }, sendTcs)); @@ -119,7 +120,7 @@ public async Task WebSocketsTransportSetsTransferMode(TransferMode transferMode) Assert.Null(webSocketsTransport.Mode); await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, - transferMode, connectionId: string.Empty).OrTimeout(); + transferMode, connectionId: string.Empty, connection: Mock.Of()).OrTimeout(); Assert.Equal(transferMode, webSocketsTransport.Mode); await webSocketsTransport.StopAsync().OrTimeout(); @@ -139,7 +140,7 @@ public async Task WebSocketsTransportThrowsForInvalidTransferMode() var webSocketsTransport = new WebSocketsTransport(httpOptions: null, loggerFactory: loggerFactory); var exception = await Assert.ThrowsAsync(() => - webSocketsTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text | TransferMode.Binary, connectionId: string.Empty)); + webSocketsTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text | TransferMode.Binary, connectionId: string.Empty, connection: Mock.Of())); Assert.Contains("Invalid transfer mode.", exception.Message); Assert.Equal("requestedTransferMode", exception.ParamName);