diff --git a/SignalR.sln b/SignalR.sln index a5c696730c..834b4998d3 100644 --- a/SignalR.sln +++ b/SignalR.sln @@ -9,6 +9,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DA69F624-539 EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{83B2C3EB-A3D8-4E6F-9A3C-A380B005EF31}" ProjectSection(SolutionItems) = preProject + build\dependencies.props = build\dependencies.props Directory.Build.props = Directory.Build.props Directory.Build.targets = Directory.Build.targets build\Key.snk = build\Key.snk @@ -85,7 +86,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "JwtSample", "samples\JwtSam EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "JwtClientSample", "samples\JwtClientSample\JwtClientSample.csproj", "{1A953296-E869-4DE2-A693-FD5FCDE27057}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNetCore.SignalR.Tests.Utils", "test\Microsoft.AspNetCore.SignalR.Tests.Utils\Microsoft.AspNetCore.SignalR.Tests.Utils.csproj", "{0A0A6135-EA24-4307-95C2-CE1B7E164A5E}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.SignalR.Tests.Utils", "test\Microsoft.AspNetCore.SignalR.Tests.Utils\Microsoft.AspNetCore.SignalR.Tests.Utils.csproj", "{0A0A6135-EA24-4307-95C2-CE1B7E164A5E}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution diff --git a/client-ts/package-lock.json b/client-ts/package-lock.json index ef3bd4b10a..2ea73c4c1c 100644 --- a/client-ts/package-lock.json +++ b/client-ts/package-lock.json @@ -28,6 +28,16 @@ "integrity": "sha512-zT+t9841g1HsjLtPMCYxmb1U4pcZ2TOegAKiomlmj6bIziuaEYHUavxLE9NRwdntY0vOCrgHho6OXjDX7fm/Kw==", "dev": true }, + "JSONStream": { + "version": "1.3.1", + "resolved": "https://registry.npmjs.org/JSONStream/-/JSONStream-1.3.1.tgz", + "integrity": "sha1-cH92HgHa6eFvG8+TcDt4xwlmV5o=", + "dev": true, + "requires": { + "jsonparse": "1.3.1", + "through": "2.3.8" + } + }, "acorn": { "version": "4.0.13", "resolved": "https://registry.npmjs.org/acorn/-/acorn-4.0.13.tgz", @@ -1017,9 +1027,9 @@ "integrity": "sha1-+GzWzvT1MAyOY+B6TVEvZfv/RTE=", "dev": true, "requires": { + "JSONStream": "1.3.1", "combine-source-map": "0.7.2", "defined": "1.0.0", - "JSONStream": "1.3.1", "through2": "2.0.3", "umd": "3.0.1" } @@ -1047,6 +1057,7 @@ "integrity": "sha1-tanJAgJD8McORnW+yCI7xifkFc4=", "dev": true, "requires": { + "JSONStream": "1.3.1", "assert": "1.4.1", "browser-pack": "6.0.2", "browser-resolve": "1.11.2", @@ -1068,7 +1079,6 @@ "https-browserify": "0.0.1", "inherits": "2.0.3", "insert-module-globals": "7.0.1", - "JSONStream": "1.3.1", "labeled-stream-splicer": "2.0.0", "module-deps": "4.1.1", "os-browserify": "0.1.2", @@ -2487,10 +2497,10 @@ "integrity": "sha1-wDv04BywhtW15azorQr+eInWOMM=", "dev": true, "requires": { + "JSONStream": "1.3.1", "combine-source-map": "0.7.2", "concat-stream": "1.5.2", "is-buffer": "1.1.5", - "JSONStream": "1.3.1", "lexical-scope": "1.2.0", "process": "0.11.10", "through2": "2.0.3", @@ -2763,16 +2773,6 @@ "integrity": "sha1-P02uSpH6wxX3EGL4UhzCOfE2YoA=", "dev": true }, - "JSONStream": { - "version": "1.3.1", - "resolved": "https://registry.npmjs.org/JSONStream/-/JSONStream-1.3.1.tgz", - "integrity": "sha1-cH92HgHa6eFvG8+TcDt4xwlmV5o=", - "dev": true, - "requires": { - "jsonparse": "1.3.1", - "through": "2.3.8" - } - }, "kind-of": { "version": "3.2.2", "resolved": "https://registry.npmjs.org/kind-of/-/kind-of-3.2.2.tgz", @@ -3126,6 +3126,7 @@ "integrity": "sha1-IyFYM/HaE/1gbMuAh7RIUty4If0=", "dev": true, "requires": { + "JSONStream": "1.3.1", "browser-resolve": "1.11.2", "cached-path-relative": "1.0.1", "concat-stream": "1.5.2", @@ -3133,7 +3134,6 @@ "detective": "4.5.0", "duplexer2": "0.1.4", "inherits": "2.0.3", - "JSONStream": "1.3.1", "parents": "1.0.1", "readable-stream": "2.2.11", "resolve": "1.3.3", diff --git a/samples/ClientSample/HubSample.cs b/samples/ClientSample/HubSample.cs index ebd805f5b0..cc1a8447c4 100644 --- a/samples/ClientSample/HubSample.cs +++ b/samples/ClientSample/HubSample.cs @@ -30,11 +30,22 @@ public static async Task ExecuteAsync(string baseUrl) baseUrl = string.IsNullOrEmpty(baseUrl) ? "http://localhost:5000/default" : baseUrl; Console.WriteLine("Connecting to {0}", baseUrl); - HubConnection connection = await ConnectAsync(baseUrl); - Console.WriteLine("Connected to {0}", baseUrl); + var connection = new HubConnectionBuilder() + .WithUrl(baseUrl) + .WithConsoleLogger(LogLevel.Trace) + .Build(); try { + var closeTcs = new TaskCompletionSource(); + connection.Closed += e => closeTcs.SetResult(null); + // Set up handler + connection.On("Send", Console.WriteLine); + + await ConnectAsync(connection); + + Console.WriteLine("Connected to {0}", baseUrl); + var sendCts = new CancellationTokenSource(); Console.CancelKeyPress += async (sender, a) => @@ -45,13 +56,10 @@ public static async Task ExecuteAsync(string baseUrl) await connection.DisposeAsync(); }; - // Set up handler - connection.On("Send", Console.WriteLine); - - while (!connection.Closed.IsCompleted) + while (!closeTcs.Task.IsCompleted) { - var completedTask = await Task.WhenAny(Task.Run(() => Console.ReadLine()), connection.Closed); - if (completedTask == connection.Closed) + var completedTask = await Task.WhenAny(Task.Run(() => Console.ReadLine()), closeTcs.Task); + if (completedTask == closeTcs.Task) { break; } @@ -79,19 +87,15 @@ public static async Task ExecuteAsync(string baseUrl) return 0; } - private static async Task ConnectAsync(string baseUrl) + private static async Task ConnectAsync(HubConnection connection) { // Keep trying to until we can start while (true) { - var connection = new HubConnectionBuilder() - .WithUrl(baseUrl) - .WithConsoleLogger(LogLevel.Trace) - .Build(); + try { await connection.StartAsync(); - return connection; } catch (Exception) { diff --git a/samples/ClientSample/RawSample.cs b/samples/ClientSample/RawSample.cs index 35786286dd..e2e047c1de 100644 --- a/samples/ClientSample/RawSample.cs +++ b/samples/ClientSample/RawSample.cs @@ -39,8 +39,9 @@ public static async Task ExecuteAsync(string baseUrl) var connection = new HttpConnection(new Uri(baseUrl), loggerFactory); try { + var closeTcs = new TaskCompletionSource(); + connection.Closed += e => closeTcs.SetResult(null); connection.OnReceived(data => Console.Out.WriteLineAsync($"{Encoding.UTF8.GetString(data)}")); - await connection.StartAsync(); Console.WriteLine($"Connected to {baseUrl}"); @@ -51,7 +52,7 @@ public static async Task ExecuteAsync(string baseUrl) await connection.DisposeAsync(); }; - while (!connection.Closed.IsCompleted) + while (!closeTcs.Task.IsCompleted) { var line = await Task.Run(() => Console.ReadLine(), cts.Token); diff --git a/samples/JwtClientSample/Program.cs b/samples/JwtClientSample/Program.cs index 8358f9c3a3..5f2553645e 100644 --- a/samples/JwtClientSample/Program.cs +++ b/samples/JwtClientSample/Program.cs @@ -34,6 +34,9 @@ private async Task RunConnection(TransportType transportType) .WithJwtBearer(() => _tokens[userId]) .Build(); + var closedTcs = new TaskCompletionSource(); + hubConnection.Closed += e => closedTcs.SetResult(null); + hubConnection.On("Message", (sender, message) => Console.WriteLine($"[{userId}] {sender}: {message}")); await hubConnection.StartAsync(); Console.WriteLine($"[{userId}] Connection Started"); @@ -43,7 +46,7 @@ private async Task RunConnection(TransportType transportType) try { - while (!hubConnection.Closed.IsCompleted) + while (!closedTcs.Task.IsCompleted) { await Task.Delay(1000); ticks++; diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs index e8c4e46553..50de12a1a1 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs @@ -34,16 +34,16 @@ public class HubConnection private HubProtocolReaderWriter _protocolReaderWriter; private readonly object _pendingCallsLock = new object(); - private readonly CancellationTokenSource _connectionActive = new CancellationTokenSource(); private readonly Dictionary _pendingCalls = new Dictionary(); private readonly ConcurrentDictionary> _handlers = new ConcurrentDictionary>(); + private CancellationTokenSource _connectionActive; private int _nextId = 0; private volatile bool _startCalled; private Timer _timeoutTimer; private bool _needKeepAlive; - public Task Closed { get; } + public event Action Closed; /// /// Gets or sets the server timeout interval for the connection. Changes to this value @@ -69,11 +69,7 @@ public HubConnection(IConnection connection, IHubProtocol protocol, ILoggerFacto _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; _logger = _loggerFactory.CreateLogger(); _connection.OnReceived((data, state) => ((HubConnection)state).OnDataReceivedAsync(data), this); - Closed = _connection.Closed.ContinueWith(task => - { - Shutdown(task.Exception); - return task; - }).Unwrap(); + _connection.Closed += e => Shutdown(e); // Create the timer for timeout, but disabled by default (we enable it when started). _timeoutTimer = new Timer(state => ((HubConnection)state).TimeoutElapsed(), this, Timeout.Infinite, Timeout.Infinite); @@ -122,12 +118,14 @@ private async Task StartAsyncCore() transferModeFeature.TransferMode = requestedTransferMode; await _connection.StartAsync(); _needKeepAlive = _connection.Features.Get() == null; + var actualTransferMode = transferModeFeature.TransferMode; _protocolReaderWriter = new HubProtocolReaderWriter(_protocol, GetDataEncoder(requestedTransferMode, actualTransferMode)); _logger.HubProtocol(_protocol.Name); + _connectionActive = new CancellationTokenSource(); using (var memoryStream = new MemoryStream()) { NegotiationProtocol.WriteMessage(new NegotiationMessage(_protocol.Name), memoryStream); @@ -151,13 +149,16 @@ private IDataEncoder GetDataEncoder(TransferMode requestedTransferMode, Transfer return new PassThroughEncoder(); } + public async Task StopAsync() => await StopAsyncCore().ForceAsync(); + + private Task StopAsyncCore() => _connection.StopAsync(); + public async Task DisposeAsync() => await DisposeAsyncCore().ForceAsync(); private async Task DisposeAsyncCore() { _timeoutTimer.Dispose(); await _connection.DisposeAsync(); - await Closed; } // TODO: Client return values/tasks? @@ -370,12 +371,12 @@ private async Task OnDataReceivedAsync(byte[] data) } } - private void Shutdown(Exception ex = null) + private void Shutdown(Exception exception = null) { _logger.ShutdownConnection(); - if (ex != null) + if (exception != null) { - _logger.ShutdownWithError(ex); + _logger.ShutdownWithError(exception); } lock (_pendingCallsLock) @@ -388,14 +389,23 @@ private void Shutdown(Exception ex = null) foreach (var outstandingCall in _pendingCalls.Values) { _logger.RemoveInvocation(outstandingCall.InvocationId); - if (ex != null) + if (exception != null) { - outstandingCall.Fail(ex); + outstandingCall.Fail(exception); } outstandingCall.Dispose(); } _pendingCalls.Clear(); } + + try + { + Closed?.Invoke(exception); + } + catch (Exception ex) + { + _logger.ErrorDuringClosedEvent(ex); + } } private async Task DispatchInvocationAsync(InvocationMessage invocation, CancellationToken cancellationToken) diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/Internal/SignalRClientLoggerExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/Internal/SignalRClientLoggerExtensions.cs index 1f7edd7c29..6371a949b0 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/Internal/SignalRClientLoggerExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/Internal/SignalRClientLoggerExtensions.cs @@ -88,6 +88,9 @@ internal static class SignalRClientLoggerExtensions private static readonly Action _resettingKeepAliveTimer = LoggerMessage.Define(LogLevel.Trace, new EventId(25, nameof(ResettingKeepAliveTimer)), "Resetting keep-alive timer, received a message from the server."); + private static readonly Action _errorDuringClosedEvent = + LoggerMessage.Define(LogLevel.Error, new EventId(26, nameof(ErrorDuringClosedEvent)), "An exception was thrown in the handler for the Closed event."); + // Category: Streaming and NonStreaming private static readonly Action _invocationCreated = LoggerMessage.Define(LogLevel.Trace, new EventId(0, nameof(InvocationCreated)), "Invocation {invocationId} created."); @@ -292,5 +295,10 @@ public static void ResettingKeepAliveTimer(this ILogger logger) { _resettingKeepAliveTimer(logger, null); } + + public static void ErrorDuringClosedEvent(this ILogger logger, Exception exception) + { + _errorDuringClosedEvent(logger, exception); + } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs index 7152c33f9b..6726927b35 100644 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs @@ -12,12 +12,13 @@ public interface IConnection { Task StartAsync(); Task SendAsync(byte[] data, CancellationToken cancellationToken); + Task StopAsync(); Task DisposeAsync(); Task AbortAsync(Exception ex); IDisposable OnReceived(Func callback, object state); - Task Closed { get; } + event Action Closed; IFeatureCollection Features { get; } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs index 0773017e92..4e0c369795 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs @@ -28,17 +28,20 @@ public class HttpConnection : IConnection private readonly ILoggerFactory _loggerFactory; private readonly ILogger _logger; - private volatile int _connectionState = ConnectionState.Initial; + private volatile ConnectionState _connectionState = ConnectionState.Disconnected; + private readonly object _stateChangeLock = new object(); + private volatile ChannelConnection _transportChannel; private readonly HttpClient _httpClient; private readonly HttpOptions _httpOptions; private volatile ITransport _transport; private volatile Task _receiveLoopTask; - private TaskCompletionSource _startTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - private readonly TaskCompletionSource _closedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - private TaskQueue _eventQueue = new TaskQueue(); + private TaskCompletionSource _startTcs; + private TaskCompletionSource _closeTcs; + private TaskQueue _eventQueue; private readonly ITransportFactory _transportFactory; private string _connectionId; + private Exception _abortException; private readonly TimeSpan _eventQueueDrainTimeout = TimeSpan.FromSeconds(5); private ChannelReader Input => _transportChannel.Input; private ChannelWriter Output => _transportChannel.Output; @@ -49,7 +52,7 @@ public class HttpConnection : IConnection public IFeatureCollection Features { get; } = new FeatureCollection(); - public Task Closed => _closedTcs.Task; + public event Action Closed; public HttpConnection(Uri url) : this(url, TransportType.All) @@ -103,25 +106,26 @@ public HttpConnection(Uri url, ITransportFactory transportFactory, ILoggerFactor private Task StartAsyncCore() { - if (Interlocked.CompareExchange(ref _connectionState, ConnectionState.Connecting, ConnectionState.Initial) - != ConnectionState.Initial) + if (ChangeState(from: ConnectionState.Disconnected, to: ConnectionState.Connecting) != ConnectionState.Disconnected) { return Task.FromException( - new InvalidOperationException("Cannot start a connection that is not in the Initial state.")); + new InvalidOperationException($"Cannot start a connection that is not in the {nameof(ConnectionState.Disconnected)} state.")); } + _startTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _eventQueue = new TaskQueue(); + StartAsyncInternal() .ContinueWith(t => { - if (t.IsFaulted) + var abortException = _abortException; + if (t.IsFaulted || abortException != null) { - _startTcs.SetException(t.Exception.InnerException); - _closedTcs.TrySetException(t.Exception.InnerException); + _startTcs.SetException(_abortException ?? t.Exception.InnerException); } else if (t.IsCanceled) { _startTcs.SetCanceled(); - _closedTcs.SetCanceled(); } else { @@ -148,8 +152,8 @@ private async Task StartAsyncInternal() var negotiationResponse = await Negotiate(Url, _httpClient, _logger); _connectionId = negotiationResponse.ConnectionId; - // Connection is being stopped while start was in progress - if (_connectionState == ConnectionState.Disconnected) + // Connection is being disposed while start was in progress + if (_connectionState == ConnectionState.Disposed) { _logger.HttpConnectionClosed(_connectionId); return; @@ -164,17 +168,25 @@ private async Task StartAsyncInternal() } catch { - Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected); + // The connection can now be either in the Connecting or Disposed state - only change the state to + // Disconnected if the connection was in the Connecting state to not resurrect a Disposed connection + ChangeState(from: ConnectionState.Connecting, to: ConnectionState.Disconnected); throw; } - // if the connection is not in the Connecting state here it means the user called DisposeAsync - if (Interlocked.CompareExchange(ref _connectionState, ConnectionState.Connected, ConnectionState.Connecting) - == ConnectionState.Connecting) + // if the connection is not in the Connecting state here it means the user called DisposeAsync while + // the connection was starting + if (ChangeState(from: ConnectionState.Connecting, to: ConnectionState.Connected) == ConnectionState.Connecting) { + _closeTcs = new TaskCompletionSource(); + _ = Input.Completion.ContinueWith(async t => { - Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected); + // Grab the exception and then clear it. + // See comment at AbortAsync for more discussion on the thread-safety + // StartAsync can't be called until the ChangeState below, so we're OK. + var abortException = _abortException; + _abortException = null; // There is an inherent race between receive and close. Removing the last message from the channel // makes Input.Completion task completed and runs this continuation. We need to await _receiveLoopTask @@ -187,28 +199,38 @@ private async Task StartAsyncInternal() await _receiveLoopTask; _logger.DrainEvents(_connectionId); - await _eventQueue.Drain(); await Task.WhenAny(_eventQueue.Drain().NoThrow(), Task.Delay(_eventQueueDrainTimeout)); - _httpClient?.Dispose(); _logger.CompleteClosed(_connectionId); - if (t.IsFaulted) - { - _closedTcs.TrySetException(t.Exception.InnerException); - } - if (t.IsCanceled) + + // At this point the connection can be either in the Connected or Disposed state. The state should be changed + // to the Disconnected state only if it was in the Connected state. + // From this point on, StartAsync can be called at any time. + ChangeState(from: ConnectionState.Connected, to: ConnectionState.Disconnected); + + _closeTcs.SetResult(null); + + try { - _closedTcs.TrySetCanceled(); + if (t.IsFaulted) + { + Closed?.Invoke(t.Exception.InnerException); + } + else + { + // Call the closed event. If there was an abort exception, it will be flowed forward + // However, if there wasn't, this will just be null and we're good + Closed?.Invoke(abortException); + } } - else + catch (Exception ex) { - _closedTcs.TrySetResult(null); + // Suppress (but log) the exception, this is user code + _logger.ErrorDuringClosedEvent(ex); } }); - // start receive loop only after the Connected event was raised to - // avoid Received event being raised before the Connected event _receiveLoopTask = ReceiveAsync(); } } @@ -306,7 +328,11 @@ private async Task StartTransport(Uri connectUrl) await _transport.StartAsync(connectUrl, applicationSide, GetTransferMode(), _connectionId, this); // actual transfer mode can differ from the one that was requested so set it on the feature - Debug.Assert(_transport.Mode.HasValue, "transfer mode not set after transport started"); + if (!_transport.Mode.HasValue) + { + // This can happen with custom transports so it should be an exception, not an assert. + throw new InvalidOperationException("Transport was expected to set the Mode property after StartAsync, but it has not been set."); + } SetTransferMode(_transport.Mode.Value); } catch (Exception ex) @@ -435,31 +461,42 @@ private async Task SendAsyncCore(byte[] data, CancellationToken cancellationToke } } - public async Task AbortAsync(Exception ex) => await DisposeAsyncCore(ex ?? new InvalidOperationException("Connection aborted")).ForceAsync(); - - public async Task DisposeAsync() => await DisposeAsyncCore().ForceAsync(); - - private async Task DisposeAsyncCore(Exception ex = null) + // AbortAsync creates a few thread-safety races that we are OK with. + // 1. If the transport shuts down gracefully after AbortAsync is called but BEFORE _abortException is called, then the + // Closed event will not receive the Abort exception. This is OK because technically the transport was shut down gracefully + // before it was aborted + // 2. If the transport is closed gracefully and then AbortAsync is called before it captures the _abortException value + // the graceful shutdown could be turned into an abort. However, again, this is an inherent race between two different conditions: + // The transport shutting down because the server went away, and the user requesting the Abort + // 3. Finally, because this is an instance field, there is a possible race around accidentally re-using _abortException in the restarted + // connection. The scenario here is: AbortAsync(someException); StartAsync(); CloseAsync(); Where the _abortException value from the + // first AbortAsync call is still set at the time CloseAsync gets to calling the Closed event. However, this can't happen because the + // StartAsync method can't be called until the connection state is changed to Disconnected, which happens AFTER the close code + // captures and resets _abortException. + public async Task AbortAsync(Exception exception) => await StopAsyncCore(exception ?? throw new ArgumentNullException(nameof(exception))).ForceAsync(); + + public async Task StopAsync() => await StopAsyncCore(exception: null).ForceAsync(); + + private async Task StopAsyncCore(Exception exception) { - if (ex != null) + lock (_stateChangeLock) { - _logger.AbortingClient(_connectionId, ex); - - // Immediately fault the close task. When the transport shuts down, - // it will trigger the close task to be completed, so we want it to be - // marked faulted before that happens - _closedTcs.TrySetException(ex); - } - else - { - _logger.StoppingClient(_connectionId); + if (!(_connectionState == ConnectionState.Connecting || _connectionState == ConnectionState.Connected)) + { + _logger.SkippingStop(_connectionId); + return; + } } - if (Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected) == ConnectionState.Initial) - { - // the connection was never started so there is nothing to clean up - return; - } + // Note that this method can be called at the same time when the connection is being closed from the server + // side due to an error. We are resilient to this since we merely try to close the channel here and the + // channel can be closed only once. As a result the continuation that does actual job and raises the Closed + // event runs always only once. + + // See comment at AbortAsync for more discussion on the thread-safety of this. + _abortException = exception; + + _logger.StoppingClient(_connectionId); try { @@ -486,8 +523,29 @@ private async Task DisposeAsyncCore(Exception ex = null) await _receiveLoopTask; } - // If we haven't already done so, trigger the Closed task. - _closedTcs.TrySetResult(null); + if (_closeTcs != null) + { + await _closeTcs.Task; + } + } + + public async Task DisposeAsync() => await DisposeAsyncCore().ForceAsync(); + + private async Task DisposeAsyncCore() + { + // This will no-op if we're already stopped + await StopAsyncCore(exception: null); + + if (ChangeState(to: ConnectionState.Disposed) == ConnectionState.Disposed) + { + _logger.SkippingDispose(_connectionId); + + // the connection was already disposed + return; + } + + _logger.DisposingClient(_connectionId); + _httpClient?.Dispose(); } @@ -537,12 +595,39 @@ public void Dispose() } } - private class ConnectionState + private ConnectionState ChangeState(ConnectionState from, ConnectionState to) + { + lock (_stateChangeLock) + { + var state = _connectionState; + if (_connectionState == from) + { + _connectionState = to; + } + + _logger.ConnectionStateChanged(_connectionId, state, to); + return state; + } + } + + private ConnectionState ChangeState(ConnectionState to) + { + lock (_stateChangeLock) + { + var state = _connectionState; + _connectionState = to; + _logger.ConnectionStateChanged(_connectionId, state, to); + return state; + } + } + + // Internal because it's used by logging to avoid ToStringing prematurely. + internal enum ConnectionState { - public const int Initial = 0; - public const int Connecting = 1; - public const int Connected = 2; - public const int Disconnected = 3; + Disconnected, + Connecting, + Connected, + Disposed } private class NegotiationResponse diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs index a68cfb9659..6dc92777ad 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs @@ -151,11 +151,25 @@ internal static class SocketClientLoggerExtensions LoggerMessage.Define(LogLevel.Information, new EventId(18, nameof(StoppingClient)), "{time}: Connection Id {connectionId}: Stopping client."); private static readonly Action _exceptionThrownFromCallback = - LoggerMessage.Define(LogLevel.Error, new EventId(19, nameof(ExceptionThrownFromCallback)), "{time}: Connection Id {connectionId}: An exception was thrown from the '{callback}' callback"); + LoggerMessage.Define(LogLevel.Error, new EventId(19, nameof(ExceptionThrownFromCallback)), "{time}: Connection Id {connectionId}: An exception was thrown from the '{callback}' callback."); + + private static readonly Action _disposingClient = + LoggerMessage.Define(LogLevel.Information, new EventId(20, nameof(DisposingClient)), "{time}: Connection Id {connectionId}: Disposing client."); private static readonly Action _abortingClient = - LoggerMessage.Define(LogLevel.Error, new EventId(20, nameof(AbortingClient)), "{time}: Connection Id {connectionId}: Aborting client."); + LoggerMessage.Define(LogLevel.Error, new EventId(21, nameof(AbortingClient)), "{time}: Connection Id {connectionId}: Aborting client."); + + private static readonly Action _errorDuringClosedEvent = + LoggerMessage.Define(LogLevel.Error, new EventId(22, nameof(ErrorDuringClosedEvent)), "An exception was thrown in the handler for the Closed event."); + + private static readonly Action _skippingStop = + LoggerMessage.Define(LogLevel.Debug, new EventId(23, nameof(SkippingStop)), "{time}: Connection Id {connectionId}: Skipping stop, connection is already stopped."); + + private static readonly Action _skippingDispose = + LoggerMessage.Define(LogLevel.Debug, new EventId(24, nameof(SkippingDispose)), "{time}: Connection Id {connectionId}: Skipping dispose, connection is already disposed."); + private static readonly Action _connectionStateChanged = + LoggerMessage.Define(LogLevel.Debug, new EventId(25, nameof(ConnectionStateChanged)), "{time}: Connection Id {connectionId}: Connection state changed from {previousState} to {newState}."); public static void StartTransport(this ILogger logger, string connectionId, TransferMode transferMode) { @@ -525,6 +539,38 @@ public static void StoppingClient(this ILogger logger, string connectionId) } } + public static void DisposingClient(this ILogger logger, string connectionId) + { + if (logger.IsEnabled(LogLevel.Information)) + { + _disposingClient(logger, DateTime.Now, connectionId, null); + } + } + + public static void SkippingDispose(this ILogger logger, string connectionId) + { + if (logger.IsEnabled(LogLevel.Debug)) + { + _skippingDispose(logger, DateTime.Now, connectionId, null); + } + } + + public static void ConnectionStateChanged(this ILogger logger, string connectionId, HttpConnection.ConnectionState previousState, HttpConnection.ConnectionState newState) + { + if (logger.IsEnabled(LogLevel.Debug)) + { + _connectionStateChanged(logger, DateTime.Now, connectionId, previousState.ToString(), newState.ToString(), null); + } + } + + public static void SkippingStop(this ILogger logger, string connectionId) + { + if (logger.IsEnabled(LogLevel.Debug)) + { + _skippingStop(logger, DateTime.Now, connectionId, null); + } + } + public static void ExceptionThrownFromCallback(this ILogger logger, string connectionId, string callbackName, Exception exception) { if (logger.IsEnabled(LogLevel.Error)) @@ -532,5 +578,10 @@ public static void ExceptionThrownFromCallback(this ILogger logger, string conne _exceptionThrownFromCallback(logger, DateTime.Now, connectionId, callbackName, exception); } } + + public static void ErrorDuringClosedEvent(this ILogger logger, Exception exception) + { + _errorDuringClosedEvent(logger, exception); + } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index 2513939070..fcf46dbf41 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -103,6 +103,88 @@ public async Task CanSendAndReceiveMessage(IHubProtocol protocol, TransportType } } + [Theory] + [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] + public async Task CanStopAndStartConnection(IHubProtocol protocol, TransportType transportType, string path) + { + using (StartLog(out var loggerFactory)) + { + const string originalMessage = "SignalR"; + var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + path), transportType, loggerFactory); + var connection = new HubConnection(httpConnection, protocol, loggerFactory); + try + { + await connection.StartAsync().OrTimeout(); + var result = await connection.InvokeAsync("Echo", originalMessage).OrTimeout(); + Assert.Equal(originalMessage, result); + await connection.StopAsync().OrTimeout(); + await connection.StartAsync().OrTimeout(); + result = await connection.InvokeAsync("Echo", originalMessage).OrTimeout(); + Assert.Equal(originalMessage, result); + } + catch (Exception ex) + { + loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + throw; + } + finally + { + await connection.DisposeAsync().OrTimeout(); + } + } + } + + [Fact] + public async Task CanStartConnectionFromClosedEvent() + { + using (StartLog(out var loggerFactory)) + { + var logger = loggerFactory.CreateLogger(); + const string originalMessage = "SignalR"; + var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + "/default"), loggerFactory); + var connection = new HubConnection(httpConnection, new JsonHubProtocol(), loggerFactory); + var restartTcs = new TaskCompletionSource(); + connection.Closed += async e => + { + logger.LogInformation("Closed event triggered"); + if (!restartTcs.Task.IsCompleted) + { + logger.LogInformation("Restarting connection"); + await connection.StartAsync().OrTimeout(); + logger.LogInformation("Restarted connection"); + restartTcs.SetResult(null); + } + }; + + try + { + await connection.StartAsync().OrTimeout(); + var result = await connection.InvokeAsync("Echo", originalMessage).OrTimeout(); + Assert.Equal(originalMessage, result); + + logger.LogInformation("Stopping connection"); + await connection.StopAsync().OrTimeout(); + + logger.LogInformation("Waiting for reconnect"); + await restartTcs.Task.OrTimeout(); + logger.LogInformation("Reconnection complete"); + + result = await connection.InvokeAsync("Echo", originalMessage).OrTimeout(); + Assert.Equal(originalMessage, result); + + } + catch (Exception ex) + { + loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + throw; + } + finally + { + await connection.DisposeAsync().OrTimeout(); + } + } + } + [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] public async Task MethodsAreCaseInsensitive(IHubProtocol protocol, TransportType transportType, string path) @@ -174,12 +256,30 @@ public async Task InvokeNonExistantClientMethodFromServer(IHubProtocol protocol, { var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + path), transportType, loggerFactory); var connection = new HubConnection(httpConnection, protocol, loggerFactory); + var closeTcs = new TaskCompletionSource(); + connection.Closed += e => + { + if (e != null) + { + closeTcs.SetException(e); + } + else + { + closeTcs.SetResult(null); + } + }; + try { await connection.StartAsync().OrTimeout(); await connection.InvokeAsync("CallHandlerThatDoesntExist").OrTimeout(); await connection.DisposeAsync().OrTimeout(); - await connection.Closed.OrTimeout(); + await closeTcs.Task.OrTimeout(); + } + catch (Exception ex) + { + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} during test: {Message}", ex.GetType().Name, ex.Message); + throw; } finally { diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Microsoft.AspNetCore.SignalR.Client.FunctionalTests.csproj b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Microsoft.AspNetCore.SignalR.Client.FunctionalTests.csproj index 14c53b8017..4e32ffcd3f 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Microsoft.AspNetCore.SignalR.Client.FunctionalTests.csproj +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Microsoft.AspNetCore.SignalR.Client.FunctionalTests.csproj @@ -10,6 +10,12 @@ + + + PreserveNewest + + + diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.AbortAsync.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.AbortAsync.cs new file mode 100644 index 0000000000..f890ccb36e --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.AbortAsync.cs @@ -0,0 +1,275 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Channels; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Sockets; +using Microsoft.AspNetCore.Sockets.Client; +using Microsoft.AspNetCore.Sockets.Client.Http; +using Microsoft.AspNetCore.Sockets.Client.Tests; +using Microsoft.Extensions.Logging.Abstractions; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Client.Tests +{ + public partial class HttpConnectionTests + { + // Nested class for grouping + public class AbortAsync + { + [Fact] + public async Task AbortAsyncTriggersClosedEventWithException() + { + var connection = CreateConnection(out var closedTask); + try + { + // Start the connection + await connection.StartAsync().OrTimeout(); + + // Abort with an error + var expected = new Exception("Ruh roh!"); + await connection.AbortAsync(expected).OrTimeout(); + + // Verify that it is thrown + var actual = await Assert.ThrowsAsync(async () => await closedTask.OrTimeout()); + Assert.Same(expected, actual); + } + finally + { + // Dispose should be clean and exception free. + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task AbortAsyncWhileStoppingTriggersClosedEventWithException() + { + var connection = CreateConnection(out var closedTask, stopHandler: SyncPoint.Create(2, out var syncPoints)); + + try + { + // Start the connection + await connection.StartAsync().OrTimeout(); + + // Stop normally + var stopTask = connection.StopAsync().OrTimeout(); + + // Wait to reach the first sync point + await syncPoints[0].WaitForSyncPoint().OrTimeout(); + + // Abort with an error + var expected = new Exception("Ruh roh!"); + var abortTask = connection.AbortAsync(expected).OrTimeout(); + + // Wait for the sync point to hit again + await syncPoints[1].WaitForSyncPoint().OrTimeout(); + + // Release sync point 0 + syncPoints[0].Continue(); + + // We should close with the error from Abort (because it was set by the call to Abort even though Stop triggered the close) + var actual = await Assert.ThrowsAsync(async () => await closedTask.OrTimeout()); + Assert.Same(expected, actual); + + // Clean-up + syncPoints[1].Continue(); + await Task.WhenAll(stopTask, abortTask).OrTimeout(); + } + finally + { + // Dispose should be clean and exception free. + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task StopAsyncWhileAbortingTriggersClosedEventWithoutException() + { + var connection = CreateConnection(out var closedTask, stopHandler: SyncPoint.Create(2, out var syncPoints)); + + try + { + // Start the connection + await connection.StartAsync().OrTimeout(); + + // Abort with an error + var expected = new Exception("Ruh roh!"); + var abortTask = connection.AbortAsync(expected).OrTimeout(); + + // Wait to reach the first sync point + await syncPoints[0].WaitForSyncPoint().OrTimeout(); + + // Stop normally, without a sync point. + // This should clear the exception, meaning Closed will not "throw" + syncPoints[1].Continue(); + await connection.StopAsync(); + await closedTask.OrTimeout(); + + // Clean-up + syncPoints[0].Continue(); + await abortTask.OrTimeout(); + } + finally + { + // Dispose should be clean and exception free. + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task StartAsyncCannotBeCalledWhileAbortAsyncInProgress() + { + var connection = CreateConnection(out var closedTask, stopHandler: SyncPoint.Create(out var syncPoint)); + + try + { + // Start the connection + await connection.StartAsync().OrTimeout(); + + // Abort with an error + var expected = new Exception("Ruh roh!"); + var abortTask = connection.AbortAsync(expected).OrTimeout(); + + // Wait to reach the first sync point + await syncPoint.WaitForSyncPoint().OrTimeout(); + + var ex = await Assert.ThrowsAsync(() => connection.StartAsync().OrTimeout()); + Assert.Equal("Cannot start a connection that is not in the Disconnected state.", ex.Message); + + // Release the sync point and wait for close to complete + // (it will throw the abort exception) + syncPoint.Continue(); + await abortTask.OrTimeout(); + var actual = await Assert.ThrowsAsync(() => closedTask.OrTimeout()); + Assert.Same(expected, actual); + + // We can start now + await connection.StartAsync().OrTimeout(); + + // And we can stop without getting the abort exception. + await connection.StopAsync().OrTimeout(); + } + finally + { + // Dispose should be clean and exception free. + await connection.DisposeAsync().OrTimeout(); + } + } + + private HttpConnection CreateConnection(out Task closedTask, Func stopHandler = null) + { + var httpHandler = new TestHttpMessageHandler(); + var transportFactory = new TestTransportFactory(new TestTransport(stopHandler)); + var connection = new HttpConnection( + new Uri("http://fakeuri.org/"), + transportFactory, + NullLoggerFactory.Instance, + new HttpOptions() + { + HttpMessageHandler = httpHandler, + }); + + var closedTcs = new TaskCompletionSource(); + connection.Closed += ex => + { + if (ex != null) + { + closedTcs.SetException(ex); + } + else + { + closedTcs.SetResult(null); + } + }; + closedTask = closedTcs.Task; + + return connection; + } + + private class TestTransport : ITransport + { + private Channel _application; + private readonly Func _stopHandler; + + public TransferMode? Mode => TransferMode.Text; + + public TestTransport(Func stopHandler) + { + _stopHandler = stopHandler ?? new Func(() => Task.CompletedTask); + } + + public Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, string connectionId, IConnection connection) + { + _application = application; + return Task.CompletedTask; + } + + public async Task StopAsync() + { + await _stopHandler(); + _application.Writer.TryComplete(); + } + } + + // Possibly useful as a general-purpose async testing helper? + private class SyncPoint + { + private TaskCompletionSource _atSyncPoint = new TaskCompletionSource(); + private TaskCompletionSource _continueFromSyncPoint = new TaskCompletionSource(); + + // Used by the test code to wait and continue + public Task WaitForSyncPoint() => _atSyncPoint.Task; + public void Continue() => _continueFromSyncPoint.TrySetResult(null); + + // Used by the code under test to wait for the test code to release it. + public Task WaitToContinue() + { + _atSyncPoint.TrySetResult(null); + return _continueFromSyncPoint.Task; + } + + public static Func Create(out SyncPoint syncPoint) + { + var handler = Create(1, out var syncPoints); + syncPoint = syncPoints[0]; + return handler; + } + + /// + /// Creates a re-entrant function that waits for sync points in sequence. + /// + /// The number of sync points to expect + /// The objects that can be used to coordinate the sync point + /// + public static Func Create(int count, out SyncPoint[] syncPoints) + { + // Need to use a local so the closure can capture it. You can't use out vars in a closure. + var localSyncPoints = new SyncPoint[count]; + for (var i = 0; i < count; i += 1) + { + localSyncPoints[i] = new SyncPoint(); + } + + syncPoints = localSyncPoints; + + var counter = 0; + return () => + { + if (counter >= localSyncPoints.Length) + { + return Task.CompletedTask; + } + else + { + var syncPoint = localSyncPoints[counter]; + + counter += 1; + return syncPoint.WaitToContinue(); + } + }; + } + } + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs index baa342476e..f62e655b38 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Diagnostics; using System.Net; using System.Net.Http; using System.Text; @@ -12,14 +13,20 @@ using Microsoft.AspNetCore.Sockets.Client.Http; using Microsoft.AspNetCore.Sockets.Features; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Testing; using Moq; using Moq.Protected; using Xunit; +using Xunit.Abstractions; namespace Microsoft.AspNetCore.Sockets.Client.Tests { - public class HttpConnectionTests + public partial class HttpConnectionTests : LoggedTest { + public HttpConnectionTests(ITestOutputHelper output) : base(output) + { + } + [Fact] public void CannotCreateConnectionWithNullUrl() { @@ -53,7 +60,7 @@ public async Task CannotStartRunningConnection() { await Task.Yield(); - return IsNegotiateRequest(request) + return ResponseUtils.IsNegotiateRequest(request) ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); @@ -66,7 +73,7 @@ public async Task CannotStartRunningConnection() var exception = await Assert.ThrowsAsync( async () => await connection.StartAsync()); - Assert.Equal("Cannot start a connection that is not in the Initial state.", exception.Message); + Assert.Equal("Cannot start a connection that is not in the Disconnected state.", exception.Message); } finally { @@ -75,7 +82,7 @@ await Assert.ThrowsAsync( } [Fact] - public async Task CannotStartStoppedConnection() + public async Task CannotStartConnectionDisposedAfterStarting() { var mockHttpHandler = new Mock(); mockHttpHandler.Protected() @@ -83,7 +90,7 @@ public async Task CannotStartStoppedConnection() .Returns(async (request, cancellationToken) => { await Task.Yield(); - return IsNegotiateRequest(request) + return ResponseUtils.IsNegotiateRequest(request) ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); @@ -97,7 +104,7 @@ public async Task CannotStartStoppedConnection() await Assert.ThrowsAsync( async () => await connection.StartAsync()); - Assert.Equal("Cannot start a connection that is not in the Initial state.", exception.Message); + Assert.Equal("Cannot start a connection that is not in the Disconnected state.", exception.Message); } [Fact] @@ -111,12 +118,12 @@ public async Task CannotStartDisposedConnection() await Assert.ThrowsAsync( async () => await connection.StartAsync()); - Assert.Equal("Cannot start a connection that is not in the Initial state.", exception.Message); + Assert.Equal("Cannot start a connection that is not in the Disconnected state.", exception.Message); } } [Fact] - public async Task CanStopStartingConnection() + public async Task CanDisposeStartingConnection() { // Used to make sure StartAsync is not completed before DisposeAsync is called var releaseNegotiateTcs = new TaskCompletionSource(); @@ -134,13 +141,25 @@ public async Task CanStopStartingConnection() // allow DisposeAsync to continue once we know we are past the connection state check allowDisposeTcs.SetResult(null); await releaseNegotiateTcs.Task; - return IsNegotiateRequest(request) + return ResponseUtils.IsNegotiateRequest(request) ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); var transport = new Mock(); - transport.Setup(t => t.StopAsync()).Returns(async () => { await releaseDisposeTcs.Task; }); + Channel channel = null; + transport.SetupGet(t => t.Mode).Returns(TransferMode.Text); + transport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns, TransferMode, string, IConnection>((_, c, __, ___, ____) => + { + channel = c; + return Task.CompletedTask; + }); + transport.Setup(t => t.StopAsync()).Returns(async () => + { + await releaseDisposeTcs.Task; + channel.Writer.TryComplete(); + }); var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(transport.Object), loggerFactory: null, httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); @@ -154,8 +173,214 @@ public async Task CanStopStartingConnection() await startTask.OrTimeout(); releaseDisposeTcs.SetResult(null); await disposeTask.OrTimeout(); + } + + [Fact] + public async Task CanStartConnectionThatFailedToStart() + { + var failNegotiate = true; + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + + if (ResponseUtils.IsNegotiateRequest(request)) + { + return failNegotiate + ? ResponseUtils.CreateResponse(HttpStatusCode.InternalServerError) + : ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()); + } + + return ResponseUtils.CreateResponse(HttpStatusCode.OK); + }); + + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); + + try + { + await connection.StartAsync().OrTimeout(); + } + catch { } + failNegotiate = false; + await connection.StartAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + + [Fact] + public async Task CanStartStoppedConnection() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return ResponseUtils.IsNegotiateRequest(request) + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); + }); + + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); + + await connection.StartAsync().OrTimeout(); + await connection.StopAsync().OrTimeout(); + await connection.StartAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + + [Fact] + public async Task CanStopStartingConnection() + { + var allowStopTcs = new TaskCompletionSource(); + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + if (ResponseUtils.IsNegotiateRequest(request)) + { + allowStopTcs.SetResult(null); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()); + } + else + { + var content = request.Content != null ? await request.Content.ReadAsByteArrayAsync() : null; + return (content?.Length == 1 && content[0] == 0x42) + ? ResponseUtils.CreateResponse(HttpStatusCode.InternalServerError) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); + } + }); + + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); + + var closeTcs = new TaskCompletionSource(); + connection.Closed += e => closeTcs.TrySetResult(null); + + var startTask = connection.StartAsync(); + await allowStopTcs.Task.OrTimeout(); + + await Task.WhenAll(startTask, connection.StopAsync()).OrTimeout(); + await closeTcs.Task.OrTimeout(); + } + + [Fact] + public async Task CanStartConnectionAfterConnectionStoppedWithError() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + if (ResponseUtils.IsNegotiateRequest(request)) + { + return ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()); + } + + var content = request.Content != null ? await request.Content.ReadAsByteArrayAsync() : null; + return (content?.Length == 1 && content[0] == 0x42) + ? ResponseUtils.CreateResponse(HttpStatusCode.InternalServerError) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); + }); + + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); + + var closeTcs = new TaskCompletionSource(); + connection.Closed += e => closeTcs.TrySetResult(null); + + await connection.StartAsync().OrTimeout(); + try + { + await connection.SendAsync(new byte[] { 0x42 }).OrTimeout(); + } + catch { } + await closeTcs.Task.OrTimeout(); + await connection.StartAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + + [Fact] + public async Task CanDisposeStoppedConnection() + { + var connection = new HttpConnection(new Uri("http://fakeuri.org/")); + await connection.StopAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + + [Fact] + public async Task StoppingStoppingConnectionNoOps() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + if (ResponseUtils.IsNegotiateRequest(request)) + { + return ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()); + } + else + { + var content = request.Content != null ? await request.Content.ReadAsByteArrayAsync() : null; + return (content?.Length == 1 && content[0] == 0x42) + ? ResponseUtils.CreateResponse(HttpStatusCode.InternalServerError) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); + } + }); + + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); + + var closeTcs = new TaskCompletionSource(); + connection.Closed += e => closeTcs.TrySetResult(null); - transport.Verify(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny()), Times.Never); + await connection.StartAsync().OrTimeout(); + await Task.WhenAll(connection.StopAsync().OrTimeout(), connection.StopAsync().OrTimeout()); + await closeTcs.Task.OrTimeout(); + } + + [Fact] + public async Task DisposedStoppingConnectionDisposesConnection() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + if (ResponseUtils.IsNegotiateRequest(request)) + { + return ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()); + } + else + { + var content = request.Content != null ? await request.Content.ReadAsByteArrayAsync() : null; + return (content?.Length == 1 && content[0] == 0x42) + ? ResponseUtils.CreateResponse(HttpStatusCode.InternalServerError) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); + } + }); + + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); + + var closeTcs = new TaskCompletionSource(); + connection.Closed += e => closeTcs.TrySetResult(null); + + await connection.StartAsync().OrTimeout(); + await Task.WhenAll(connection.StopAsync().OrTimeout(), connection.DisposeAsync().OrTimeout()); + await closeTcs.Task.OrTimeout(); + + var exception = await Assert.ThrowsAsync(() => connection.StartAsync()); + Assert.Equal("Cannot start a connection that is not in the Disconnected state.", exception.Message); } [Fact] @@ -176,7 +401,7 @@ public async Task SendThrowsIfConnectionIsDisposed() .Returns(async (request, cancellationToken) => { await Task.Yield(); - return IsNegotiateRequest(request) + return ResponseUtils.IsNegotiateRequest(request) ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); @@ -201,7 +426,7 @@ public async Task ClosedEventRaisedWhenTheClientIsBeingStopped() .Returns(async (request, cancellationToken) => { await Task.Yield(); - return IsNegotiateRequest(request) + return ResponseUtils.IsNegotiateRequest(request) ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); @@ -209,11 +434,21 @@ public async Task ClosedEventRaisedWhenTheClientIsBeingStopped() var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); - + var closeTcs = new TaskCompletionSource(); + connection.Closed += e => + { + if (e != null) + { + closeTcs.SetException(e); + } + else + { + closeTcs.SetResult(null); + } + }; await connection.StartAsync().OrTimeout(); await connection.DisposeAsync().OrTimeout(); - await connection.Closed.OrTimeout(); - // in case of clean disconnect error should be null + await closeTcs.Task.OrTimeout(); } [Fact] @@ -228,18 +463,30 @@ public async Task ClosedEventRaisedWhenConnectionToServerLost() return request.Method == HttpMethod.Get ? ResponseUtils.CreateResponse(HttpStatusCode.InternalServerError) - : IsNegotiateRequest(request) + : ResponseUtils.IsNegotiateRequest(request) ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); + var closeTcs = new TaskCompletionSource(); + connection.Closed += e => + { + if (e != null) + { + closeTcs.SetException(e); + } + else + { + closeTcs.SetResult(null); + } + }; try { await connection.StartAsync().OrTimeout(); - await Assert.ThrowsAsync(() => connection.Closed.OrTimeout()); + await Assert.ThrowsAsync(() => closeTcs.Task.OrTimeout()); } finally { @@ -247,54 +494,6 @@ public async Task ClosedEventRaisedWhenConnectionToServerLost() } } - [Fact] - public async Task ReceivedCallbackNotRaisedAfterConnectionIsDisposed() - { - var mockHttpHandler = new Mock(); - mockHttpHandler.Protected() - .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) - .Returns(async (request, cancellationToken) => - { - await Task.Yield(); - return IsNegotiateRequest(request) - ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) - : ResponseUtils.CreateResponse(HttpStatusCode.OK); - }); - - var mockTransport = new Mock(); - Channel channel = null; - mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny())) - .Returns, TransferMode, string, IConnection>((url, c, transferMode, connectionId, _) => - { - channel = c; - return Task.CompletedTask; - }); - mockTransport.Setup(t => t.StopAsync()) - .Returns(() => - { - // The connection is now in the Disconnected state so the Received event for - // this message should not be raised - channel.Writer.TryWrite(Array.Empty()); - channel.Writer.TryComplete(); - return Task.CompletedTask; - }); - mockTransport.SetupGet(t => t.Mode).Returns(TransferMode.Text); - - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, - httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); - - var onReceivedInvoked = false; - connection.OnReceived(_ => - { - onReceivedInvoked = true; - return Task.CompletedTask; - }); - - await connection.StartAsync(); - await connection.DisposeAsync(); - Assert.False(onReceivedInvoked); - } - [Fact] public async Task EventsAreNotRunningOnMainLoop() { @@ -304,7 +503,7 @@ public async Task EventsAreNotRunningOnMainLoop() .Returns(async (request, cancellationToken) => { await Task.Yield(); - return IsNegotiateRequest(request) + return ResponseUtils.IsNegotiateRequest(request) ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); @@ -355,46 +554,63 @@ public async Task EventsAreNotRunningOnMainLoop() [Fact] public async Task EventQueueTimeout() { - var mockHttpHandler = new Mock(); - mockHttpHandler.Protected() - .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) - .Returns(async (request, cancellationToken) => - { - await Task.Yield(); - return IsNegotiateRequest(request) - ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) - : ResponseUtils.CreateResponse(HttpStatusCode.OK); - }); + using (StartLog(out var loggerFactory)) + { + var logger = loggerFactory.CreateLogger(); - var mockTransport = new Mock(); - Channel channel = null; - mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny())) - .Returns, TransferMode, string, IConnection>((url, c, transferMode, connectionId, _) => - { - channel = c; - return Task.CompletedTask; - }); - mockTransport.Setup(t => t.StopAsync()) - .Returns(() => - { - channel.Writer.TryComplete(); - return Task.CompletedTask; - }); - mockTransport.SetupGet(t => t.Mode).Returns(TransferMode.Text); + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return ResponseUtils.IsNegotiateRequest(request) + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); + }); - var blockReceiveCallbackTcs = new TaskCompletionSource(); + var mockTransport = new Mock(); + Channel channel = null; + mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns, TransferMode, string, IConnection>((url, c, transferMode, connectionId, _) => + { + logger.LogInformation("Transport started"); + channel = c; + return Task.CompletedTask; + }); + mockTransport.Setup(t => t.StopAsync()) + .Returns(() => + { + logger.LogInformation("Transport stopped"); + channel.Writer.TryComplete(); + return Task.CompletedTask; + }); + mockTransport.SetupGet(t => t.Mode).Returns(TransferMode.Text); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, - httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); - connection.OnReceived(_ => blockReceiveCallbackTcs.Task); + var blockReceiveCallbackTcs = new TaskCompletionSource(); + var onReceivedCalledTcs = new TaskCompletionSource(); - await connection.StartAsync(); - channel.Writer.TryWrite(Array.Empty()); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); + connection.OnReceived(async _ => + { + onReceivedCalledTcs.TrySetResult(null); + await blockReceiveCallbackTcs.Task; + }); - // Ensure that SignalR isn't blocked by the receive callback - Assert.False(channel.Reader.TryRead(out var message)); + logger.LogInformation("Starting connection"); + await connection.StartAsync().OrTimeout(); + logger.LogInformation("Started connection"); + channel.Writer.TryWrite(Array.Empty()); + await onReceivedCalledTcs.Task.OrTimeout(); - await connection.DisposeAsync(); + // Ensure that SignalR isn't blocked by the receive callback + Assert.False(channel.Reader.TryRead(out var message)); + + logger.LogInformation("Disposing connection"); + await connection.DisposeAsync().OrTimeout(TimeSpan.FromSeconds(10)); + logger.LogInformation("Disposed connection"); + } } [Fact] @@ -406,7 +622,7 @@ public async Task EventQueueTimeoutWithException() .Returns(async (request, cancellationToken) => { await Task.Yield(); - return IsNegotiateRequest(request) + return ResponseUtils.IsNegotiateRequest(request) ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); @@ -449,9 +665,10 @@ public async Task EventQueueTimeoutWithException() public async Task ClosedEventNotRaisedWhenTheClientIsStoppedButWasNeverStarted() { var connection = new HttpConnection(new Uri("http://fakeuri.org/")); - + var closeInvoked = false; + connection.Closed += e => closeInvoked = true; await connection.DisposeAsync(); - Assert.False(connection.Closed.IsCompleted); + Assert.False(closeInvoked); } [Fact] @@ -464,7 +681,7 @@ public async Task TransportIsStoppedWhenConnectionIsStopped() { await Task.Yield(); - return IsNegotiateRequest(request) + return ResponseUtils.IsNegotiateRequest(request) ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); @@ -502,7 +719,7 @@ public async Task CanSendData() .Returns(async (request, cancellationToken) => { await Task.Yield(); - if (IsNegotiateRequest(request)) + if (ResponseUtils.IsNegotiateRequest(request)) { return ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()); } @@ -557,7 +774,7 @@ public async Task SendAsyncThrowsIfConnectionIsDisposed() content = "T2:T:42;"; } - return IsNegotiateRequest(request) + return ResponseUtils.IsNegotiateRequest(request) ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) : ResponseUtils.CreateResponse(HttpStatusCode.OK, content); }); @@ -584,7 +801,7 @@ public async Task CallerReceivesExceptionsFromSendAsync() { await Task.Yield(); - return IsNegotiateRequest(request) + return ResponseUtils.IsNegotiateRequest(request) ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) : request.Method == HttpMethod.Post ? ResponseUtils.CreateResponse(HttpStatusCode.InternalServerError) @@ -618,7 +835,7 @@ public async Task CanReceiveData() content = "42"; } - return IsNegotiateRequest(request) + return ResponseUtils.IsNegotiateRequest(request) ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) : ResponseUtils.CreateResponse(HttpStatusCode.OK, content); }); @@ -635,18 +852,17 @@ public async Task CanReceiveData() return Task.CompletedTask; }, receiveTcs); - _ = connection.Closed.ContinueWith(task => + connection.Closed += e => { - if (task.Exception != null) + if (e != null) { - receiveTcs.TrySetException(task.Exception); + receiveTcs.TrySetException(e); } else { receiveTcs.TrySetCanceled(); } - return Task.CompletedTask; - }); + }; await connection.StartAsync().OrTimeout(); Assert.Equal("42", await receiveTcs.Task.OrTimeout()); @@ -674,7 +890,7 @@ public async Task CanReceiveDataEvenIfExceptionThrownFromPreviousReceivedEvent() content = "42"; } - return IsNegotiateRequest(request) + return ResponseUtils.IsNegotiateRequest(request) ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) : ResponseUtils.CreateResponse(HttpStatusCode.OK, content); }); @@ -698,18 +914,17 @@ public async Task CanReceiveDataEvenIfExceptionThrownFromPreviousReceivedEvent() return Task.CompletedTask; }); - _ = connection.Closed.ContinueWith(task => + connection.Closed += e => { - if (task.Exception != null) + if (e != null) { - receiveTcs.TrySetException(task.Exception); + receiveTcs.TrySetException(e); } else { receiveTcs.TrySetCanceled(); } - return Task.CompletedTask; - }); + }; await connection.StartAsync(); @@ -738,7 +953,7 @@ public async Task CanReceiveDataEvenIfExceptionThrownSynchronouslyFromPreviousRe content = "42"; } - return IsNegotiateRequest(request) + return ResponseUtils.IsNegotiateRequest(request) ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) : ResponseUtils.CreateResponse(HttpStatusCode.OK, content); }); @@ -762,18 +977,17 @@ public async Task CanReceiveDataEvenIfExceptionThrownSynchronouslyFromPreviousRe return Task.CompletedTask; }); - _ = connection.Closed.ContinueWith(task => + connection.Closed += e => { - if (task.Exception != null) + if (e != null) { - receiveTcs.TrySetException(task.Exception); + receiveTcs.TrySetException(e); } else { receiveTcs.TrySetCanceled(); } - return Task.CompletedTask; - }); + }; await connection.StartAsync(); @@ -797,7 +1011,7 @@ public async Task CannotSendAfterReceiveThrewException() return request.Method == HttpMethod.Get ? ResponseUtils.CreateResponse(HttpStatusCode.InternalServerError) - : IsNegotiateRequest(request) + : ResponseUtils.IsNegotiateRequest(request) ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); @@ -806,12 +1020,21 @@ public async Task CannotSendAfterReceiveThrewException() httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); try { + var closeTcs = new TaskCompletionSource(); + connection.Closed += e => + { + if (e != null) + { + closeTcs.SetException(e); + } + else + { + closeTcs.SetResult(null); + } + }; await connection.StartAsync().OrTimeout(); - - // Exception in send should shutdown the connection - await Assert.ThrowsAsync(() => connection.Closed.OrTimeout()); - + await Assert.ThrowsAsync(() => closeTcs.Task.OrTimeout()); var exception = await Assert.ThrowsAsync(() => connection.SendAsync(new byte[0])); Assert.Equal("Cannot send messages when the connection is not in the Connected state.", exception.Message); @@ -918,7 +1141,7 @@ public async Task CanStartConnectionWithoutSettingTransferModeFeature() .Returns(async (request, cancellationToken) => { await Task.Yield(); - return IsNegotiateRequest(request) + return ResponseUtils.IsNegotiateRequest(request) ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); @@ -977,11 +1200,5 @@ public async Task CorrectlyHandlesQueryStringWhenAppendingNegotiateToUrl(string await connection.StartAsync().OrTimeout(); await connection.DisposeAsync().OrTimeout(); } - - private bool IsNegotiateRequest(HttpRequestMessage request) - { - return request.Method == HttpMethod.Post && - new UriBuilder(request.RequestUri).Path.EndsWith("/negotiate"); - } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs index 85784f347d..3434658038 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs @@ -73,10 +73,11 @@ public async Task ClosedEventRaisedWhenTheClientIsStopped() { var hubConnection = new HubConnection(new TestConnection(), Mock.Of(), null); var closedEventTcs = new TaskCompletionSource(); + hubConnection.Closed += e => closedEventTcs.SetResult(e); await hubConnection.StartAsync().OrTimeout(); await hubConnection.DisposeAsync().OrTimeout(); - await hubConnection.Closed.OrTimeout(); + Assert.Null(await closedEventTcs.Task); } [Fact] @@ -182,9 +183,12 @@ public async Task PendingInvocationsAreTerminatedWithExceptionWhenConnectionClos await hubConnection.StartAsync(); var invokeTask = hubConnection.InvokeAsync("testMethod"); - await hubConnection.DisposeAsync(); - await Assert.ThrowsAsync(async () => await invokeTask); + var exception = new InvalidOperationException(); + mockConnection.Raise(m => m.Closed += null, exception); + + var actualException = await Assert.ThrowsAsync(async () => await invokeTask); + Assert.Equal(exception, actualException); } [Fact] @@ -196,8 +200,11 @@ public async Task ConnectionTerminatedIfServerTimeoutIntervalElapsesWithNoMessag hubConnection.ServerTimeout = TimeSpan.FromMilliseconds(100); await hubConnection.StartAsync().OrTimeout(); - var ex = await Assert.ThrowsAsync(async () => await hubConnection.Closed.OrTimeout()); - Assert.Equal("Server timeout (100.00ms) elapsed without receiving a message from the server.", ex.Message); + + var closeTcs = new TaskCompletionSource(); + hubConnection.Closed += ex => closeTcs.TrySetResult(ex); + var exception = Assert.IsType(await closeTcs.Task.OrTimeout()); + Assert.Equal("Server timeout (100.00ms) elapsed without receiving a message from the server.", exception.Message); } // Moq really doesn't handle out parameters well, so to make these tests work I added a manual mock -anurse diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj b/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj index 30dee8e47a..4ebefdd9b0 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj @@ -4,6 +4,12 @@ $(StandardTestTfms) + + + PreserveNewest + + + diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ResponseUtils.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ResponseUtils.cs index 6651cf24bc..0db15aa9af 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ResponseUtils.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ResponseUtils.cs @@ -1,6 +1,7 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Net; using System.Net.Http; using System.Text; @@ -28,6 +29,12 @@ public static HttpResponseMessage CreateResponse(HttpStatusCode statusCode, Http }; } + public static bool IsNegotiateRequest(HttpRequestMessage request) + { + return request.Method == HttpMethod.Post && + new UriBuilder(request.RequestUri).Path.EndsWith("/negotiate"); + } + public static string CreateNegotiationResponse(string connectionId = "00000000-0000-0000-0000-000000000000", SocketsTransportType? transportTypes = SocketsTransportType.All) { diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs index 4a4e8f81a0..880ae28b40 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs @@ -75,25 +75,23 @@ public async Task SSETransportStopsSendAndReceiveLoopsWhenTransportStopped() var mockHttpHandler = new Mock(); mockHttpHandler.Protected() .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) - .Returns(async (request, cancellationToken) => + .Returns((request, cancellationToken) => { - await Task.Yield(); - var mockStream = new Mock(); mockStream .Setup(s => s.CopyToAsync(It.IsAny(), It.IsAny(), It.IsAny())) .Returns(async (stream, bufferSize, t) => + { + await Task.Yield(); + var buffer = Encoding.ASCII.GetBytes("data: 3:abc\r\n\r\n"); + while (!eventStreamCts.IsCancellationRequested) { - await Task.Yield(); - var buffer = Encoding.ASCII.GetBytes("data: 3:abc\r\n\r\n"); - while (!eventStreamCts.IsCancellationRequested) - { - await stream.WriteAsync(buffer, 0, buffer.Length); - } - }); + await stream.WriteAsync(buffer, 0, buffer.Length).OrTimeout(); + } + }); mockStream.Setup(s => s.CanRead).Returns(true); - return new HttpResponseMessage { Content = new StreamContent(mockStream.Object) }; + return Task.FromResult(new HttpResponseMessage { Content = new StreamContent(mockStream.Object) }); }); Task transportActiveTask; diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs index aee51722ec..16b37fb79a 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs @@ -30,14 +30,16 @@ internal class TestConnection : IConnection private Task _receiveLoop; private TransferMode? _transferMode; - private readonly TaskCompletionSource _closeTcs = new TaskCompletionSource(); - public Task Closed => _closeTcs.Task; + public event Action Closed; public Task Started => _started.Task; public Task Disposed => _disposed.Task; public ChannelReader SentMessages => _sentMessages.Reader; public ChannelWriter ReceivedMessages => _receivedMessages.Writer; + private bool _closed; + private object _closedLock = new object(); + private readonly List _callbacks = new List(); public IFeatureCollection Features { get; } = new FeatureCollection(); @@ -51,19 +53,12 @@ public TestConnection(TransferMode? transferMode = null) public Task AbortAsync(Exception ex) => DisposeCoreAsync(ex); public Task DisposeAsync() => DisposeCoreAsync(); + // TestConnection isn't restartable + public Task StopAsync() => DisposeAsync(); + private Task DisposeCoreAsync(Exception ex = null) { - if (ex == null) - { - _closeTcs.TrySetResult(null); - _disposed.TrySetResult(null); - } - else - { - _closeTcs.TrySetException(ex); - _disposed.TrySetException(ex); - } - + TriggerClosed(ex); _receiveShutdownToken.Cancel(); return _receiveLoop; } @@ -147,16 +142,28 @@ private async Task ReceiveLoopAsync(CancellationToken token) } } } - _closeTcs.TrySetResult(null); + TriggerClosed(); } catch (OperationCanceledException) { // Do nothing, we were just asked to shut down. - _closeTcs.TrySetResult(null); + TriggerClosed(); } catch (Exception ex) { - _closeTcs.TrySetException(ex); + TriggerClosed(ex); + } + } + + private void TriggerClosed(Exception ex = null) + { + lock (_closedLock) + { + if (!_closed) + { + _closed = true; + Closed?.Invoke(ex); + } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs new file mode 100644 index 0000000000..c12f1c05aa --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs @@ -0,0 +1,25 @@ +using System; +using System.Net; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Client.Tests; + +namespace Microsoft.AspNetCore.SignalR.Client.Tests +{ + public class TestHttpMessageHandler : HttpMessageHandler + { + protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + if (ResponseUtils.IsNegotiateRequest(request)) + { + return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.OK, + ResponseUtils.CreateNegotiationResponse())); + } + else + { + return Task.FromException(new InvalidOperationException($"Http endpoint not implemented: {request.RequestUri}")); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Microsoft.AspNetCore.SignalR.Common.Tests.csproj b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Microsoft.AspNetCore.SignalR.Common.Tests.csproj index a4e695aee0..bc58f85974 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Microsoft.AspNetCore.SignalR.Common.Tests.csproj +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Microsoft.AspNetCore.SignalR.Common.Tests.csproj @@ -4,6 +4,12 @@ $(StandardTestTfms) + + + PreserveNewest + + + diff --git a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/Microsoft.AspNetCore.SignalR.Redis.Tests.csproj b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/Microsoft.AspNetCore.SignalR.Redis.Tests.csproj index 83bc0580a8..5ca77fbb65 100644 --- a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/Microsoft.AspNetCore.SignalR.Redis.Tests.csproj +++ b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/Microsoft.AspNetCore.SignalR.Redis.Tests.csproj @@ -4,6 +4,12 @@ $(StandardTestTfms) + + + PreserveNewest + + + diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs index 7af4600b24..5e50df59ea 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs @@ -164,6 +164,19 @@ public async Task ConnectionCanSendAndReceiveMessages(TransportType transportTyp new TransferModeFeature { TransferMode = requestedTransferMode }); try { + var closeTcs = new TaskCompletionSource(); + connection.Closed += e => + { + if (e != null) + { + closeTcs.SetException(e); + } + else + { + closeTcs.SetResult(null); + } + }; + var receiveTcs = new TaskCompletionSource(); connection.OnReceived((data, state) => { @@ -208,7 +221,7 @@ public async Task ConnectionCanSendAndReceiveMessages(TransportType transportTyp logger.LogInformation("Receiving message"); Assert.Equal(message, await receiveTcs.Task.OrTimeout()); logger.LogInformation("Completed receive"); - await connection.Closed.OrTimeout(); + await closeTcs.Task.OrTimeout(); } catch (Exception ex) { @@ -325,11 +338,22 @@ private async Task ServerClosesConnectionWithErrorIfHubCannotBeCreated(Transport try { var closeTcs = new TaskCompletionSource(); + connection.Closed += e => + { + if (e != null) + { + closeTcs.SetException(e); + } + else + { + closeTcs.SetResult(null); + } + }; logger.LogInformation("Starting connection to {url}", url); await connection.StartAsync().OrTimeout(); - await connection.Closed.OrTimeout(); + await closeTcs.Task.OrTimeout(); } catch (Exception ex) { diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/Microsoft.AspNetCore.SignalR.Tests.csproj b/test/Microsoft.AspNetCore.SignalR.Tests/Microsoft.AspNetCore.SignalR.Tests.csproj index 2fa2733b0d..dda3884804 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/Microsoft.AspNetCore.SignalR.Tests.csproj +++ b/test/Microsoft.AspNetCore.SignalR.Tests/Microsoft.AspNetCore.SignalR.Tests.csproj @@ -15,6 +15,12 @@ + + + PreserveNewest + + + diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj b/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj index 71808e5822..c743871b69 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj +++ b/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj @@ -5,6 +5,12 @@ win7-x86 + + + PreserveNewest + + + diff --git a/test/xunit.runner.json b/test/xunit.runner.json new file mode 100644 index 0000000000..e1589333fe --- /dev/null +++ b/test/xunit.runner.json @@ -0,0 +1,4 @@ +{ + "longRunningTestSeconds": 5, + "diagnosticMessages": false +} \ No newline at end of file