diff --git a/SignalR.sln b/SignalR.sln index c9f1d748fd..a5c696730c 100644 --- a/SignalR.sln +++ b/SignalR.sln @@ -1,6 +1,6 @@ Microsoft Visual Studio Solution File, Format Version 12.00 # Visual Studio 15 -VisualStudioVersion = 15.0.26923.0 +VisualStudioVersion = 15.0.27110.0 MinimumVisualStudioVersion = 15.0.26730.03 Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DA69F624-5398-4884-87E4-B816698CDE65}" ProjectSection(SolutionItems) = preProject @@ -56,14 +56,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.Signal EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.SignalR.Client.TS", "client-ts\Microsoft.AspNetCore.SignalR.Client.TS\Microsoft.AspNetCore.SignalR.Client.TS.csproj", "{333526A4-633B-491A-AC45-CC62A0012D1C}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Common", "Common", "{6CEC3DC2-5B01-45A8-8F0D-8531315DA90B}" - ProjectSection(SolutionItems) = preProject - test\Common\ChannelExtensions.cs = test\Common\ChannelExtensions.cs - test\Common\ServerFixture.cs = test\Common\ServerFixture.cs - test\Common\TaskExtensions.cs = test\Common\TaskExtensions.cs - test\Common\TestHelpers.cs = test\Common\TestHelpers.cs - EndProjectSection -EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "client-ts", "client-ts", "{3A76C5A2-79ED-49BC-8BDC-6A3A766FFA1B}" ProjectSection(SolutionItems) = preProject client-ts\package.json = client-ts\package.json @@ -91,7 +83,9 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.Signal EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "JwtSample", "samples\JwtSample\JwtSample.csproj", "{6A7491D3-3C97-49BD-A71C-433AED657F30}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "JwtClientSample", "samples\JwtClientSample\JwtClientSample.csproj", "{1A953296-E869-4DE2-A693-FD5FCDE27057}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "JwtClientSample", "samples\JwtClientSample\JwtClientSample.csproj", "{1A953296-E869-4DE2-A693-FD5FCDE27057}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNetCore.SignalR.Tests.Utils", "test\Microsoft.AspNetCore.SignalR.Tests.Utils\Microsoft.AspNetCore.SignalR.Tests.Utils.csproj", "{0A0A6135-EA24-4307-95C2-CE1B7E164A5E}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -207,6 +201,10 @@ Global {1A953296-E869-4DE2-A693-FD5FCDE27057}.Debug|Any CPU.Build.0 = Debug|Any CPU {1A953296-E869-4DE2-A693-FD5FCDE27057}.Release|Any CPU.ActiveCfg = Release|Any CPU {1A953296-E869-4DE2-A693-FD5FCDE27057}.Release|Any CPU.Build.0 = Release|Any CPU + {0A0A6135-EA24-4307-95C2-CE1B7E164A5E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {0A0A6135-EA24-4307-95C2-CE1B7E164A5E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {0A0A6135-EA24-4307-95C2-CE1B7E164A5E}.Release|Any CPU.ActiveCfg = Release|Any CPU + {0A0A6135-EA24-4307-95C2-CE1B7E164A5E}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -228,7 +226,6 @@ Global {354335AB-CEE9-4434-A641-78058F6EFE56} = {DA69F624-5398-4884-87E4-B816698CDE65} {455B68D2-C5B6-4BF4-A685-964B07AFAAF8} = {6A35B453-52EC-48AF-89CA-D4A69800F131} {333526A4-633B-491A-AC45-CC62A0012D1C} = {3A76C5A2-79ED-49BC-8BDC-6A3A766FFA1B} - {6CEC3DC2-5B01-45A8-8F0D-8531315DA90B} = {6A35B453-52EC-48AF-89CA-D4A69800F131} {96771B3F-4D18-41A7-A75B-FF38E76AAC89} = {8A4582C8-DC59-4B61-BCE7-119FBAA99EFB} {75E342F6-5445-4E7E-9143-6D9AE62C2B1E} = {6A35B453-52EC-48AF-89CA-D4A69800F131} {F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8} = {DA69F624-5398-4884-87E4-B816698CDE65} @@ -240,6 +237,7 @@ Global {0B083AE6-86CA-4E0B-AE02-59154D1FD005} = {6A35B453-52EC-48AF-89CA-D4A69800F131} {6A7491D3-3C97-49BD-A71C-433AED657F30} = {C4BC9889-B49F-41B6-806B-F84941B2549B} {1A953296-E869-4DE2-A693-FD5FCDE27057} = {C4BC9889-B49F-41B6-806B-F84941B2549B} + {0A0A6135-EA24-4307-95C2-CE1B7E164A5E} = {6A35B453-52EC-48AF-89CA-D4A69800F131} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {7945A4E4-ACDB-4F6E-95CA-6AC6E7C2CD59} diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs index 2e0596cae2..2d80701b84 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs @@ -1,10 +1,12 @@ using System; +using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; using BenchmarkDotNet.Attributes; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Internal; +using Microsoft.Extensions.Logging.Abstractions; namespace Microsoft.AspNetCore.SignalR.Microbenchmarks { @@ -32,7 +34,7 @@ public void GlobalSetup() var transport = ChannelConnection.Create(input: transportToApplication, output: applicationToTransport); var connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), transport, application); - _hubLifetimeManager.OnConnectedAsync(new HubConnectionContext(Channel.CreateUnbounded(), connection)).Wait(); + _hubLifetimeManager.OnConnectedAsync(new HubConnectionContext(connection, Timeout.InfiniteTimeSpan, NullLoggerFactory.Instance)).Wait(); } _hubContext = new HubContext(_hubLifetimeManager); diff --git a/client-ts/package-lock.json b/client-ts/package-lock.json index 2ea73c4c1c..ef3bd4b10a 100644 --- a/client-ts/package-lock.json +++ b/client-ts/package-lock.json @@ -28,16 +28,6 @@ "integrity": "sha512-zT+t9841g1HsjLtPMCYxmb1U4pcZ2TOegAKiomlmj6bIziuaEYHUavxLE9NRwdntY0vOCrgHho6OXjDX7fm/Kw==", "dev": true }, - "JSONStream": { - "version": "1.3.1", - "resolved": "https://registry.npmjs.org/JSONStream/-/JSONStream-1.3.1.tgz", - "integrity": "sha1-cH92HgHa6eFvG8+TcDt4xwlmV5o=", - "dev": true, - "requires": { - "jsonparse": "1.3.1", - "through": "2.3.8" - } - }, "acorn": { "version": "4.0.13", "resolved": "https://registry.npmjs.org/acorn/-/acorn-4.0.13.tgz", @@ -1027,9 +1017,9 @@ "integrity": "sha1-+GzWzvT1MAyOY+B6TVEvZfv/RTE=", "dev": true, "requires": { - "JSONStream": "1.3.1", "combine-source-map": "0.7.2", "defined": "1.0.0", + "JSONStream": "1.3.1", "through2": "2.0.3", "umd": "3.0.1" } @@ -1057,7 +1047,6 @@ "integrity": "sha1-tanJAgJD8McORnW+yCI7xifkFc4=", "dev": true, "requires": { - "JSONStream": "1.3.1", "assert": "1.4.1", "browser-pack": "6.0.2", "browser-resolve": "1.11.2", @@ -1079,6 +1068,7 @@ "https-browserify": "0.0.1", "inherits": "2.0.3", "insert-module-globals": "7.0.1", + "JSONStream": "1.3.1", "labeled-stream-splicer": "2.0.0", "module-deps": "4.1.1", "os-browserify": "0.1.2", @@ -2497,10 +2487,10 @@ "integrity": "sha1-wDv04BywhtW15azorQr+eInWOMM=", "dev": true, "requires": { - "JSONStream": "1.3.1", "combine-source-map": "0.7.2", "concat-stream": "1.5.2", "is-buffer": "1.1.5", + "JSONStream": "1.3.1", "lexical-scope": "1.2.0", "process": "0.11.10", "through2": "2.0.3", @@ -2773,6 +2763,16 @@ "integrity": "sha1-P02uSpH6wxX3EGL4UhzCOfE2YoA=", "dev": true }, + "JSONStream": { + "version": "1.3.1", + "resolved": "https://registry.npmjs.org/JSONStream/-/JSONStream-1.3.1.tgz", + "integrity": "sha1-cH92HgHa6eFvG8+TcDt4xwlmV5o=", + "dev": true, + "requires": { + "jsonparse": "1.3.1", + "through": "2.3.8" + } + }, "kind-of": { "version": "3.2.2", "resolved": "https://registry.npmjs.org/kind-of/-/kind-of-3.2.2.tgz", @@ -3126,7 +3126,6 @@ "integrity": "sha1-IyFYM/HaE/1gbMuAh7RIUty4If0=", "dev": true, "requires": { - "JSONStream": "1.3.1", "browser-resolve": "1.11.2", "cached-path-relative": "1.0.1", "concat-stream": "1.5.2", @@ -3134,6 +3133,7 @@ "detective": "4.5.0", "duplexer2": "0.1.4", "inherits": "2.0.3", + "JSONStream": "1.3.1", "parents": "1.0.1", "readable-stream": "2.2.11", "resolve": "1.3.3", diff --git a/samples/SocketsSample/Startup.cs b/samples/SocketsSample/Startup.cs index 6c68af5414..bc6096a19f 100644 --- a/samples/SocketsSample/Startup.cs +++ b/samples/SocketsSample/Startup.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.DependencyInjection; @@ -17,7 +18,11 @@ public void ConfigureServices(IServiceCollection services) { services.AddSockets(); - services.AddSignalR(); + services.AddSignalR(options => + { + // Faster pings for testing + options.KeepAliveInterval = TimeSpan.FromSeconds(5); + }); // .AddRedis(); services.AddCors(o => diff --git a/samples/SocketsSample/wwwroot/hubs.html b/samples/SocketsSample/wwwroot/hubs.html index ce3678b319..04d5868c23 100644 --- a/samples/SocketsSample/wwwroot/hubs.html +++ b/samples/SocketsSample/wwwroot/hubs.html @@ -1,4 +1,4 @@ - + diff --git a/samples/SocketsSample/wwwroot/sockets.html b/samples/SocketsSample/wwwroot/sockets.html index f971761d89..4f5f492bf2 100644 --- a/samples/SocketsSample/wwwroot/sockets.html +++ b/samples/SocketsSample/wwwroot/sockets.html @@ -1,4 +1,4 @@ - + @@ -31,7 +31,7 @@

Unknown Transport

document.getElementById('transportName').innerHTML = signalR.TransportType[transportType]; let url = 'http://' + document.location.host + '/chat'; - let connection = new signalR.HttpConnection(url, { transport: transportType, logger: new signalR.ConsoleLogger(signalR.LogLevel.Information) }); + let connection = new signalR.HttpConnection(url, { transport: transportType, logging: new signalR.ConsoleLogger(signalR.LogLevel.Information) }); connection.onreceive = function(data) { let child = document.createElement('li'); diff --git a/samples/SocketsSample/wwwroot/streaming.html b/samples/SocketsSample/wwwroot/streaming.html index 1b13148a8d..976b82362b 100644 --- a/samples/SocketsSample/wwwroot/streaming.html +++ b/samples/SocketsSample/wwwroot/streaming.html @@ -1,4 +1,4 @@ - + @@ -65,7 +65,7 @@

Results

}); click('connectButton', function () { - connection = new signalR.HubConnection('/streaming', { transport: transportType, logger: logger }); + connection = new signalR.HubConnection('/streaming', { transport: transportType, logging: logger }); connection.onclose(function () { channelButton.disabled = true; diff --git a/specs/HubProtocol.md b/specs/HubProtocol.md index 7ab37f61d6..54d88ca3a7 100644 --- a/specs/HubProtocol.md +++ b/specs/HubProtocol.md @@ -109,6 +109,8 @@ Keep alive behavior is achieved via the `Ping` message type. **Either endpoint** Ping messages do not have any payload, they are completely empty messages (aside from the encoding necessary to identify the message as a `Ping` message). +It is up to the server implementation to decide how frequently (if at all) `Ping` frames are sent. The ASP.NET Core implementation sends `Ping` frames only when using the Server Sent Events and WebSockets transports, at a default interval of 15 seconds (configurable). However, a `Ping` frame is only sent if 15 seconds elapses since the last message was sent. Clients may choose to use the "Ping rate" to provide a timeout for the server connection. Since the Client can expect the server to send `Ping` frames at regular intervals, even when the connection is idle, it can use that to determine if the server has left without closing the connection. The ASP.NET Core implementation (both JavaScript and C#) use a default timeout window of 30 seconds, which is twice the server ping rate interval. + ## Example Consider the following C# methods diff --git a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs index dbaab03e65..c036489dbb 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -86,7 +86,7 @@ private Task InvokeAllWhere(string methodName, object[] args, Func WriteAsync(c, message)); + var tasks = group.Values.Select(c => c.WriteAsync(message)); return Task.WhenAll(tasks); } @@ -153,17 +153,6 @@ public override Task OnDisconnectedAsync(HubConnectionContext connection) return Task.CompletedTask; } - private async Task WriteAsync(HubConnectionContext connection, HubInvocationMessage hubMessage) - { - while (await connection.Output.WaitToWriteAsync()) - { - if (connection.Output.TryWrite(hubMessage)) - { - break; - } - } - } - private string GetInvocationId() { var invocationId = Interlocked.Increment(ref _nextInvocationId); diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Features/IHubFeature.cs b/src/Microsoft.AspNetCore.SignalR.Core/Features/IHubFeature.cs deleted file mode 100644 index 7b8a0528c7..0000000000 --- a/src/Microsoft.AspNetCore.SignalR.Core/Features/IHubFeature.cs +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using Microsoft.AspNetCore.SignalR.Internal; - -namespace Microsoft.AspNetCore.SignalR.Features -{ - public interface IHubFeature - { - HubProtocolReaderWriter ProtocolReaderWriter { get; set; } - } - - public class HubFeature : IHubFeature - { - public HubProtocolReaderWriter ProtocolReaderWriter { get; set; } - } -} diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs index 61e9b3e0f2..53accccc7e 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs @@ -4,43 +4,49 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Diagnostics; using System.Runtime.ExceptionServices; using System.Security.Claims; using System.Threading; -using System.Threading.Tasks; using System.Threading.Channels; +using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Features; -using Microsoft.AspNetCore.SignalR.Features; +using Microsoft.AspNetCore.SignalR.Core; +using Microsoft.AspNetCore.SignalR.Core.Internal; using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Encoders; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Features; +using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.SignalR { public class HubConnectionContext { private static Action _abortedCallback = AbortConnection; + private static readonly Base64Encoder Base64Encoder = new Base64Encoder(); + private static readonly PassThroughEncoder PassThroughEncoder = new PassThroughEncoder(); - private readonly ChannelWriter _output; private readonly ConnectionContext _connectionContext; + private readonly ILogger _logger; private readonly CancellationTokenSource _connectionAbortedTokenSource = new CancellationTokenSource(); private readonly TaskCompletionSource _abortCompletedTcs = new TaskCompletionSource(); + private readonly long _keepAliveDuration; - public HubConnectionContext(ChannelWriter output, ConnectionContext connectionContext) + private Task _writingTask = Task.CompletedTask; + private long _lastSendTimestamp = Stopwatch.GetTimestamp(); + private byte[] _pingMessage; + + public HubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory) { - _output = output; + Output = Channel.CreateUnbounded(); _connectionContext = connectionContext; + _logger = loggerFactory.CreateLogger(); ConnectionAbortedToken = _connectionAbortedTokenSource.Token; + _keepAliveDuration = (int)keepAliveInterval.TotalMilliseconds * (Stopwatch.Frequency / 1000); } - private IHubFeature HubFeature => Features.Get(); - - // Used by the HubEndPoint only - internal ChannelReader Input => _connectionContext.Transport; - - internal ExceptionDispatchInfo AbortException { get; private set; } - public virtual CancellationToken ConnectionAbortedToken { get; } public virtual string ConnectionId => _connectionContext.ConnectionId; @@ -53,11 +59,37 @@ public HubConnectionContext(ChannelWriter output, ConnectionContext public virtual HubProtocolReaderWriter ProtocolReaderWriter { get; set; } - public virtual ChannelWriter Output => _output; + public virtual ChannelReader Input => _connectionContext.Transport.Reader; + + public string UserIdentifier { get; private set; } + + internal virtual Channel Output { get; set; } + + internal ExceptionDispatchInfo AbortException { get; private set; } // Currently used only for streaming methods internal ConcurrentDictionary ActiveRequestCancellationSources { get; } = new ConcurrentDictionary(); + public async Task WriteAsync(HubInvocationMessage message) + { + while (await Output.Writer.WaitToWriteAsync()) + { + if (Output.Writer.TryWrite(message)) + { + return; + } + } + } + + public async Task DisposeAsync() + { + // Nothing should be writing to the HubConnectionContext + Output.Writer.TryComplete(); + + // This should unwind once we complete the output + await _writingTask; + } + public virtual void Abort() { // If we already triggered the token then noop, this isn't thread safe but it's good enough @@ -71,7 +103,62 @@ public virtual void Abort() Task.Factory.StartNew(_abortedCallback, this); } - public string UserIdentifier { get; internal set; } + // Hubs support multiple producers so we set up this loop to copy + // data written to the HubConnectionContext's channel to the transport channel + internal Task StartAsync() + { + return _writingTask = StartAsyncCore(); + } + + internal async Task NegotiateAsync(TimeSpan timeout, IHubProtocolResolver protocolResolver, IUserIdProvider userIdProvider) + { + try + { + using (var cts = new CancellationTokenSource()) + { + cts.CancelAfter(timeout); + while (await _connectionContext.Transport.Reader.WaitToReadAsync(cts.Token)) + { + while (_connectionContext.Transport.Reader.TryRead(out var buffer)) + { + if (NegotiationProtocol.TryParseMessage(buffer, out var negotiationMessage)) + { + var protocol = protocolResolver.GetProtocol(negotiationMessage.Protocol, this); + + var transportCapabilities = Features.Get()?.TransportCapabilities + ?? throw new InvalidOperationException("Unable to read transport capabilities."); + + var dataEncoder = (protocol.Type == ProtocolType.Binary && (transportCapabilities & TransferMode.Binary) == 0) + ? (IDataEncoder)Base64Encoder + : PassThroughEncoder; + + var transferModeFeature = Features.Get() ?? + throw new InvalidOperationException("Unable to read transfer mode."); + + transferModeFeature.TransferMode = + (protocol.Type == ProtocolType.Binary && (transportCapabilities & TransferMode.Binary) != 0) + ? TransferMode.Binary + : TransferMode.Text; + + ProtocolReaderWriter = new HubProtocolReaderWriter(protocol, dataEncoder); + + _logger.UsingHubProtocol(protocol.Name); + + UserIdentifier = userIdProvider.GetUserId(this); + + return true; + } + } + } + } + } + catch (OperationCanceledException) + { + _logger.NegotiateCanceled(); + } + + return false; + } internal void Abort(Exception exception) { @@ -86,6 +173,68 @@ internal Task AbortAsync() return _abortCompletedTcs.Task; } + private async Task StartAsyncCore() + { + if (Features.Get() == null) + { + Debug.Assert(ProtocolReaderWriter != null, "Expected the ProtocolReaderWriter to be set before StartAsync is called"); + _pingMessage = ProtocolReaderWriter.WriteMessage(PingMessage.Instance); + _connectionContext.Features.Get()?.OnHeartbeat(state => ((HubConnectionContext)state).KeepAliveTick(), this); + } + + try + { + while (await Output.Reader.WaitToReadAsync()) + { + while (Output.Reader.TryRead(out var hubMessage)) + { + var buffer = ProtocolReaderWriter.WriteMessage(hubMessage); + while (await _connectionContext.Transport.Writer.WaitToWriteAsync()) + { + if (_connectionContext.Transport.Writer.TryWrite(buffer)) + { + Interlocked.Exchange(ref _lastSendTimestamp, Stopwatch.GetTimestamp()); + break; + } + } + } + } + } + catch (Exception ex) + { + Abort(ex); + } + } + + private void KeepAliveTick() + { + // Implements the keep-alive tick behavior + // Each tick, we check if the time since the last send is larger than the keep alive duration (in ticks). + // If it is, we send a ping frame, if not, we no-op on this tick. This means that in the worst case, the + // true "ping rate" of the server could be (_hubOptions.KeepAliveInterval + HubEndPoint.KeepAliveTimerInterval), + // because if the interval elapses right after the last tick of this timer, it won't be detected until the next tick. + Debug.Assert(_pingMessage != null, "Expected the ping message to be prepared before the first heartbeat tick"); + + if (Stopwatch.GetTimestamp() - Interlocked.Read(ref _lastSendTimestamp) > _keepAliveDuration) + { + // Haven't sent a message for the entire keep-alive duration, so send a ping. + // If the transport channel is full, this will fail, but that's OK because + // adding a Ping message when the transport is full is unnecessary since the + // transport is still in the process of sending frames. + if (_connectionContext.Transport.Writer.TryWrite(_pingMessage)) + { + _logger.SentPing(); + } + else + { + // This isn't necessarily an error, it just indicates that the transport is applying backpressure right now. + _logger.TransportBufferFull(); + } + + Interlocked.Exchange(ref _lastSendTimestamp, Stopwatch.GetTimestamp()); + } + } + private static void AbortConnection(object state) { var connection = (HubConnectionContext)state; diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs index 0e9ce1d65c..0d904a2d83 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs @@ -8,17 +8,14 @@ using System.Reflection; using System.Security.Claims; using System.Threading; -using System.Threading.Tasks; using System.Threading.Channels; +using System.Threading.Tasks; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.SignalR.Core; using Microsoft.AspNetCore.SignalR.Core.Internal; -using Microsoft.AspNetCore.SignalR.Features; using Microsoft.AspNetCore.SignalR.Internal; -using Microsoft.AspNetCore.SignalR.Internal.Encoders; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; -using Microsoft.AspNetCore.Sockets.Features; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Internal; using Microsoft.Extensions.Logging; @@ -28,13 +25,11 @@ namespace Microsoft.AspNetCore.SignalR { public class HubEndPoint : IInvocationBinder where THub : Hub { - private static readonly Base64Encoder Base64Encoder = new Base64Encoder(); - private static readonly PassThroughEncoder PassThroughEncoder = new PassThroughEncoder(); - private readonly Dictionary _methods = new Dictionary(StringComparer.OrdinalIgnoreCase); private readonly HubLifetimeManager _lifetimeManager; private readonly IHubContext _hubContext; + private readonly ILoggerFactory _loggerFactory; private readonly ILogger> _logger; private readonly IServiceScopeFactory _serviceScopeFactory; private readonly IHubProtocolResolver _protocolResolver; @@ -45,15 +40,16 @@ public HubEndPoint(HubLifetimeManager lifetimeManager, IHubProtocolResolver protocolResolver, IHubContext hubContext, IOptions hubOptions, - ILogger> logger, + ILoggerFactory loggerFactory, IServiceScopeFactory serviceScopeFactory, IUserIdProvider userIdProvider) { _protocolResolver = protocolResolver; _lifetimeManager = lifetimeManager; _hubContext = hubContext; + _loggerFactory = loggerFactory; _hubOptions = hubOptions.Value; - _logger = logger; + _logger = loggerFactory.CreateLogger>(); _serviceScopeFactory = serviceScopeFactory; _userIdProvider = userIdProvider; @@ -62,50 +58,15 @@ public HubEndPoint(HubLifetimeManager lifetimeManager, public async Task OnConnectedAsync(ConnectionContext connection) { - var output = Channel.CreateUnbounded(); - - // Set the hub feature before doing anything else. This stores - // all the relevant state for a SignalR Hub connection. - connection.Features.Set(new HubFeature()); + var connectionContext = new HubConnectionContext(connection, _hubOptions.KeepAliveInterval, _loggerFactory); - var connectionContext = new HubConnectionContext(output, connection); - - if (!await ProcessNegotiate(connectionContext)) + if (!await connectionContext.NegotiateAsync(_hubOptions.NegotiateTimeout, _protocolResolver, _userIdProvider)) { return; } - connectionContext.UserIdentifier = _userIdProvider.GetUserId(connectionContext); - - // Hubs support multiple producers so we set up this loop to copy - // data written to the HubConnectionContext's channel to the transport channel - var protocolReaderWriter = connectionContext.ProtocolReaderWriter; - async Task WriteToTransport() - { - try - { - while (await output.Reader.WaitToReadAsync()) - { - while (output.Reader.TryRead(out var hubMessage)) - { - var buffer = protocolReaderWriter.WriteMessage(hubMessage); - while (await connection.Transport.Writer.WaitToWriteAsync()) - { - if (connection.Transport.Writer.TryWrite(buffer)) - { - break; - } - } - } - } - } - catch (Exception ex) - { - connectionContext.Abort(ex); - } - } - - var writingOutputTask = WriteToTransport(); + // We don't need to hold this task, it's also held internally and awaited by DisposeAsync. + _ = connectionContext.StartAsync(); try { @@ -116,61 +77,10 @@ async Task WriteToTransport() { await _lifetimeManager.OnDisconnectedAsync(connectionContext); - // Nothing should be writing to the HubConnectionContext - output.Writer.TryComplete(); - - // This should unwind once we complete the output - await writingOutputTask; + await connectionContext.DisposeAsync(); } } - private async Task ProcessNegotiate(HubConnectionContext connection) - { - try - { - using (var cts = new CancellationTokenSource()) - { - cts.CancelAfter(_hubOptions.NegotiateTimeout); - while (await connection.Input.WaitToReadAsync(cts.Token)) - { - while (connection.Input.TryRead(out var buffer)) - { - if (NegotiationProtocol.TryParseMessage(buffer, out var negotiationMessage)) - { - var protocol = _protocolResolver.GetProtocol(negotiationMessage.Protocol, connection); - - var transportCapabilities = connection.Features.Get()?.TransportCapabilities - ?? throw new InvalidOperationException("Unable to read transport capabilities."); - - var dataEncoder = (protocol.Type == ProtocolType.Binary && (transportCapabilities & TransferMode.Binary) == 0) - ? (IDataEncoder)Base64Encoder - : PassThroughEncoder; - - var transferModeFeature = connection.Features.Get() ?? - throw new InvalidOperationException("Unable to read transfer mode."); - - transferModeFeature.TransferMode = - (protocol.Type == ProtocolType.Binary && (transportCapabilities & TransferMode.Binary) != 0) - ? TransferMode.Binary - : TransferMode.Text; - - connection.ProtocolReaderWriter = new HubProtocolReaderWriter(protocol, dataEncoder); - - _logger.UsingHubProtocol(protocol.Name); - - return true; - } - } - } - } - } - catch (OperationCanceledException) - { - _logger.NegotiateCanceled(); - } - - return false; - } private async Task RunHubAsync(HubConnectionContext connection) { @@ -352,9 +262,9 @@ await SendMessageAsync(connection, CompletionMessage.WithError( private async Task SendMessageAsync(HubConnectionContext connection, HubMessage hubMessage) { - while (await connection.Output.WaitToWriteAsync()) + while (await connection.Output.Writer.WaitToWriteAsync()) { - if (connection.Output.TryWrite(hubMessage)) + if (connection.Output.Writer.TryWrite(hubMessage)) { return; } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubOptions.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubOptions.cs index 1ca0415e8e..b30944c6ab 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubOptions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubOptions.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -10,8 +10,24 @@ namespace Microsoft.AspNetCore.SignalR { public class HubOptions { + /// + /// The default keep-alive interval. This is set to exactly half of the default client timeout window, + /// to ensure a ping can arrive in time to satisfy the client timeout. + /// + public static readonly TimeSpan DefaultKeepAliveInterval = TimeSpan.FromSeconds(15); + public JsonSerializerSettings JsonSerializerSettings { get; set; } = JsonHubProtocol.CreateDefaultSerializerSettings(); public SerializationContext MessagePackSerializationContext { get; set; } = MessagePackHubProtocol.CreateDefaultSerializationContext(); public TimeSpan NegotiateTimeout { get; set; } = TimeSpan.FromSeconds(5); + + /// + /// The interval at which keep-alive messages should be sent. The default interval + /// is 15 seconds. + /// + /// + /// This interval is not used by the Long Polling transport as it has inherent keep-alive + /// functionality because of the polling mechanism. + /// + public TimeSpan KeepAliveInterval { get; set; } = DefaultKeepAliveInterval; } } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/SignalRCoreLoggerExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/SignalRCoreLoggerExtensions.cs index 47a36488f1..5024ed2555 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/SignalRCoreLoggerExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/SignalRCoreLoggerExtensions.cs @@ -10,65 +10,72 @@ namespace Microsoft.AspNetCore.SignalR.Core.Internal internal static class SignalRCoreLoggerExtensions { // Category: HubEndPoint - private static readonly Action _usingHubProtocol = - LoggerMessage.Define(LogLevel.Information, new EventId(0, nameof(UsingHubProtocol)), "Using HubProtocol '{protocol}'."); - - private static readonly Action _negotiateCanceled = - LoggerMessage.Define(LogLevel.Debug, new EventId(1, nameof(NegotiateCanceled)), "Negotiate was canceled."); - private static readonly Action _errorProcessingRequest = - LoggerMessage.Define(LogLevel.Error, new EventId(2, nameof(ErrorProcessingRequest)), "Error when processing requests."); + LoggerMessage.Define(LogLevel.Error, new EventId(1, nameof(ErrorProcessingRequest)), "Error when processing requests."); private static readonly Action _errorInvokingHubMethod = - LoggerMessage.Define(LogLevel.Error, new EventId(3, nameof(ErrorInvokingHubMethod)), "Error when invoking '{hubMethod}' on hub."); + LoggerMessage.Define(LogLevel.Error, new EventId(2, nameof(ErrorInvokingHubMethod)), "Error when invoking '{hubMethod}' on hub."); private static readonly Action _receivedHubInvocation = - LoggerMessage.Define(LogLevel.Debug, new EventId(4, nameof(ReceivedHubInvocation)), "Received hub invocation: {invocationMessage}."); + LoggerMessage.Define(LogLevel.Debug, new EventId(3, nameof(ReceivedHubInvocation)), "Received hub invocation: {invocationMessage}."); private static readonly Action _unsupportedMessageReceived = - LoggerMessage.Define(LogLevel.Error, new EventId(5, nameof(UnsupportedMessageReceived)), "Received unsupported message of type '{messageType}'."); + LoggerMessage.Define(LogLevel.Error, new EventId(4, nameof(UnsupportedMessageReceived)), "Received unsupported message of type '{messageType}'."); private static readonly Action _unknownHubMethod = - LoggerMessage.Define(LogLevel.Error, new EventId(6, nameof(UnknownHubMethod)), "Unknown hub method '{hubMethod}'."); + LoggerMessage.Define(LogLevel.Error, new EventId(5, nameof(UnknownHubMethod)), "Unknown hub method '{hubMethod}'."); private static readonly Action _outboundChannelClosed = - LoggerMessage.Define(LogLevel.Warning, new EventId(7, nameof(OutboundChannelClosed)), "Outbound channel was closed while trying to write hub message."); + LoggerMessage.Define(LogLevel.Warning, new EventId(6, nameof(OutboundChannelClosed)), "Outbound channel was closed while trying to write hub message."); private static readonly Action _hubMethodNotAuthorized = - LoggerMessage.Define(LogLevel.Debug, new EventId(8, nameof(HubMethodNotAuthorized)), "Failed to invoke '{hubMethod}' because user is unauthorized."); + LoggerMessage.Define(LogLevel.Debug, new EventId(7, nameof(HubMethodNotAuthorized)), "Failed to invoke '{hubMethod}' because user is unauthorized."); private static readonly Action _streamingResult = - LoggerMessage.Define(LogLevel.Trace, new EventId(9, nameof(StreamingResult)), "{invocationId}: Streaming result of type '{resultType}'."); + LoggerMessage.Define(LogLevel.Trace, new EventId(8, nameof(StreamingResult)), "{invocationId}: Streaming result of type '{resultType}'."); private static readonly Action _sendingResult = - LoggerMessage.Define(LogLevel.Trace, new EventId(10, nameof(SendingResult)), "{invocationId}: Sending result of type '{resultType}'."); + LoggerMessage.Define(LogLevel.Trace, new EventId(9, nameof(SendingResult)), "{invocationId}: Sending result of type '{resultType}'."); private static readonly Action _failedInvokingHubMethod = - LoggerMessage.Define(LogLevel.Error, new EventId(11, nameof(FailedInvokingHubMethod)), "Failed to invoke hub method '{hubMethod}'."); + LoggerMessage.Define(LogLevel.Error, new EventId(10, nameof(FailedInvokingHubMethod)), "Failed to invoke hub method '{hubMethod}'."); private static readonly Action _hubMethodBound = - LoggerMessage.Define(LogLevel.Trace, new EventId(12, nameof(HubMethodBound)), "Hub method '{hubMethod}' is bound."); + LoggerMessage.Define(LogLevel.Trace, new EventId(11, nameof(HubMethodBound)), "Hub method '{hubMethod}' is bound."); private static readonly Action _cancelStream = - LoggerMessage.Define(LogLevel.Debug, new EventId(13, nameof(CancelStream)), "Canceling stream for invocation {invocationId}."); + LoggerMessage.Define(LogLevel.Debug, new EventId(12, nameof(CancelStream)), "Canceling stream for invocation {invocationId}."); private static readonly Action _unexpectedCancel = - LoggerMessage.Define(LogLevel.Debug, new EventId(14, nameof(UnexpectedCancel)), "CancelInvocationMessage received unexpectedly."); + LoggerMessage.Define(LogLevel.Debug, new EventId(13, nameof(UnexpectedCancel)), "CancelInvocationMessage received unexpectedly."); private static readonly Action _abortFailed = - LoggerMessage.Define(LogLevel.Trace, new EventId(15, nameof(AbortFailed)), "Abort callback failed."); + LoggerMessage.Define(LogLevel.Trace, new EventId(14, nameof(AbortFailed)), "Abort callback failed."); private static readonly Action _receivedStreamHubInvocation = - LoggerMessage.Define(LogLevel.Debug, new EventId(16, nameof(ReceivedStreamHubInvocation)), "Received stream hub invocation: {invocationMessage}."); + LoggerMessage.Define(LogLevel.Debug, new EventId(15, nameof(ReceivedStreamHubInvocation)), "Received stream hub invocation: {invocationMessage}."); private static readonly Action _streamingMethodCalledWithInvoke = - LoggerMessage.Define(LogLevel.Error, new EventId(17, nameof(StreamingMethodCalledWithInvoke)), "A streaming method was invoked in the non-streaming fashion : {invocationMessage}."); + LoggerMessage.Define(LogLevel.Error, new EventId(16, nameof(StreamingMethodCalledWithInvoke)), "A streaming method was invoked in the non-streaming fashion : {invocationMessage}."); private static readonly Action _nonStreamingMethodCalledWithStream = - LoggerMessage.Define(LogLevel.Error, new EventId(18, nameof(NonStreamingMethodCalledWithStream)), "A non-streaming method was invoked in the streaming fashion : {invocationMessage}."); + LoggerMessage.Define(LogLevel.Error, new EventId(17, nameof(NonStreamingMethodCalledWithStream)), "A non-streaming method was invoked in the streaming fashion : {invocationMessage}."); private static readonly Action _invalidReturnValueFromStreamingMethod = - LoggerMessage.Define(LogLevel.Error, new EventId(19, nameof(InvalidReturnValueFromStreamingMethod)), "A streaming method returned a value that cannot be used to build enumerator {hubMethod}."); + LoggerMessage.Define(LogLevel.Error, new EventId(18, nameof(InvalidReturnValueFromStreamingMethod)), "A streaming method returned a value that cannot be used to build enumerator {hubMethod}."); + + // Category: HubConnectionContext + private static readonly Action _usingHubProtocol = + LoggerMessage.Define(LogLevel.Information, new EventId(1, nameof(UsingHubProtocol)), "Using HubProtocol '{protocol}'."); + + private static readonly Action _negotiateCanceled = + LoggerMessage.Define(LogLevel.Debug, new EventId(2, nameof(NegotiateCanceled)), "Negotiate was canceled."); + + private static readonly Action _sentPing = + LoggerMessage.Define(LogLevel.Trace, new EventId(3, nameof(SentPing)), "Sent a ping message to the client."); + + private static readonly Action _transportBufferFull = + LoggerMessage.Define(LogLevel.Debug, new EventId(4, nameof(TransportBufferFull)), "Unable to send Ping message to client, the transport buffer is full."); public static void UsingHubProtocol(this ILogger logger, string hubProtocol) { @@ -169,5 +176,15 @@ public static void InvalidReturnValueFromStreamingMethod(this ILogger logger, st { _invalidReturnValueFromStreamingMethod(logger, hubMethod, null); } + + public static void SentPing(this ILogger logger) + { + _sentPing(logger, null); + } + + public static void TransportBufferFull(this ILogger logger) + { + _transportBufferFull(logger, null); + } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Properties/AssemblyInfo.cs b/src/Microsoft.AspNetCore.SignalR.Core/Properties/AssemblyInfo.cs index 91975c8961..3661bf333e 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Properties/AssemblyInfo.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Properties/AssemblyInfo.cs @@ -3,5 +3,4 @@ using System.Runtime.CompilerServices; - -[assembly: InternalsVisibleTo("Microsoft.AspNetCore.SignalR.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] \ No newline at end of file +[assembly: InternalsVisibleTo("Microsoft.AspNetCore.SignalR.Tests.Utils, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index 76d1e97932..a4eff48152 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -177,7 +177,7 @@ public override Task InvokeConnectionAsync(string connectionId, string methodNam var connection = _connections[connectionId]; if (connection != null) { - return WriteAsync(connection, message); + return connection.WriteAsync(message); } return PublishAsync(_channelNamePrefix + "." + connectionId, message); @@ -370,17 +370,6 @@ public void Dispose() _ackHandler.Dispose(); } - private static async Task WriteAsync(HubConnectionContext connection, HubInvocationMessage hubMessage) - { - while (await connection.Output.WaitToWriteAsync()) - { - if (connection.Output.TryWrite(hubMessage)) - { - break; - } - } - } - private string GetInvocationId() { var invocationId = Interlocked.Increment(ref _nextInvocationId); @@ -410,7 +399,7 @@ private void SubscribeToHub() foreach (var connection in _connections) { - tasks.Add(WriteAsync(connection, message)); + tasks.Add(connection.WriteAsync(message)); } await Task.WhenAll(tasks); @@ -441,7 +430,7 @@ private void SubscribeToAllExcept() { if (!excludedIds.Contains(connection.ConnectionId)) { - tasks.Add(WriteAsync(connection, message)); + tasks.Add(connection.WriteAsync(message)); } } @@ -521,7 +510,7 @@ private Task SubscribeToConnection(HubConnectionContext connection, HashSet(data); - await WriteAsync(connection, message); + await connection.WriteAsync(message); } catch (Exception ex) { @@ -542,7 +531,7 @@ private Task SubscribeToUser(HubConnectionContext connection, HashSet re { var message = DeserializeMessage(data); - await WriteAsync(connection, message); + await connection.WriteAsync(message); } catch (Exception ex) { @@ -563,7 +552,7 @@ private Task SubscribeToGroup(string groupChannel, GroupData group) var tasks = new List(group.Connections.Count); foreach (var groupConnection in group.Connections) { - tasks.Add(WriteAsync(groupConnection, message)); + tasks.Add(groupConnection.WriteAsync(message)); } await Task.WhenAll(tasks); diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/ConnectionInherentKeepAliveFeature.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/ConnectionInherentKeepAliveFeature.cs new file mode 100644 index 0000000000..d89ff8489c --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/ConnectionInherentKeepAliveFeature.cs @@ -0,0 +1,17 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Sockets.Features +{ + public class ConnectionInherentKeepAliveFeature : IConnectionInherentKeepAliveFeature + { + public TimeSpan KeepAliveInterval { get; } + + public ConnectionInherentKeepAliveFeature(TimeSpan keepAliveInterval) + { + KeepAliveInterval = keepAliveInterval; + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionHeartbeatFeature.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionHeartbeatFeature.cs new file mode 100644 index 0000000000..016a63b869 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionHeartbeatFeature.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Sockets.Features +{ + public interface IConnectionHeartbeatFeature + { + void OnHeartbeat(Action action, object state); + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionInherentKeepAliveFeature.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionInherentKeepAliveFeature.cs new file mode 100644 index 0000000000..b1f39bf567 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionInherentKeepAliveFeature.cs @@ -0,0 +1,24 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.AspNetCore.Sockets.Features +{ + /// + /// Indicates if the connection transport has an "inherent keep-alive", which means that the transport will automatically + /// inform the client that it is still present. + /// + /// + /// The most common example of this feature is the Long Polling HTTP transport, which must (due to HTTP limitations) terminate + /// each poll within a particular interval and return a signal indicating "the server is still here, but there is no data yet". + /// This feature allows applications to add keep-alive functionality, but limit it only to transports that don't have some kind + /// of inherent keep-alive. + /// + public interface IConnectionInherentKeepAliveFeature + { + TimeSpan KeepAliveInterval { get; } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/Microsoft.AspNetCore.Sockets.Abstractions.csproj b/src/Microsoft.AspNetCore.Sockets.Abstractions/Microsoft.AspNetCore.Sockets.Abstractions.csproj index 56b275ea06..e1e43ecaf5 100644 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/Microsoft.AspNetCore.Sockets.Abstractions.csproj +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/Microsoft.AspNetCore.Sockets.Abstractions.csproj @@ -3,6 +3,7 @@ Components for providing real-time bi-directional communication across the Web. netstandard2.0 + Microsoft.AspNetCore.Sockets diff --git a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs index 213a91b966..4ee3bb84b6 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs @@ -2,11 +2,13 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Diagnostics; using System.IO; using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Sockets.Features; using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.AspNetCore.Sockets.Internal.Transports; using Microsoft.Extensions.Logging; @@ -99,7 +101,7 @@ private async Task ExecuteEndpointAsync(HttpContext context, SocketDelegate sock return; } - if (!await EnsureConnectionStateAsync(connection, context, TransportType.ServerSentEvents, supportedTransports, logScope)) + if (!await EnsureConnectionStateAsync(connection, context, TransportType.ServerSentEvents, supportedTransports, logScope, options)) { // Bad connection state. It's already set the response status code. return; @@ -125,7 +127,7 @@ private async Task ExecuteEndpointAsync(HttpContext context, SocketDelegate sock return; } - if (!await EnsureConnectionStateAsync(connection, context, TransportType.WebSockets, supportedTransports, logScope)) + if (!await EnsureConnectionStateAsync(connection, context, TransportType.WebSockets, supportedTransports, logScope, options)) { // Bad connection state. It's already set the response status code. return; @@ -149,7 +151,7 @@ private async Task ExecuteEndpointAsync(HttpContext context, SocketDelegate sock return; } - if (!await EnsureConnectionStateAsync(connection, context, TransportType.LongPolling, supportedTransports, logScope)) + if (!await EnsureConnectionStateAsync(connection, context, TransportType.LongPolling, supportedTransports, logScope, options)) { // Bad connection state. It's already set the response status code. return; @@ -334,6 +336,12 @@ private async Task DoPersistentConnection(SocketDelegate socketDelegate, private async Task ExecuteApplication(SocketDelegate socketDelegate, ConnectionContext connection) { + // Verify some initialization invariants + // We want to be positive that the IConnectionInherentKeepAliveFeature is initialized before invoking the application, if the long polling transport is in use. + Debug.Assert(connection.Metadata[ConnectionMetadataNames.Transport] != null, "Transport has not been initialized yet"); + Debug.Assert((TransportType?)connection.Metadata[ConnectionMetadataNames.Transport] != TransportType.LongPolling || + connection.Features.Get() != null, "Long-polling transport is in use but IConnectionInherentKeepAliveFeature as not configured"); + // Jump onto the thread pool thread so blocking user code doesn't block the setup of the // connection and transport await AwaitableThreadPool.Yield(); @@ -435,7 +443,7 @@ private async Task ProcessSend(HttpContext context) } } - private async Task EnsureConnectionStateAsync(DefaultConnectionContext connection, HttpContext context, TransportType transportType, TransportType supportedTransports, ConnectionLogScope logScope) + private async Task EnsureConnectionStateAsync(DefaultConnectionContext connection, HttpContext context, TransportType transportType, TransportType supportedTransports, ConnectionLogScope logScope, HttpSocketOptions options) { if ((supportedTransports & transportType) == 0) { @@ -459,6 +467,12 @@ private async Task EnsureConnectionStateAsync(DefaultConnectionContext con return false; } + // Configure transport-specific features. + if (transportType == TransportType.LongPolling) + { + connection.Features.Set(new ConnectionInherentKeepAliveFeature(options.LongPolling.PollTimeout)); + } + // Setup the connection state from the http context connection.User = context.User; connection.SetHttpContext(context); diff --git a/src/Microsoft.AspNetCore.Sockets.Http/HttpSocketOptions.cs b/src/Microsoft.AspNetCore.Sockets.Http/HttpSocketOptions.cs index ea9f921541..51dc9c4d5f 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/HttpSocketOptions.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/HttpSocketOptions.cs @@ -1,6 +1,7 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Collections.Generic; using Microsoft.AspNetCore.Authorization; @@ -16,4 +17,4 @@ public class HttpSocketOptions public LongPollingOptions LongPolling { get; } = new LongPollingOptions(); } -} \ No newline at end of file +} diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/LongPollingTransport.cs index a9ae71af2a..3c82a0205b 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/LongPollingTransport.cs @@ -4,9 +4,10 @@ using System; using System.Collections.Generic; using System.Threading; -using System.Threading.Tasks; using System.Threading.Channels; +using System.Threading.Tasks; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Sockets.Features; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Sockets.Internal.Transports diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs index 58536855d8..55bb38b3e7 100644 --- a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -8,16 +8,20 @@ using System.IO; using System.Net.WebSockets; using System.Threading; -using System.Threading.Tasks; using System.Threading.Channels; +using System.Threading.Tasks; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Sockets.Internal; +using Microsoft.Extensions.Internal; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Sockets { public class ConnectionManager { + // TODO: Consider making this configurable? At least for testing? + private static readonly TimeSpan _heartbeatTickRate = TimeSpan.FromSeconds(1); + private readonly ConcurrentDictionary _connections = new ConcurrentDictionary(); private Timer _timer; private readonly ILogger _logger; @@ -27,7 +31,6 @@ public class ConnectionManager public ConnectionManager(ILogger logger, IApplicationLifetime appLifetime) { _logger = logger; - appLifetime.ApplicationStarted.Register(() => Start()); appLifetime.ApplicationStopping.Register(() => CloseConnections()); } @@ -43,7 +46,7 @@ public void Start() if (_timer == null) { - _timer = new Timer(Scan, this, TimeSpan.FromSeconds(1), TimeSpan.FromSeconds(1)); + _timer = new Timer(Scan, this, _heartbeatTickRate, _heartbeatTickRate); } } } @@ -107,7 +110,7 @@ public void Scan() try { - if (_disposed || Debugger.IsAttached) + if (_disposed) { return; } @@ -115,6 +118,11 @@ public void Scan() // Pause the timer while we're running _timer.Change(Timeout.Infinite, Timeout.Infinite); + // Time the scan so we know if it gets slower than 1sec + var timer = ValueStopwatch.StartNew(); + SocketEventSource.Log.ScanningConnections(); + _logger.ScanningConnections(); + // Scan the registered connections looking for ones that have timed out foreach (var c in _connections) { @@ -136,16 +144,27 @@ public void Scan() } // Once the decision has been made to dispose we don't check the status again - if (status == DefaultConnectionContext.ConnectionStatus.Inactive && (DateTimeOffset.UtcNow - lastSeenUtc).TotalSeconds > 5) + // But don't clean up connections while the debugger is attached. + if (!Debugger.IsAttached && status == DefaultConnectionContext.ConnectionStatus.Inactive && (DateTimeOffset.UtcNow - lastSeenUtc).TotalSeconds > 5) { _logger.ConnectionTimedOut(c.Value.ConnectionId); SocketEventSource.Log.ConnectionTimedOut(c.Value.ConnectionId); var ignore = DisposeAndRemoveAsync(c.Value); } + else + { + // Tick the heartbeat, if the connection is still active + c.Value.TickHeartbeat(); + } } + // TODO: We could use this timer to determine if the connection scanner is too slow, but we need an idea of what "too slow" is. + var elapsed = timer.GetElapsedTime(); + SocketEventSource.Log.ScannedConnections(elapsed); + _logger.ScannedConnections(elapsed); + // Resume once we finished processing all connections - _timer.Change(TimeSpan.FromSeconds(1), TimeSpan.FromSeconds(1)); + _timer.Change(_heartbeatTickRate, _heartbeatTickRate); } finally { diff --git a/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs b/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs index 5c7c853a2d..b27b9543cd 100644 --- a/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs +++ b/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs @@ -5,8 +5,8 @@ using System.Collections.Generic; using System.Security.Claims; using System.Threading; -using System.Threading.Tasks; using System.Threading.Channels; +using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Sockets.Features; using Microsoft.Extensions.Internal; @@ -18,8 +18,11 @@ public class DefaultConnectionContext : ConnectionContext, IConnectionMetadataFeature, IConnectionTransportFeature, IConnectionUserFeature, + IConnectionHeartbeatFeature, ITransferModeFeature { + private List<(Action handler, object state)> _heartbeatHandlers = new List<(Action handler, object state)>(); + // This tcs exists so that multiple calls to DisposeAsync all wait asynchronously // on the same task private TaskCompletionSource _disposeTcs = new TaskCompletionSource(); @@ -39,6 +42,7 @@ public DefaultConnectionContext(string id, Channel transport, Channel(this); Features.Set(this); Features.Set(this); + Features.Set(this); } public CancellationTokenSource Cancellation { get; set; } @@ -69,6 +73,19 @@ public DefaultConnectionContext(string id, Channel transport, Channel action, object state) + { + _heartbeatHandlers.Add((action, state)); + } + + public void TickHeartbeat() + { + foreach (var (handler, state) in _heartbeatHandlers) + { + handler(state); + } + } + public async Task DisposeAsync() { Task disposeTask = Task.CompletedTask; diff --git a/src/Microsoft.AspNetCore.Sockets/Internal/SocketEventSource.cs b/src/Microsoft.AspNetCore.Sockets/Internal/SocketEventSource.cs index 66a1bb6357..1f9f22cc38 100644 --- a/src/Microsoft.AspNetCore.Sockets/Internal/SocketEventSource.cs +++ b/src/Microsoft.AspNetCore.Sockets/Internal/SocketEventSource.cs @@ -1,6 +1,7 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Diagnostics.Tracing; using Microsoft.Extensions.Internal; @@ -42,6 +43,15 @@ public void ConnectionStop(string connectionId, ValueStopwatch timer) } } + [NonEvent] + public void ScannedConnections(TimeSpan duration) + { + if (IsEnabled() && IsEnabled(EventLevel.Verbose, EventKeywords.None)) + { + ScannedConnections(duration.TotalMilliseconds); + } + } + [Event(eventId: 1, Level = EventLevel.Informational, Message = "Started connection '{0}'.")] public ValueStopwatch ConnectionStart(string connectionId) { @@ -74,5 +84,17 @@ public void ConnectionTimedOut(string connectionId) } } } + + [Event(eventId: 4, Level = EventLevel.Verbose, Message = "Scanning connections.")] + public void ScanningConnections() + { + if (IsEnabled() && IsEnabled(EventLevel.Verbose, EventKeywords.None)) + { + WriteEvent(4); + } + } + + [Event(eventId: 5, Level = EventLevel.Verbose, Message = "Finished scanning connections. Duration: {0:0.00}ms.")] + private void ScannedConnections(double durationInMilliseconds) => WriteEvent(5, durationInMilliseconds); } } diff --git a/src/Microsoft.AspNetCore.Sockets/Internal/SocketLoggerExtensions.cs b/src/Microsoft.AspNetCore.Sockets/Internal/SocketLoggerExtensions.cs index a167ffd048..f85281c86c 100644 --- a/src/Microsoft.AspNetCore.Sockets/Internal/SocketLoggerExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets/Internal/SocketLoggerExtensions.cs @@ -24,6 +24,12 @@ internal static class SocketLoggerExtensions private static readonly Action _connectionTimedOut = LoggerMessage.Define(LogLevel.Trace, new EventId(4, nameof(ConnectionTimedOut)), "{time}: ConnectionId {connectionId}: Connection timed out."); + private static readonly Action _scanningConnections = + LoggerMessage.Define(LogLevel.Trace, new EventId(5, nameof(ScanningConnections)), "{time}: Scanning connections."); + + private static readonly Action _scannedConnections = + LoggerMessage.Define(LogLevel.Trace, new EventId(6, nameof(ScannedConnections)), "{time}: Scanned connections in {duration}."); + public static void CreatedNewConnection(this ILogger logger, string connectionId) { if (logger.IsEnabled(LogLevel.Debug)) @@ -63,5 +69,21 @@ public static void ConnectionReset(this ILogger logger, string connectionId, Exc _connectionReset(logger, DateTime.Now, connectionId, exception); } } + + public static void ScanningConnections(this ILogger logger) + { + if (logger.IsEnabled(LogLevel.Trace)) + { + _scanningConnections(logger, DateTime.Now, null); + } + } + + public static void ScannedConnections(this ILogger logger, TimeSpan duration) + { + if (logger.IsEnabled(LogLevel.Trace)) + { + _scannedConnections(logger, DateTime.Now, duration, null); + } + } } } diff --git a/test/Common/TestClient.cs b/test/Common/TestClient.cs index 2acecdd66c..2afc8886e8 100644 --- a/test/Common/TestClient.cs +++ b/test/Common/TestClient.cs @@ -25,7 +25,6 @@ public class TestClient : IDisposable private CancellationTokenSource _cts; private ChannelConnection _transport; - public DefaultConnectionContext Connection { get; } public Channel Application { get; } public Task Connected => ((TaskCompletionSource)Connection.Metadata["ConnectedTask"]).Task; @@ -121,6 +120,9 @@ public async Task InvokeAsync(string methodName, params objec throw new NotSupportedException("Use 'StreamAsync' to call a streaming method"); case CompletionMessage completion: return completion; + case PingMessage _: + // Pings are ignored + break; default: throw new NotSupportedException("TestClient does not support receiving invocations!"); } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index e8b7ef1525..012a11b799 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -9,7 +9,7 @@ using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Internal.Protocol; -using Microsoft.AspNetCore.SignalR.Tests.Common; +using Microsoft.AspNetCore.SignalR.Tests; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.Extensions.Logging; diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Microsoft.AspNetCore.SignalR.Client.FunctionalTests.csproj b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Microsoft.AspNetCore.SignalR.Client.FunctionalTests.csproj index 5d81375a97..14c53b8017 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Microsoft.AspNetCore.SignalR.Client.FunctionalTests.csproj +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Microsoft.AspNetCore.SignalR.Client.FunctionalTests.csproj @@ -10,26 +10,17 @@ - - - - - - - + - - - diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs index 89cbb5c8e2..86d975fb89 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs @@ -9,7 +9,6 @@ using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Client.Tests; -using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets.Client.Http; using Microsoft.AspNetCore.Sockets.Features; using Microsoft.Extensions.Logging; diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionExtensionsTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionExtensionsTests.cs index 3b5b70dde8..65fc4e6556 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionExtensionsTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionExtensionsTests.cs @@ -4,7 +4,6 @@ using System; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Internal.Protocol; -using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.Extensions.Logging; using Newtonsoft.Json; using Xunit; diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs index 8b4d0995c6..2c029fcc41 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs @@ -8,7 +8,6 @@ using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Internal.Protocol; -using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.Logging; using Newtonsoft.Json; diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs index a8671a09a7..35cd3a1a6e 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs @@ -8,7 +8,6 @@ using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; -using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.Extensions.Logging; using Moq; diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs index 5f14541613..55c1452223 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs @@ -7,9 +7,8 @@ using System.Net.Http; using System.Text; using System.Threading; -using System.Threading.Tasks; using System.Threading.Channels; -using Microsoft.AspNetCore.SignalR.Tests.Common; +using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.AspNetCore.Sockets.Internal; diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj b/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj index c527585014..30dee8e47a 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj @@ -4,18 +4,13 @@ $(StandardTestTfms) - - - - - + - diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs index 622996bef3..961700585f 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs @@ -7,11 +7,10 @@ using System.Net.Http.Headers; using System.Text; using System.Threading; -using System.Threading.Tasks; using System.Threading.Channels; +using System.Threading.Tasks; using Microsoft.AspNetCore.Client.Tests; using Microsoft.AspNetCore.SignalR.Internal; -using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.AspNetCore.Sockets.Internal; diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Microsoft.AspNetCore.SignalR.Common.Tests.csproj b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Microsoft.AspNetCore.SignalR.Common.Tests.csproj index 5d080979f9..a4e695aee0 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Microsoft.AspNetCore.SignalR.Common.Tests.csproj +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Microsoft.AspNetCore.SignalR.Common.Tests.csproj @@ -6,6 +6,7 @@ + diff --git a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/Microsoft.AspNetCore.SignalR.Redis.Tests.csproj b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/Microsoft.AspNetCore.SignalR.Redis.Tests.csproj index 7cfdbb3cfa..83bc0580a8 100644 --- a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/Microsoft.AspNetCore.SignalR.Redis.Tests.csproj +++ b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/Microsoft.AspNetCore.SignalR.Redis.Tests.csproj @@ -4,23 +4,15 @@ $(StandardTestTfms) - - - - - - - + - - diff --git a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisEndToEnd.cs b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisEndToEnd.cs index 01c351fa36..7664d7145c 100644 --- a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisEndToEnd.cs +++ b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisEndToEnd.cs @@ -6,7 +6,7 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Client; using Microsoft.AspNetCore.SignalR.Internal.Protocol; -using Microsoft.AspNetCore.SignalR.Tests.Common; +using Microsoft.AspNetCore.SignalR.Tests; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Testing.xunit; using Microsoft.Extensions.Logging; @@ -37,7 +37,7 @@ public RedisEndToEndTests(RedisServerFixture serverFixture, ITestOutput _serverFixture = serverFixture; } - [ConditionalTheory] + [ConditionalTheory(Skip = "Docker tests are flaky")] [SkipIfDockerNotPresent] [MemberData(nameof(TransportTypesAndProtocolTypes))] public async Task HubConnectionCanSendAndReceiveMessages(TransportType transportType, IHubProtocol protocol) @@ -56,7 +56,7 @@ public async Task HubConnectionCanSendAndReceiveMessages(TransportType transport } } - [ConditionalTheory] + [ConditionalTheory(Skip = "Docker tests are flaky")] [SkipIfDockerNotPresent] [MemberData(nameof(TransportTypesAndProtocolTypes))] public async Task HubConnectionCanSendAndReceiveGroupMessages(TransportType transportType, IHubProtocol protocol) diff --git a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs index 7dcc10a447..b92c45b291 100644 --- a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs @@ -3,12 +3,15 @@ using System; using System.Threading; -using System.Threading.Tasks; using System.Threading.Channels; +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Encoders; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Tests; -using Microsoft.AspNetCore.SignalR.Tests.Common; +using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Options; using Moq; using Xunit; @@ -23,24 +26,21 @@ public async Task InvokeAllAsyncWritesToAllConnectionsOutput() using (var client1 = new TestClient()) using (var client2 = new TestClient()) { - var output1 = Channel.CreateUnbounded(); - var output2 = Channel.CreateUnbounded(); - var manager = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() { Factory = t => new TestConnectionMultiplexer() })); - var connection1 = new HubConnectionContext(output1, client1.Connection); - var connection2 = new HubConnectionContext(output2, client2.Connection); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); await manager.OnConnectedAsync(connection1).OrTimeout(); await manager.OnConnectedAsync(connection2).OrTimeout(); await manager.InvokeAllAsync("Hello", new object[] { "World" }).OrTimeout(); - AssertMessage(output1); - AssertMessage(output2); + await AssertMessageAsync(client1); + await AssertMessageAsync(client2); } } @@ -50,16 +50,13 @@ public async Task InvokeAllAsyncDoesNotWriteToDisconnectedConnectionsOutput() using (var client1 = new TestClient()) using (var client2 = new TestClient()) { - var output1 = Channel.CreateUnbounded(); - var output2 = Channel.CreateUnbounded(); - var manager = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() { Factory = t => new TestConnectionMultiplexer() })); - var connection1 = new HubConnectionContext(output1, client1.Connection); - var connection2 = new HubConnectionContext(output2, client2.Connection); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); await manager.OnConnectedAsync(connection1).OrTimeout(); await manager.OnConnectedAsync(connection2).OrTimeout(); @@ -68,9 +65,11 @@ public async Task InvokeAllAsyncDoesNotWriteToDisconnectedConnectionsOutput() await manager.InvokeAllAsync("Hello", new object[] { "World" }).OrTimeout(); - AssertMessage(output1); + await AssertMessageAsync(client1); - Assert.False(output2.Reader.TryRead(out var item)); + await connection1.DisposeAsync().OrTimeout(); + await connection2.DisposeAsync().OrTimeout(); + Assert.Null(client2.TryRead()); } } @@ -80,16 +79,13 @@ public async Task InvokeGroupAsyncWritesToAllConnectionsInGroupOutput() using (var client1 = new TestClient()) using (var client2 = new TestClient()) { - var output1 = Channel.CreateUnbounded(); - var output2 = Channel.CreateUnbounded(); - var manager = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() { Factory = t => new TestConnectionMultiplexer() })); - var connection1 = new HubConnectionContext(output1, client1.Connection); - var connection2 = new HubConnectionContext(output2, client2.Connection); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); await manager.OnConnectedAsync(connection1).OrTimeout(); await manager.OnConnectedAsync(connection2).OrTimeout(); @@ -98,9 +94,11 @@ public async Task InvokeGroupAsyncWritesToAllConnectionsInGroupOutput() await manager.InvokeGroupAsync("gunit", "Hello", new object[] { "World" }).OrTimeout(); - AssertMessage(output1); + await AssertMessageAsync(client1); - Assert.False(output2.Reader.TryRead(out var item)); + await connection1.DisposeAsync().OrTimeout(); + await connection2.DisposeAsync().OrTimeout(); + Assert.Null(client2.TryRead()); } } @@ -109,19 +107,20 @@ public async Task InvokeConnectionAsyncWritesToConnectionOutput() { using (var client = new TestClient()) { - var output = Channel.CreateUnbounded(); var manager = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() { Factory = t => new TestConnectionMultiplexer() })); - var connection = new HubConnectionContext(output, client.Connection); + var connection = HubConnectionContextUtils.Create(client.Connection); await manager.OnConnectedAsync(connection).OrTimeout(); await manager.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout(); - AssertMessage(output); + await connection.DisposeAsync().OrTimeout(); + + await AssertMessageAsync(client); } } @@ -153,19 +152,16 @@ public async Task InvokeAllAsyncWithMultipleServersWritesToAllConnectionsOutput( using (var client1 = new TestClient()) using (var client2 = new TestClient()) { - var output1 = Channel.CreateUnbounded(); - var output2 = Channel.CreateUnbounded(); - - var connection1 = new HubConnectionContext(output1, client1.Connection); - var connection2 = new HubConnectionContext(output2, client2.Connection); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); await manager1.OnConnectedAsync(connection1).OrTimeout(); await manager2.OnConnectedAsync(connection2).OrTimeout(); await manager1.InvokeAllAsync("Hello", new object[] { "World" }).OrTimeout(); - AssertMessage(output1); - AssertMessage(output2); + await AssertMessageAsync(client1); + await AssertMessageAsync(client2); } } @@ -186,11 +182,8 @@ public async Task InvokeAllAsyncWithMultipleServersDoesNotWriteToDisconnectedCon using (var client1 = new TestClient()) using (var client2 = new TestClient()) { - var output1 = Channel.CreateUnbounded(); - var output2 = Channel.CreateUnbounded(); - - var connection1 = new HubConnectionContext(output1, client1.Connection); - var connection2 = new HubConnectionContext(output2, client2.Connection); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); await manager1.OnConnectedAsync(connection1).OrTimeout(); await manager2.OnConnectedAsync(connection2).OrTimeout(); @@ -199,9 +192,11 @@ public async Task InvokeAllAsyncWithMultipleServersDoesNotWriteToDisconnectedCon await manager2.InvokeAllAsync("Hello", new object[] { "World" }).OrTimeout(); - AssertMessage(output1); + await AssertMessageAsync(client1); - Assert.False(output2.Reader.TryRead(out var item)); + await connection1.DisposeAsync().OrTimeout(); + await connection2.DisposeAsync().OrTimeout(); + Assert.Null(client2.TryRead()); } } @@ -221,15 +216,13 @@ public async Task InvokeConnectionAsyncOnServerWithoutConnectionWritesOutputToCo using (var client = new TestClient()) { - var output = Channel.CreateUnbounded(); - - var connection = new HubConnectionContext(output, client.Connection); + var connection = HubConnectionContextUtils.Create(client.Connection); await manager1.OnConnectedAsync(connection).OrTimeout(); await manager2.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout(); - AssertMessage(output); + await AssertMessageAsync(client); } } @@ -249,9 +242,7 @@ public async Task InvokeGroupAsyncOnServerWithoutConnectionWritesOutputToGroupCo using (var client = new TestClient()) { - var output = Channel.CreateUnbounded(); - - var connection = new HubConnectionContext(output, client.Connection); + var connection = HubConnectionContextUtils.Create(client.Connection); await manager1.OnConnectedAsync(connection).OrTimeout(); @@ -259,7 +250,7 @@ public async Task InvokeGroupAsyncOnServerWithoutConnectionWritesOutputToGroupCo await manager2.InvokeGroupAsync("name", "Hello", new object[] { "World" }).OrTimeout(); - AssertMessage(output); + await AssertMessageAsync(client); } } @@ -274,9 +265,7 @@ public async Task DisconnectConnectionRemovesConnectionFromGroup() using (var client = new TestClient()) { - var output = Channel.CreateUnbounded(); - - var connection = new HubConnectionContext(output, client.Connection); + var connection = HubConnectionContextUtils.Create(client.Connection); await manager.OnConnectedAsync(connection).OrTimeout(); @@ -286,7 +275,8 @@ public async Task DisconnectConnectionRemovesConnectionFromGroup() await manager.InvokeGroupAsync("name", "Hello", new object[] { "World" }).OrTimeout(); - Assert.False(output.Reader.TryRead(out var item)); + await connection.DisposeAsync().OrTimeout(); + Assert.Null(client.TryRead()); } } @@ -301,9 +291,7 @@ public async Task RemoveGroupFromLocalConnectionNotInGroupDoesNothing() using (var client = new TestClient()) { - var output = Channel.CreateUnbounded(); - - var connection = new HubConnectionContext(output, client.Connection); + var connection = HubConnectionContextUtils.Create(client.Connection); await manager.OnConnectedAsync(connection).OrTimeout(); @@ -327,9 +315,7 @@ public async Task RemoveGroupFromConnectionOnDifferentServerNotInGroupDoesNothin using (var client = new TestClient()) { - var output = Channel.CreateUnbounded(); - - var connection = new HubConnectionContext(output, client.Connection); + var connection = HubConnectionContextUtils.Create(client.Connection); await manager1.OnConnectedAsync(connection).OrTimeout(); @@ -351,9 +337,7 @@ public async Task AddGroupAsyncForConnectionOnDifferentServerWorks() using (var client = new TestClient()) { - var output = Channel.CreateUnbounded(); - - var connection = new HubConnectionContext(output, client.Connection); + var connection = HubConnectionContextUtils.Create(client.Connection); await manager1.OnConnectedAsync(connection).OrTimeout(); @@ -361,7 +345,9 @@ public async Task AddGroupAsyncForConnectionOnDifferentServerWorks() await manager2.InvokeGroupAsync("name", "Hello", new object[] { "World" }).OrTimeout(); - AssertMessage(output); + await connection.DisposeAsync().OrTimeout(); + + await AssertMessageAsync(client); } } @@ -375,9 +361,7 @@ public async Task AddGroupAsyncForLocalConnectionAlreadyInGroupDoesNothing() using (var client = new TestClient()) { - var output = Channel.CreateUnbounded(); - - var connection = new HubConnectionContext(output, client.Connection); + var connection = HubConnectionContextUtils.Create(client.Connection); await manager.OnConnectedAsync(connection).OrTimeout(); @@ -386,8 +370,11 @@ public async Task AddGroupAsyncForLocalConnectionAlreadyInGroupDoesNothing() await manager.InvokeGroupAsync("name", "Hello", new object[] { "World" }).OrTimeout(); - AssertMessage(output); - Assert.False(output.Reader.TryRead(out var item)); + await connection.DisposeAsync().OrTimeout(); + + await AssertMessageAsync(client); + await connection.DisposeAsync().OrTimeout(); + Assert.Null(client.TryRead()); } } @@ -405,9 +392,7 @@ public async Task AddGroupAsyncForConnectionOnDifferentServerAlreadyInGroupDoesN using (var client = new TestClient()) { - var output = Channel.CreateUnbounded(); - - var connection = new HubConnectionContext(output, client.Connection); + var connection = HubConnectionContextUtils.Create(client.Connection); await manager1.OnConnectedAsync(connection).OrTimeout(); @@ -416,8 +401,11 @@ public async Task AddGroupAsyncForConnectionOnDifferentServerAlreadyInGroupDoesN await manager2.InvokeGroupAsync("name", "Hello", new object[] { "World" }).OrTimeout(); - AssertMessage(output); - Assert.False(output.Reader.TryRead(out var item)); + await connection.DisposeAsync().OrTimeout(); + + await AssertMessageAsync(client); + await connection.DisposeAsync().OrTimeout(); + Assert.Null(client.TryRead()); } } @@ -435,9 +423,7 @@ public async Task RemoveGroupAsyncForConnectionOnDifferentServerWorks() using (var client = new TestClient()) { - var output = Channel.CreateUnbounded(); - - var connection = new HubConnectionContext(output, client.Connection); + var connection = HubConnectionContextUtils.Create(client.Connection); await manager1.OnConnectedAsync(connection).OrTimeout(); @@ -445,13 +431,14 @@ public async Task RemoveGroupAsyncForConnectionOnDifferentServerWorks() await manager2.InvokeGroupAsync("name", "Hello", new object[] { "World" }).OrTimeout(); - AssertMessage(output); + await AssertMessageAsync(client); await manager2.RemoveGroupAsync(connection.ConnectionId, "name").OrTimeout(); await manager2.InvokeGroupAsync("name", "Hello", new object[] { "World" }).OrTimeout(); - Assert.False(output.Reader.TryRead(out var item)); + await connection.DisposeAsync().OrTimeout(); + Assert.Null(client.TryRead()); } } @@ -469,9 +456,7 @@ public async Task InvokeConnectionAsyncForLocalConnectionDoesNotPublishToRedis() using (var client = new TestClient()) { - var output = Channel.CreateUnbounded(); - - var connection = new HubConnectionContext(output, client.Connection); + var connection = HubConnectionContextUtils.Create(client.Connection); // Add connection to both "servers" to see if connection receives message twice await manager1.OnConnectedAsync(connection).OrTimeout(); @@ -479,8 +464,10 @@ public async Task InvokeConnectionAsyncForLocalConnectionDoesNotPublishToRedis() await manager1.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout(); - AssertMessage(output); - Assert.False(output.Reader.TryRead(out var item)); + await connection.DisposeAsync().OrTimeout(); + + await AssertMessageAsync(client); + Assert.Null(client.TryRead()); } } @@ -502,7 +489,7 @@ public async Task WritingToRemoteConnectionThatFailsDoesNotThrow() var writer = new Mock>(); writer.Setup(o => o.WaitToWriteAsync(It.IsAny())).Throws(new Exception()); - var connection = new HubConnectionContext(new MockChannel(writer.Object), client.Connection); + var connection = HubConnectionContextUtils.Create(client.Connection, new MockChannel(writer.Object)); await manager2.OnConnectedAsync(connection).OrTimeout(); @@ -526,7 +513,7 @@ public async Task WritingToLocalConnectionThatFailsThrowsException() var writer = new Mock>(); writer.Setup(o => o.WaitToWriteAsync(It.IsAny())).Throws(new Exception("Message")); - var connection = new HubConnectionContext(new MockChannel(writer.Object), client.Connection); + var connection = HubConnectionContextUtils.Create(client.Connection, new MockChannel(writer.Object)); await manager.OnConnectedAsync(connection).OrTimeout(); @@ -546,14 +533,12 @@ public async Task WritingToGroupWithOneConnectionFailingSecondConnectionStillRec using (var client1 = new TestClient()) using (var client2 = new TestClient()) { - var output2 = Channel.CreateUnbounded(); - // Force an exception when writing to connection var writer = new Mock>(); writer.Setup(o => o.WaitToWriteAsync(It.IsAny())).Throws(new Exception()); - var connection1 = new HubConnectionContext(new MockChannel(writer.Object), client1.Connection); - var connection2 = new HubConnectionContext(output2, client2.Connection); + var connection1 = HubConnectionContextUtils.Create(client1.Connection, new MockChannel(writer.Object)); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); await manager.OnConnectedAsync(connection1).OrTimeout(); await manager.AddGroupAsync(connection1.ConnectionId, "group"); @@ -563,18 +548,17 @@ public async Task WritingToGroupWithOneConnectionFailingSecondConnectionStillRec await manager.InvokeGroupAsync("group", "Hello", new object[] { "World" }).OrTimeout(); // connection1 will throw when receiving a group message, we are making sure other connections // are not affected by another connection throwing - AssertMessage(output2); + await AssertMessageAsync(client2); // Repeat to check that group can still be sent to await manager.InvokeGroupAsync("group", "Hello", new object[] { "World" }).OrTimeout(); - AssertMessage(output2); + await AssertMessageAsync(client2); } } - private void AssertMessage(Channel channel) + private async Task AssertMessageAsync(TestClient client) { - Assert.True(channel.Reader.TryRead(out var item)); - var message = Assert.IsType(item); + var message = Assert.IsType(await client.ReadAsync()); Assert.Equal("Hello", message.Target); Assert.Single(message.Arguments); Assert.Equal("World", (string)message.Arguments[0]); diff --git a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisServerFixture.cs b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisServerFixture.cs index 8697cef365..a125e6e648 100644 --- a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisServerFixture.cs +++ b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisServerFixture.cs @@ -2,7 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using Microsoft.AspNetCore.SignalR.Tests.Common; +using Microsoft.AspNetCore.SignalR.Tests; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Testing; @@ -61,4 +61,4 @@ public void Dispose() } } } -} \ No newline at end of file +} diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/ChannelExtensions.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/ChannelExtensions.cs new file mode 100644 index 0000000000..bf4afa8979 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/ChannelExtensions.cs @@ -0,0 +1,28 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Threading.Tasks; + +namespace System.Threading.Channels +{ + public static class ChannelExtensions + { + public static async Task> ReadAllAsync(this ChannelReader channel) + { + var list = new List(); + while (await channel.WaitToReadAsync()) + { + while (channel.TryRead(out var item)) + { + list.Add(item); + } + } + + // Manifest any error from channel.Completion (which should be completed now) + await channel.Completion; + + return list; + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/HubConnectionContextUtils.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/HubConnectionContextUtils.cs new file mode 100644 index 0000000000..66fc1f6fe9 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/HubConnectionContextUtils.cs @@ -0,0 +1,30 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Channels; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Encoders; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.AspNetCore.Sockets; +using Microsoft.Extensions.Logging.Abstractions; + +namespace Microsoft.AspNetCore.SignalR.Tests +{ + public static class HubConnectionContextUtils + { + public static HubConnectionContext Create(DefaultConnectionContext connection, Channel replacementOutput = null) + { + var context = new HubConnectionContext(connection, TimeSpan.FromSeconds(15), NullLoggerFactory.Instance); + if (replacementOutput != null) + { + context.Output = replacementOutput; + } + context.ProtocolReaderWriter = new HubProtocolReaderWriter(new JsonHubProtocol(), new PassThroughEncoder()); + + _ = context.StartAsync(); + + return context; + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/Microsoft.AspNetCore.SignalR.Tests.Utils.csproj b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/Microsoft.AspNetCore.SignalR.Tests.Utils.csproj new file mode 100644 index 0000000000..1e12b690f7 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/Microsoft.AspNetCore.SignalR.Tests.Utils.csproj @@ -0,0 +1,20 @@ + + + + $(StandardTestTfms) + Microsoft.AspNetCore.SignalR.Tests + + + + + + + + + + + + + + + diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/ServerFixture.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/ServerFixture.cs new file mode 100644 index 0000000000..3b271b509b --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/ServerFixture.cs @@ -0,0 +1,111 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Net; +using System.Net.Sockets; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Testing; + +namespace Microsoft.AspNetCore.SignalR.Tests +{ + public class ServerFixture : IDisposable + where TStartup : class + { + private readonly ILoggerFactory _loggerFactory; + private readonly ILogger _logger; + private IWebHost _host; + private IApplicationLifetime _lifetime; + private readonly IDisposable _logToken; + + public string WebSocketsUrl => Url.Replace("http", "ws"); + + public string Url { get; private set; } + + public ServerFixture() + { + var testLog = AssemblyTestLog.ForAssembly(typeof(ServerFixture).Assembly); + _logToken = testLog.StartTestLog(null, $"{nameof(ServerFixture)}_{typeof(TStartup).Name}", out _loggerFactory, "ServerFixture"); + _logger = _loggerFactory.CreateLogger>(); + Url = "http://localhost:" + GetNextPort(); + + StartServer(Url); + } + + private void StartServer(string url) + { + _host = new WebHostBuilder() + .ConfigureLogging(builder => builder.AddProvider(new ForwardingLoggerProvider(_loggerFactory))) + .UseStartup(typeof(TStartup)) + .UseKestrel() + .UseUrls(url) + .UseContentRoot(Directory.GetCurrentDirectory()) + .Build(); + + var t = Task.Run(() => _host.Start()); + _logger.LogInformation("Starting test server..."); + _lifetime = _host.Services.GetRequiredService(); + if (!_lifetime.ApplicationStarted.WaitHandle.WaitOne(TimeSpan.FromSeconds(5))) + { + // t probably faulted + if (t.IsFaulted) + { + throw t.Exception.InnerException; + } + throw new TimeoutException("Timed out waiting for application to start."); + } + _logger.LogInformation("Test Server started"); + + _lifetime.ApplicationStopped.Register(() => + { + _logger.LogInformation("Test server shut down"); + _logToken.Dispose(); + }); + } + + public void Dispose() + { + _logger.LogInformation("Shutting down test server"); + _host.Dispose(); + _loggerFactory.Dispose(); + } + + private class ForwardingLoggerProvider : ILoggerProvider + { + private readonly ILoggerFactory _loggerFactory; + + public ForwardingLoggerProvider(ILoggerFactory loggerFactory) + { + _loggerFactory = loggerFactory; + } + + public void Dispose() + { + } + + public ILogger CreateLogger(string categoryName) + { + return _loggerFactory.CreateLogger(categoryName); + } + } + + // Copied from https://github.com/aspnet/KestrelHttpServer/blob/47f1db20e063c2da75d9d89653fad4eafe24446c/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/AddressRegistrationTests.cs#L508 + private static int GetNextPort() + { + using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + // Let the OS assign the next available port. Unless we cycle through all ports + // on a test run, the OS will always increment the port number when making these calls. + // This prevents races in parallel test runs where a test is already bound to + // a given port, and a new test is able to bind to the same port due to port + // reuse being enabled by default by the OS. + socket.Bind(new IPEndPoint(IPAddress.Loopback, 0)); + return ((IPEndPoint)socket.LocalEndPoint).Port; + } + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TaskExtensions.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TaskExtensions.cs new file mode 100644 index 0000000000..734571d404 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TaskExtensions.cs @@ -0,0 +1,57 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Diagnostics; +using System.Runtime.CompilerServices; + +namespace System.Threading.Tasks +{ + public static class TaskExtensions + { + private const int DefaultTimeout = 5000; + + public static Task OrTimeout(this Task task, int milliseconds = DefaultTimeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) + { + return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds), memberName, filePath, lineNumber); + } + + public static async Task OrTimeout(this Task task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) + { + var completed = await Task.WhenAny(task, Task.Delay(Debugger.IsAttached ? Timeout.InfiniteTimeSpan : timeout)); + if (completed != task) + { + throw new TimeoutException(GetMessage(memberName, filePath, lineNumber)); + } + + await task; + } + + public static Task OrTimeout(this Task task, int milliseconds = DefaultTimeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) + { + return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds), memberName, filePath, lineNumber); + } + + public static async Task OrTimeout(this Task task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) + { + var completed = await Task.WhenAny(task, Task.Delay(Debugger.IsAttached ? Timeout.InfiniteTimeSpan : timeout)); + if (completed != task) + { + throw new TimeoutException(GetMessage(memberName, filePath, lineNumber)); + } + + return await task; + } + + private static string GetMessage(string memberName, string filePath, int? lineNumber) + { + if (!string.IsNullOrEmpty(memberName)) + { + return $"Operation in {memberName} timed out at {filePath}:{lineNumber}"; + } + else + { + return "Operation timed out"; + } + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs new file mode 100644 index 0000000000..2afc8886e8 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs @@ -0,0 +1,213 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Security.Claims; +using System.Threading; +using System.Threading.Tasks; +using System.Threading.Channels; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Encoders; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.AspNetCore.Sockets; +using Microsoft.AspNetCore.Sockets.Internal; +using Newtonsoft.Json; + +namespace Microsoft.AspNetCore.SignalR.Tests +{ + public class TestClient : IDisposable + { + private static int _id; + private readonly HubProtocolReaderWriter _protocolReaderWriter; + private readonly IInvocationBinder _invocationBinder; + private CancellationTokenSource _cts; + private ChannelConnection _transport; + + public DefaultConnectionContext Connection { get; } + public Channel Application { get; } + public Task Connected => ((TaskCompletionSource)Connection.Metadata["ConnectedTask"]).Task; + + public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null, IInvocationBinder invocationBinder = null, bool addClaimId = false) + { + var options = new UnboundedChannelOptions { AllowSynchronousContinuations = synchronousCallbacks }; + var transportToApplication = Channel.CreateUnbounded(options); + var applicationToTransport = Channel.CreateUnbounded(options); + + Application = ChannelConnection.Create(input: applicationToTransport, output: transportToApplication); + _transport = ChannelConnection.Create(input: transportToApplication, output: applicationToTransport); + + Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), _transport, Application); + + var claimValue = Interlocked.Increment(ref _id).ToString(); + var claims = new List{ new Claim(ClaimTypes.Name, claimValue) }; + if (addClaimId) + { + claims.Add(new Claim(ClaimTypes.NameIdentifier, claimValue)); + } + + Connection.User = new ClaimsPrincipal(new ClaimsIdentity(claims)); + Connection.Metadata["ConnectedTask"] = new TaskCompletionSource(); + + protocol = protocol ?? new JsonHubProtocol(); + _protocolReaderWriter = new HubProtocolReaderWriter(protocol, new PassThroughEncoder()); + _invocationBinder = invocationBinder ?? new DefaultInvocationBinder(); + + _cts = new CancellationTokenSource(); + + using (var memoryStream = new MemoryStream()) + { + NegotiationProtocol.WriteMessage(new NegotiationMessage(protocol.Name), memoryStream); + Application.Writer.TryWrite(memoryStream.ToArray()); + } + } + + public async Task> StreamAsync(string methodName, params object[] args) + { + var invocationId = await SendStreamInvocationAsync(methodName, args); + + var messages = new List(); + while (true) + { + var message = await ReadAsync(); + + if (message == null) + { + throw new InvalidOperationException("Connection aborted!"); + } + + if (message is HubInvocationMessage hubInvocationMessage && !string.Equals(hubInvocationMessage.InvocationId, invocationId)) + { + throw new NotSupportedException("TestClient does not support multiple outgoing invocations!"); + } + + switch (message) + { + case StreamItemMessage _: + messages.Add(message); + break; + case CompletionMessage _: + messages.Add(message); + return messages; + default: + throw new NotSupportedException("TestClient does not support receiving invocations!"); + } + } + } + + public async Task InvokeAsync(string methodName, params object[] args) + { + var invocationId = await SendInvocationAsync(methodName, nonBlocking: false, args: args); + + while (true) + { + var message = await ReadAsync(); + + if (message == null) + { + throw new InvalidOperationException("Connection aborted!"); + } + + if (message is HubInvocationMessage hubInvocationMessage && !string.Equals(hubInvocationMessage.InvocationId, invocationId)) + { + throw new NotSupportedException("TestClient does not support multiple outgoing invocations!"); + } + + switch (message) + { + case StreamItemMessage result: + throw new NotSupportedException("Use 'StreamAsync' to call a streaming method"); + case CompletionMessage completion: + return completion; + case PingMessage _: + // Pings are ignored + break; + default: + throw new NotSupportedException("TestClient does not support receiving invocations!"); + } + } + } + + public Task SendInvocationAsync(string methodName, params object[] args) + { + return SendInvocationAsync(methodName, nonBlocking: false, args: args); + } + + public Task SendInvocationAsync(string methodName, bool nonBlocking, params object[] args) + { + var invocationId = GetInvocationId(); + return SendHubMessageAsync(new InvocationMessage(invocationId, nonBlocking, methodName, + argumentBindingException: null, arguments: args)); + } + + public Task SendStreamInvocationAsync(string methodName, params object[] args) + { + var invocationId = GetInvocationId(); + return SendHubMessageAsync(new StreamInvocationMessage(invocationId, methodName, + argumentBindingException: null, arguments: args)); + } + + public async Task SendHubMessageAsync(HubMessage message) + { + var payload = _protocolReaderWriter.WriteMessage(message); + await Application.Writer.WriteAsync(payload); + return message is HubInvocationMessage hubMessage ? hubMessage.InvocationId : null; + } + + public async Task ReadAsync() + { + while (true) + { + var message = TryRead(); + + if (message == null) + { + if (!await Application.Reader.WaitToReadAsync()) + { + return null; + } + } + else + { + return message; + } + } + } + + public HubMessage TryRead() + { + if (Application.Reader.TryRead(out var buffer) && + _protocolReaderWriter.ReadMessages(buffer, _invocationBinder, out var messages)) + { + return messages[0]; + } + return null; + } + + public void Dispose() + { + _cts.Cancel(); + _transport.Dispose(); + } + + private static string GetInvocationId() + { + return Guid.NewGuid().ToString("N"); + } + + private class DefaultInvocationBinder : IInvocationBinder + { + public Type[] GetParameterTypes(string methodName) + { + // TODO: Possibly support actual client methods + return new[] { typeof(object) }; + } + + public Type GetReturnType(string invocationId) + { + return typeof(object); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestHelpers.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestHelpers.cs new file mode 100644 index 0000000000..17f037d178 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestHelpers.cs @@ -0,0 +1,24 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.SignalR.Tests +{ + public static class TestHelpers + { + public static bool IsWebSocketsSupported() + { + try + { + new System.Net.WebSockets.ClientWebSocket().Dispose(); + } + catch (PlatformNotSupportedException) + { + return false; + } + + return true; + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs index 7065eadd9d..7a765591cc 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs @@ -1,9 +1,8 @@ using System; using System.Threading; -using System.Threading.Tasks; using System.Threading.Channels; +using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Internal.Protocol; -using Microsoft.AspNetCore.SignalR.Tests.Common; using Moq; using Xunit; @@ -17,26 +16,24 @@ public async Task InvokeAllAsyncWritesToAllConnectionsOutput() using (var client1 = new TestClient()) using (var client2 = new TestClient()) { - var output1 = Channel.CreateUnbounded(); - var output2 = Channel.CreateUnbounded(); - var manager = new DefaultHubLifetimeManager(); - var connection1 = new HubConnectionContext(output1, client1.Connection); - var connection2 = new HubConnectionContext(output2, client2.Connection); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); await manager.OnConnectedAsync(connection1).OrTimeout(); await manager.OnConnectedAsync(connection2).OrTimeout(); await manager.InvokeAllAsync("Hello", new object[] { "World" }).OrTimeout(); - Assert.True(output1.Reader.TryRead(out var item)); - var message = Assert.IsType(item); + await connection1.DisposeAsync().OrTimeout(); + await connection2.DisposeAsync().OrTimeout(); + + var message = Assert.IsType(client1.TryRead()); Assert.Equal("Hello", message.Target); Assert.Single(message.Arguments); Assert.Equal("World", (string)message.Arguments[0]); - Assert.True(output2.Reader.TryRead(out item)); - message = Assert.IsType(item); + message = Assert.IsType(client2.TryRead()); Assert.Equal("Hello", message.Target); Assert.Single(message.Arguments); Assert.Equal("World", (string)message.Arguments[0]); @@ -49,12 +46,9 @@ public async Task InvokeAllAsyncDoesNotWriteToDisconnectedConnectionsOutput() using (var client1 = new TestClient()) using (var client2 = new TestClient()) { - var output1 = Channel.CreateUnbounded(); - var output2 = Channel.CreateUnbounded(); - var manager = new DefaultHubLifetimeManager(); - var connection1 = new HubConnectionContext(output1, client1.Connection); - var connection2 = new HubConnectionContext(output2, client2.Connection); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); await manager.OnConnectedAsync(connection1).OrTimeout(); await manager.OnConnectedAsync(connection2).OrTimeout(); @@ -63,13 +57,15 @@ public async Task InvokeAllAsyncDoesNotWriteToDisconnectedConnectionsOutput() await manager.InvokeAllAsync("Hello", new object[] { "World" }).OrTimeout(); - Assert.True(output1.Reader.TryRead(out var item)); - var message = Assert.IsType(item); + await connection1.DisposeAsync().OrTimeout(); + await connection2.DisposeAsync().OrTimeout(); + + var message = Assert.IsType(client1.TryRead()); Assert.Equal("Hello", message.Target); Assert.Single(message.Arguments); Assert.Equal("World", (string)message.Arguments[0]); - Assert.False(output2.Reader.TryRead(out item)); + Assert.Null(client2.TryRead()); } } @@ -79,12 +75,9 @@ public async Task InvokeGroupAsyncWritesToAllConnectionsInGroupOutput() using (var client1 = new TestClient()) using (var client2 = new TestClient()) { - var output1 = Channel.CreateUnbounded(); - var output2 = Channel.CreateUnbounded(); - var manager = new DefaultHubLifetimeManager(); - var connection1 = new HubConnectionContext(output1, client1.Connection); - var connection2 = new HubConnectionContext(output2, client2.Connection); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); await manager.OnConnectedAsync(connection1).OrTimeout(); await manager.OnConnectedAsync(connection2).OrTimeout(); @@ -93,13 +86,15 @@ public async Task InvokeGroupAsyncWritesToAllConnectionsInGroupOutput() await manager.InvokeGroupAsync("gunit", "Hello", new object[] { "World" }).OrTimeout(); - Assert.True(output1.Reader.TryRead(out var item)); - var message = Assert.IsType(item); + await connection1.DisposeAsync().OrTimeout(); + await connection2.DisposeAsync().OrTimeout(); + + var message = Assert.IsType(client1.TryRead()); Assert.Equal("Hello", message.Target); Assert.Single(message.Arguments); Assert.Equal("World", (string)message.Arguments[0]); - Assert.False(output2.Reader.TryRead(out item)); + Assert.Null(client2.TryRead()); } } @@ -108,16 +103,16 @@ public async Task InvokeConnectionAsyncWritesToConnectionOutput() { using (var client = new TestClient()) { - var output = Channel.CreateUnbounded(); var manager = new DefaultHubLifetimeManager(); - var connection = new HubConnectionContext(output, client.Connection); + var connection = HubConnectionContextUtils.Create(client.Connection); await manager.OnConnectedAsync(connection).OrTimeout(); await manager.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout(); - Assert.True(output.Reader.TryRead(out var item)); - var message = Assert.IsType(item); + await connection.DisposeAsync().OrTimeout(); + + var message = Assert.IsType(client.TryRead()); Assert.Equal("Hello", message.Target); Assert.Single(message.Arguments); Assert.Equal("World", (string)message.Arguments[0]); @@ -134,7 +129,7 @@ public async Task InvokeConnectionAsyncThrowsIfConnectionFailsToWrite() writer.Setup(o => o.WaitToWriteAsync(It.IsAny())).Throws(new Exception("Message")); var manager = new DefaultHubLifetimeManager(); - var connection = new HubConnectionContext(new MockChannel(writer.Object), client.Connection); + var connection = HubConnectionContextUtils.Create(client.Connection, new MockChannel(writer.Object)); await manager.OnConnectedAsync(connection).OrTimeout(); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultTransportFactoryTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultTransportFactoryTests.cs index ff127785f5..e8a93fd4f6 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultTransportFactoryTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultTransportFactoryTests.cs @@ -1,9 +1,8 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Net.Http; -using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.AspNetCore.Testing.xunit; diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs index 5195cebef6..7af4600b24 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs @@ -9,7 +9,6 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Client; -using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.AspNetCore.Sockets.Client.Http; diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index fd237c0ec9..d63fc03e27 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -14,7 +14,6 @@ using Microsoft.AspNetCore.SignalR.Core; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; -using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Features; using Microsoft.Extensions.DependencyInjection; @@ -1380,6 +1379,85 @@ public async Task AcceptsPingMessages() } } + [Fact] + public async Task DoesNotWritePingMessagesIfSufficientOtherMessagesAreSent() + { + var serviceProvider = CreateServiceProvider(services => + services.Configure(options => + options.KeepAliveInterval = TimeSpan.FromMilliseconds(100))); + var endPoint = serviceProvider.GetService>(); + + using (var client = new TestClient(false, new JsonHubProtocol())) + { + var endPointLifetime = endPoint.OnConnectedAsync(client.Connection).OrTimeout(); + + await client.Connected.OrTimeout(); + + // Echo a bunch of stuff, waiting 10ms between each, until 500ms have elapsed + DateTime start = DateTime.UtcNow; + while ((DateTime.UtcNow - start).TotalMilliseconds <= 500.0) + { + await client.SendInvocationAsync("Echo", "foo").OrTimeout(); + await Task.Delay(10); + } + + // Shut down + client.Dispose(); + + await endPointLifetime.OrTimeout(); + + // We shouldn't have any ping messages + HubMessage message; + var counter = 0; + while ((message = await client.ReadAsync()) != null) + { + counter += 1; + Assert.IsNotType(message); + } + Assert.InRange(counter, 1, 50); + } + } + + [Fact] + public async Task WritesPingMessageIfNothingWrittenWhenKeepAliveIntervalElapses() + { + var serviceProvider = CreateServiceProvider(services => + services.Configure(options => + options.KeepAliveInterval = TimeSpan.FromMilliseconds(100))); + var endPoint = serviceProvider.GetService>(); + + using (var client = new TestClient(false, new JsonHubProtocol())) + { + var endPointLifetime = endPoint.OnConnectedAsync(client.Connection).OrTimeout(); + await client.Connected.OrTimeout(); + + // Wait 500 ms, but make sure to yield some time up to unblock concurrent threads + // This is useful on AppVeyor because it's slow enough to end up with no time + // being available for the endpoint to run. + for (var i = 0; i < 50; i += 1) + { + client.Connection.TickHeartbeat(); + await Task.Yield(); + await Task.Delay(10); + } + + // Shut down + client.Dispose(); + + await endPointLifetime.OrTimeout(); + + // We should have all pings + HubMessage message; + var counter = 0; + while ((message = await client.ReadAsync().OrTimeout()) != null) + { + counter += 1; + Assert.Same(PingMessage.Instance, message); + } + Assert.InRange(counter, 1, 10); + } + } + private static void AssertHubMessage(HubMessage expected, HubMessage actual) { // We aren't testing InvocationIds here diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs index dba1123dd3..fb38adaf9f 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs @@ -7,6 +7,7 @@ using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; +using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Options; using Moq; using Newtonsoft.Json; @@ -20,7 +21,7 @@ public class DefaultHubProtocolResolverTests [MemberData(nameof(HubProtocols))] public void DefaultHubProtocolResolverTestsCanCreateSupportedProtocols(IHubProtocol protocol) { - var mockConnection = new Mock(Channel.CreateUnbounded().Writer, new Mock().Object); + var mockConnection = new Mock(new Mock().Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance); Assert.IsType( protocol.GetType(), new DefaultHubProtocolResolver(Options.Create(new HubOptions())).GetProtocol(protocol.Name, mockConnection.Object)); @@ -31,7 +32,7 @@ public void DefaultHubProtocolResolverTestsCanCreateSupportedProtocols(IHubProto [InlineData("dummy")] public void DefaultHubProtocolResolverThrowsForNotSupportedProtocol(string protocolName) { - var mockConnection = new Mock(Channel.CreateUnbounded().Writer, new Mock().Object); + var mockConnection = new Mock(new Mock().Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance); var exception = Assert.Throws( () => new DefaultHubProtocolResolver(Options.Create(new HubOptions())).GetProtocol(protocolName, mockConnection.Object)); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/Microsoft.AspNetCore.SignalR.Tests.csproj b/test/Microsoft.AspNetCore.SignalR.Tests/Microsoft.AspNetCore.SignalR.Tests.csproj index 7eab31138a..2fa2733b0d 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/Microsoft.AspNetCore.SignalR.Tests.csproj +++ b/test/Microsoft.AspNetCore.SignalR.Tests/Microsoft.AspNetCore.SignalR.Tests.csproj @@ -15,30 +15,20 @@ - - - - - - - + - - - - diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs index 1644c86ccc..2845e0a6c2 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs @@ -2,9 +2,8 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Threading.Tasks; using System.Threading.Channels; -using Microsoft.AspNetCore.SignalR.Tests.Common; +using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.AspNetCore.Sockets.Internal; diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs index f2cb45f02c..dc631a2593 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs @@ -4,7 +4,6 @@ using System; using System.Threading.Tasks; using Microsoft.AspNetCore.Hosting; -using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.Extensions.Logging; using Xunit; diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index df36eb8214..d490ea096e 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -14,7 +14,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Http.Internal; -using Microsoft.AspNetCore.SignalR.Tests.Common; +using Microsoft.AspNetCore.Sockets.Features; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Primitives; @@ -959,6 +959,32 @@ public async Task AuthorizedConnectionWithRejectedSchemesFailsToConnectToEndPoin Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode); } + [Fact] + public async Task SetsInherentKeepAliveFeatureOnFirstLongPollingRequest() + { + var manager = CreateConnectionManager(); + var connection = manager.CreateConnection(); + + var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + + var context = MakeRequest("/foo", connection); + + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + options.LongPolling.PollTimeout = TimeSpan.FromMilliseconds(1); // We don't care about the poll itself + + Assert.Null(connection.Features.Get()); + + await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); + + Assert.NotNull(connection.Features.Get()); + Assert.Equal(options.LongPolling.PollTimeout, connection.Features.Get().KeepAliveInterval); + } + private class RejectHandler : TestAuthenticationHandler { protected override bool ShouldAccept => false; diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs index 112314f6ac..b1824f34ad 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs @@ -1,13 +1,14 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.IO; using System.Text; using System.Threading; -using System.Threading.Tasks; using System.Threading.Channels; +using System.Threading.Tasks; using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.SignalR.Tests.Common; +using Microsoft.AspNetCore.Sockets.Features; using Microsoft.AspNetCore.Sockets.Internal.Transports; using Microsoft.Extensions.Logging; using Xunit; @@ -19,13 +20,16 @@ public class LongPollingTests [Fact] public async Task Set204StatusCodeWhenChannelComplete() { - var channel = Channel.CreateUnbounded(); + var toApplication = Channel.CreateUnbounded(); + var toTransport = Channel.CreateUnbounded(); var context = new DefaultHttpContext(); - var poll = new LongPollingTransport(CancellationToken.None, channel, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var connection = new DefaultConnectionContext("foo", toTransport, toApplication); + + var poll = new LongPollingTransport(CancellationToken.None, toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory()); - Assert.True(channel.Writer.TryComplete()); + Assert.True(toTransport.Writer.TryComplete()); - await poll.ProcessRequestAsync(context, context.RequestAborted); + await poll.ProcessRequestAsync(context, context.RequestAborted).OrTimeout(); Assert.Equal(204, context.Response.StatusCode); } @@ -33,10 +37,13 @@ public async Task Set204StatusCodeWhenChannelComplete() [Fact] public async Task Set200StatusCodeWhenTimeoutTokenFires() { - var channel = Channel.CreateUnbounded(); + var toApplication = Channel.CreateUnbounded(); + var toTransport = Channel.CreateUnbounded(); var context = new DefaultHttpContext(); + var connection = new DefaultConnectionContext("foo", toTransport, toApplication); + var timeoutToken = new CancellationToken(true); - var poll = new LongPollingTransport(timeoutToken, channel, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var poll = new LongPollingTransport(timeoutToken, toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory()); using (var cts = CancellationTokenSource.CreateLinkedTokenSource(timeoutToken, context.RequestAborted)) { @@ -50,17 +57,20 @@ public async Task Set200StatusCodeWhenTimeoutTokenFires() [Fact] public async Task FrameSentAsSingleResponse() { - var channel = Channel.CreateUnbounded(); + var toApplication = Channel.CreateUnbounded(); + var toTransport = Channel.CreateUnbounded(); var context = new DefaultHttpContext(); - var poll = new LongPollingTransport(CancellationToken.None, channel, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var connection = new DefaultConnectionContext("foo", toTransport, toApplication); + + var poll = new LongPollingTransport(CancellationToken.None, toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory()); var ms = new MemoryStream(); context.Response.Body = ms; - await channel.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello World")); + await toTransport.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello World")); - Assert.True(channel.Writer.TryComplete()); + Assert.True(toTransport.Writer.TryComplete()); - await poll.ProcessRequestAsync(context, context.RequestAborted); + await poll.ProcessRequestAsync(context, context.RequestAborted).OrTimeout(); Assert.Equal(200, context.Response.StatusCode); Assert.Equal("Hello World", Encoding.UTF8.GetString(ms.ToArray())); @@ -69,20 +79,22 @@ public async Task FrameSentAsSingleResponse() [Fact] public async Task MultipleFramesSentAsSingleResponse() { - var channel = Channel.CreateUnbounded(); + var toApplication = Channel.CreateUnbounded(); + var toTransport = Channel.CreateUnbounded(); var context = new DefaultHttpContext(); + var connection = new DefaultConnectionContext("foo", toTransport, toApplication); - var poll = new LongPollingTransport(CancellationToken.None, channel, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var poll = new LongPollingTransport(CancellationToken.None, toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory()); var ms = new MemoryStream(); context.Response.Body = ms; - await channel.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello")); - await channel.Writer.WriteAsync(Encoding.UTF8.GetBytes(" ")); - await channel.Writer.WriteAsync(Encoding.UTF8.GetBytes("World")); + await toTransport.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello")); + await toTransport.Writer.WriteAsync(Encoding.UTF8.GetBytes(" ")); + await toTransport.Writer.WriteAsync(Encoding.UTF8.GetBytes("World")); - Assert.True(channel.Writer.TryComplete()); + Assert.True(toTransport.Writer.TryComplete()); - await poll.ProcessRequestAsync(context, context.RequestAborted); + await poll.ProcessRequestAsync(context, context.RequestAborted).OrTimeout(); Assert.Equal(200, context.Response.StatusCode); diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/MapEndPointTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/MapEndPointTests.cs index 26260907a1..3db299c072 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/MapEndPointTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/MapEndPointTests.cs @@ -10,7 +10,6 @@ using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Hosting.Server.Features; -using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Testing.xunit; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj b/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj index acf1cf916c..71808e5822 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj +++ b/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj @@ -5,21 +5,15 @@ win7-x86 - - - - + - - - diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs index b7283440b4..cf60e1b028 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs @@ -3,11 +3,10 @@ using System.IO; using System.Text; -using System.Threading.Tasks; using System.Threading.Channels; +using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; -using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets.Internal.Transports; using Microsoft.Extensions.Logging; using Xunit; @@ -19,11 +18,14 @@ public class ServerSentEventsTests [Fact] public async Task SSESetsContentType() { - var channel = Channel.CreateUnbounded(); + var toApplication = Channel.CreateUnbounded(); + var toTransport = Channel.CreateUnbounded(); var context = new DefaultHttpContext(); - var sse = new ServerSentEventsTransport(channel, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var connection = new DefaultConnectionContext("foo", toTransport, toApplication); - Assert.True(channel.Writer.TryComplete()); + var sse = new ServerSentEventsTransport(toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + + Assert.True(toTransport.Writer.TryComplete()); await sse.ProcessRequestAsync(context, context.RequestAborted); @@ -34,13 +36,16 @@ public async Task SSESetsContentType() [Fact] public async Task SSETurnsResponseBufferingOff() { - var channel = Channel.CreateUnbounded(); + var toApplication = Channel.CreateUnbounded(); + var toTransport = Channel.CreateUnbounded(); var context = new DefaultHttpContext(); + var connection = new DefaultConnectionContext("foo", toTransport, toApplication); + var feature = new HttpBufferingFeature(); context.Features.Set(feature); - var sse = new ServerSentEventsTransport(channel, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var sse = new ServerSentEventsTransport(toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory()); - Assert.True(channel.Writer.TryComplete()); + Assert.True(toTransport.Writer.TryComplete()); await sse.ProcessRequestAsync(context, context.RequestAborted); @@ -50,23 +55,25 @@ public async Task SSETurnsResponseBufferingOff() [Fact] public async Task SSEWritesMessages() { - var channel = Channel.CreateUnbounded(new UnboundedChannelOptions + var toApplication = Channel.CreateUnbounded(); + var toTransport = Channel.CreateUnbounded(new UnboundedChannelOptions { AllowSynchronousContinuations = true }); - var context = new DefaultHttpContext(); + var connection = new DefaultConnectionContext("foo", toTransport, toApplication); + var ms = new MemoryStream(); context.Response.Body = ms; - var sse = new ServerSentEventsTransport(channel, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var sse = new ServerSentEventsTransport(toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory()); var task = sse.ProcessRequestAsync(context, context.RequestAborted); - await channel.Writer.WriteAsync(Encoding.ASCII.GetBytes("Hello")); + await toTransport.Writer.WriteAsync(Encoding.ASCII.GetBytes("Hello")); Assert.Equal(":\r\ndata: Hello\r\n\r\n", Encoding.ASCII.GetString(ms.ToArray())); - channel.Writer.TryComplete(); + toTransport.Writer.TryComplete(); await task.OrTimeout(); } @@ -77,15 +84,18 @@ public async Task SSEWritesMessages() [InlineData("Hello\r\nWorld", ":\r\ndata: Hello\r\ndata: World\r\n\r\n")] public async Task SSEAddsAppropriateFraming(string message, string expected) { - var channel = Channel.CreateUnbounded(); + var toApplication = Channel.CreateUnbounded(); + var toTransport = Channel.CreateUnbounded(); var context = new DefaultHttpContext(); - var sse = new ServerSentEventsTransport(channel, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var connection = new DefaultConnectionContext("foo", toTransport, toApplication); + + var sse = new ServerSentEventsTransport(toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory()); var ms = new MemoryStream(); context.Response.Body = ms; - await channel.Writer.WriteAsync(Encoding.UTF8.GetBytes(message)); + await toTransport.Writer.WriteAsync(Encoding.UTF8.GetBytes(message)); - Assert.True(channel.Writer.TryComplete()); + Assert.True(toTransport.Writer.TryComplete()); await sse.ProcessRequestAsync(context, context.RequestAborted); diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs index d93d653e9b..87a30d11f7 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs @@ -5,10 +5,9 @@ using System.Net.WebSockets; using System.Text; using System.Threading; -using System.Threading.Tasks; using System.Threading.Channels; +using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Internal; -using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.AspNetCore.Sockets.Internal.Transports; using Microsoft.Extensions.Logging.Testing;