From a73edfa697032124a1c45de16fa01a729aee53f1 Mon Sep 17 00:00:00 2001 From: Brennan Date: Wed, 8 Mar 2023 10:27:51 -0800 Subject: [PATCH 01/25] Initial --- .../csharp/Client.Core/src/HubConnection.cs | 72 +++++-- .../HttpConnectionTests.Negotiate.cs | 4 +- .../test/UnitTests/TestTransportFactory.cs | 2 +- .../src/HttpConnection.cs | 81 ++++++- .../src/Internal/DefaultTransportFactory.cs | 8 +- .../src/Internal/ITransportFactory.cs | 2 +- .../src/Internal/LongPollingTransport.cs | 31 ++- .../src/Internal/ServerSentEventsTransport.cs | 32 ++- .../src/Internal/WebSocketsTransport.cs | 129 +++++++++-- ....AspNetCore.Http.Connections.Client.csproj | 2 + .../src/NegotiateProtocol.cs | 15 +- .../src/NegotiationResponse.cs | 2 + .../src/PublicAPI.Unshipped.txt | 2 + .../src/Internal/HttpConnectionContext.cs | 5 +- .../src/Internal/HttpConnectionDispatcher.cs | 102 +++++---- .../src/Internal/HttpConnectionManager.cs | 21 +- .../Transports/WebSocketsServerTransport.cs | 65 +++++- ...crosoft.AspNetCore.Http.Connections.csproj | 2 + .../test/HttpConnectionDispatcherTests.cs | 2 +- .../Shared/AcknowledgePipe/DuplexPipe.cs | 93 ++++++++ .../common/Shared/AcknowledgePipeV2.cs | 201 ++++++++++++++++++ .../common/Shared/ParseAckPipeReader.cs | 108 ++++++++++ src/SignalR/samples/ClientSample/HubSample.cs | 5 +- src/SignalR/samples/SignalRSamples/Program.cs | 5 +- src/SignalR/samples/SignalRSamples/Startup.cs | 7 +- .../server/Core/src/HubConnectionHandler.cs | 13 ++ .../test/DefaultTransportFactoryTests.cs | 10 +- .../server/SignalR/test/EndToEndTests.cs | 2 +- 28 files changed, 900 insertions(+), 123 deletions(-) create mode 100644 src/SignalR/common/Shared/AcknowledgePipe/DuplexPipe.cs create mode 100644 src/SignalR/common/Shared/AcknowledgePipeV2.cs create mode 100644 src/SignalR/common/Shared/ParseAckPipeReader.cs diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index 797ecee716c8..78ce75e8a35c 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Buffers; using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; @@ -12,6 +13,7 @@ using System.Net; using System.Reflection; using System.Runtime.CompilerServices; +using System.Text; using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; @@ -461,10 +463,10 @@ public virtual async Task SendCoreAsync(string methodName, object?[] args, Cance } } - private async Task StartAsyncCore(CancellationToken cancellationToken) + private async Task StartAsyncCore(CancellationToken cancellationToken, bool sendHandshake = true) { _state.AssertInConnectionLock(); - SafeAssert(_state.CurrentConnectionStateUnsynchronized == null, "We already have a connection!"); + //SafeAssert(_state.CurrentConnectionStateUnsynchronized == null, "We already have a connection!"); cancellationToken.ThrowIfCancellationRequested(); @@ -472,24 +474,37 @@ private async Task StartAsyncCore(CancellationToken cancellationToken) Log.Starting(_logger); + if (_state.CurrentConnectionStateUnsynchronized is not null) + { + // public Task StartAsync(CancellationToken cancellationToken = default) + + var method = _state.CurrentConnectionStateUnsynchronized.Connection.GetType().GetMethod("StartAsync", new Type[] { typeof(CancellationToken) }); + await ((Task)method.Invoke(_state.CurrentConnectionStateUnsynchronized.Connection, new object[] { cancellationToken })).ConfigureAwait(false); + _state.CurrentConnectionStateUnsynchronized.ReceiveTask = ReceiveLoop(_state.CurrentConnectionStateUnsynchronized); + return; + } + // Start the connection var connection = await _connectionFactory.ConnectAsync(_endPoint, cancellationToken).ConfigureAwait(false); var startingConnectionState = new ConnectionState(connection, this); - // From here on, if an error occurs we need to shut down the connection because - // we still own it. - try + if (sendHandshake) { - Log.HubProtocol(_logger, _protocol.Name, _protocol.Version); - await HandshakeAsync(startingConnectionState, cancellationToken).ConfigureAwait(false); - } - catch (Exception ex) - { - Log.ErrorStartingConnection(_logger, ex); + // From here on, if an error occurs we need to shut down the connection because + // we still own it. + try + { + Log.HubProtocol(_logger, _protocol.Name, _protocol.Version); + await HandshakeAsync(startingConnectionState, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + Log.ErrorStartingConnection(_logger, ex); - // Can't have any invocations to cancel, we're in the lock. - await CloseAsync(startingConnectionState.Connection).ConfigureAwait(false); - throw; + // Can't have any invocations to cancel, we're in the lock. + await CloseAsync(startingConnectionState.Connection).ConfigureAwait(false); + throw; + } } // Set this at the end to avoid setting internal state until the connection is real @@ -1330,6 +1345,17 @@ async Task StartProcessingInvocationMessages(ChannelReader in var result = await input.ReadAsync().ConfigureAwait(false); var buffer = result.Buffer; + LogBytes(buffer.ToArray(), _logger); + void LogBytes(Memory memory, ILogger logger) + { + var sb = new StringBuilder(); + foreach (var b in memory.Span) + { + sb.Append($"0x{b:x} "); + } + logger.LogDebug(sb.ToString()); + } + try { if (result.IsCanceled) @@ -1426,15 +1452,16 @@ private async Task HandleConnectionClose(ConnectionState connectionState) await _state.WaitConnectionLockAsync(token: default).ConfigureAwait(false); try { - SafeAssert(ReferenceEquals(_state.CurrentConnectionStateUnsynchronized, connectionState), - "Someone other than ReceiveLoop cleared the connection state!"); - _state.CurrentConnectionStateUnsynchronized = null; + //SafeAssert(ReferenceEquals(_state.CurrentConnectionStateUnsynchronized, connectionState), + // "Someone other than ReceiveLoop cleared the connection state!"); + //_state.CurrentConnectionStateUnsynchronized = null; + await ((ValueTask)connectionState.Connection.GetType().GetMethod("CloseAsync").Invoke(connectionState.Connection, Array.Empty())).ConfigureAwait(false); // Dispose the connection - await CloseAsync(connectionState.Connection).ConfigureAwait(false); + //await CloseAsync(connectionState.Connection).ConfigureAwait(false); // Cancel any outstanding invocations within the connection lock - connectionState.CancelOutstandingInvocations(connectionState.CloseException); + //connectionState.CancelOutstandingInvocations(connectionState.CloseException); if (connectionState.Stopping || _reconnectPolicy == null) { @@ -1559,10 +1586,11 @@ private async Task ReconnectAsync(Exception? closeException) await _state.WaitConnectionLockAsync(token: default).ConfigureAwait(false); try { - SafeAssert(ReferenceEquals(_state.CurrentConnectionStateUnsynchronized, null), - "Someone other than Reconnect set the connection state!"); + //SafeAssert(ReferenceEquals(_state.CurrentConnectionStateUnsynchronized, null), + // "Someone other than Reconnect set the connection state!"); - await StartAsyncCore(_state.StopCts.Token).ConfigureAwait(false); + // TODO: sendHandshake needs to be determined by something + await StartAsyncCore(_state.StopCts.Token, sendHandshake: false).ConfigureAwait(false); Log.Reconnected(_logger, previousReconnectAttempts, DateTime.UtcNow - reconnectStartTime); diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs index 810633c73587..b7790ffc07c3 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs @@ -508,7 +508,7 @@ public async Task StartSkipsOverTransportsThatTheClientDoesNotUnderstand() var transportFactory = new Mock(MockBehavior.Strict); - transportFactory.Setup(t => t.CreateTransport(HttpTransportType.LongPolling)) + transportFactory.Setup(t => t.CreateTransport(HttpTransportType.LongPolling, true)) .Returns(new TestTransport(transferFormat: TransferFormat.Text | TransferFormat.Binary)); using (var noErrorScope = new VerifyNoErrorsScope()) @@ -557,7 +557,7 @@ public async Task StartSkipsOverTransportsThatDoNotSupportTheRequredTransferForm var transportFactory = new Mock(MockBehavior.Strict); - transportFactory.Setup(t => t.CreateTransport(HttpTransportType.LongPolling)) + transportFactory.Setup(t => t.CreateTransport(HttpTransportType.LongPolling, true)) .Returns(new TestTransport(transferFormat: TransferFormat.Text | TransferFormat.Binary)); await WithConnectionAsync( diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/TestTransportFactory.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/TestTransportFactory.cs index 2923c387e755..07491c809d7b 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/TestTransportFactory.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/TestTransportFactory.cs @@ -15,7 +15,7 @@ public TestTransportFactory(ITransport transport) _transport = transport; } - public ITransport CreateTransport(HttpTransportType availableServerTransports) + public ITransport CreateTransport(HttpTransportType availableServerTransports, bool useAck) { return _transport; } diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs index 1d56c314f784..b33af85fa9ec 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.IO.Pipelines; using System.Linq; @@ -21,10 +22,16 @@ namespace Microsoft.AspNetCore.Http.Connections.Client; +internal interface IConnectionCanRetry +{ + Task StartAsync(CancellationToken cancellationToken); + ValueTask CloseAsync(); +} + /// /// Used to make a connection to an ASP.NET Core ConnectionHandler using an HTTP-based transport. /// -public partial class HttpConnection : ConnectionContext, IConnectionInherentKeepAliveFeature +public partial class HttpConnection : ConnectionContext, IConnectionInherentKeepAliveFeature, IConnectionCanRetry { // Not configurable on purpose, high enough that if we reach here, it's likely // a buggy server @@ -50,6 +57,7 @@ public partial class HttpConnection : ConnectionContext, IConnectionInherentKeep private readonly ILoggerFactory _loggerFactory; private readonly Uri _url; private Func>? _accessTokenProvider; + private Uri? _connectUrl; /// public override IDuplexPipe Transport @@ -308,12 +316,25 @@ private async Task SelectAndStartTransport(TransferFormat transferFormat, Cancel var transportExceptions = new List(); - if (_httpConnectionOptions.SkipNegotiation) + var skipNegotiation = _httpConnectionOptions.SkipNegotiation; + if (!string.IsNullOrEmpty(_connectionId)) + { + skipNegotiation = true; + } + + if (skipNegotiation) { if (_httpConnectionOptions.Transports == HttpTransportType.WebSockets) { Log.StartingTransport(_logger, _httpConnectionOptions.Transports, uri); - await StartTransport(uri, _httpConnectionOptions.Transports, transferFormat, cancellationToken).ConfigureAwait(false); + await StartTransport(uri, _httpConnectionOptions.Transports, transferFormat, cancellationToken, false).ConfigureAwait(false); + } + else if (skipNegotiation && !_httpConnectionOptions.SkipNegotiation) + { + Debug.Assert(_connectUrl is not null); + //Log.StartingTransport(_logger, _httpConnectionOptions.Transports, uri); + HttpTransportType transport = _transport is WebSocketsTransport ? HttpTransportType.WebSockets : HttpTransportType.LongPolling; + await StartTransport(_connectUrl, transport, transferFormat, cancellationToken, false).ConfigureAwait(false); } else { @@ -351,7 +372,7 @@ private async Task SelectAndStartTransport(TransferFormat transferFormat, Cancel } // This should only need to happen once - var connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionToken); + _connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionToken); // We're going to search for the transfer format as a string because we don't want to parse // all the transfer formats in the negotiation response, and we want to allow transfer formats @@ -399,11 +420,11 @@ private async Task SelectAndStartTransport(TransferFormat transferFormat, Cancel if (negotiationResponse == null) { negotiationResponse = await GetNegotiationResponseAsync(uri, cancellationToken).ConfigureAwait(false); - connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionToken); + _connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionToken); } Log.StartingTransport(_logger, transportType, uri); - await StartTransport(connectUrl, transportType, transferFormat, cancellationToken).ConfigureAwait(false); + await StartTransport(_connectUrl, transportType, transferFormat, cancellationToken, negotiationResponse.UseAcking).ConfigureAwait(false); break; } } @@ -454,6 +475,7 @@ private async Task NegotiateAsync(Uri url, HttpClient httpC { uri = Utils.AppendQueryString(urlBuilder.Uri, $"negotiateVersion={_protocolVersionNumber}"); } + uri = Utils.AppendQueryString(uri, "useAck=true"); using (var request = new HttpRequestMessage(HttpMethod.Post, uri)) { @@ -500,10 +522,14 @@ private static Uri CreateConnectUrl(Uri url, string? connectionId) return Utils.AppendQueryString(url, $"id={connectionId}"); } - private async Task StartTransport(Uri connectUrl, HttpTransportType transportType, TransferFormat transferFormat, CancellationToken cancellationToken) + private async Task StartTransport(Uri connectUrl, HttpTransportType transportType, TransferFormat transferFormat, CancellationToken cancellationToken, bool useAck) { - // Construct the transport - var transport = _transportFactory.CreateTransport(transportType); + var transport = _transport; + if (transport is null) + { + // Construct the transport + transport = _transportFactory.CreateTransport(transportType, useAck); + } // Start the transport, giving it one end of the pipe try @@ -704,4 +730,41 @@ private async Task GetNegotiationResponseAsync(Uri uri, Can _logScope.ConnectionId = _connectionId; return negotiationResponse; } + + public async ValueTask CloseAsync() + { + await _connectionLock.WaitAsync().ConfigureAwait(false); + try + { + if (_started) + { + Log.DisposingHttpConnection(_logger); + + // Stop the transport, but we don't care if it throws. + // The transport should also have completed the pipe with this exception. + try + { + await _transport!.StopAsync().ConfigureAwait(false); + } + catch (Exception ex) + { + Log.TransportThrewExceptionOnStop(_logger, ex); + } + + Log.Disposed(_logger); + } + else + { + Log.SkippingDispose(_logger); + } + } + finally + { + // We want to do these things even if the WaitForWriterToComplete/WaitForReaderToComplete fails + + _started = false; + + _connectionLock.Release(); + } + } } diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/DefaultTransportFactory.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/DefaultTransportFactory.cs index 3a99961b41bd..5b37a0c561a4 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/DefaultTransportFactory.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/DefaultTransportFactory.cs @@ -31,13 +31,13 @@ public DefaultTransportFactory(HttpTransportType requestedTransportType, ILogger _accessTokenProvider = accessTokenProvider; } - public ITransport CreateTransport(HttpTransportType availableServerTransports) + public ITransport CreateTransport(HttpTransportType availableServerTransports, bool useAck) { if (_websocketsSupported && (availableServerTransports & HttpTransportType.WebSockets & _requestedTransportType) == HttpTransportType.WebSockets) { try { - return new WebSocketsTransport(_httpConnectionOptions, _loggerFactory, _accessTokenProvider, _httpClient); + return new WebSocketsTransport(_httpConnectionOptions, _loggerFactory, _accessTokenProvider, _httpClient, useAck); } catch (PlatformNotSupportedException ex) { @@ -49,13 +49,13 @@ public ITransport CreateTransport(HttpTransportType availableServerTransports) if ((availableServerTransports & HttpTransportType.ServerSentEvents & _requestedTransportType) == HttpTransportType.ServerSentEvents) { // We don't need to give the transport the accessTokenProvider because the HttpClient has a message handler that does the work for us. - return new ServerSentEventsTransport(_httpClient!, _httpConnectionOptions, _loggerFactory); + return new ServerSentEventsTransport(_httpClient!, _httpConnectionOptions, _loggerFactory, useAck); } if ((availableServerTransports & HttpTransportType.LongPolling & _requestedTransportType) == HttpTransportType.LongPolling) { // We don't need to give the transport the accessTokenProvider because the HttpClient has a message handler that does the work for us. - return new LongPollingTransport(_httpClient!, _httpConnectionOptions, _loggerFactory); + return new LongPollingTransport(_httpClient!, _httpConnectionOptions, _loggerFactory, useAck); } throw new InvalidOperationException("No requested transports available on the server."); diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ITransportFactory.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ITransportFactory.cs index cad87cd31772..e9779ba01fa8 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ITransportFactory.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ITransportFactory.cs @@ -5,5 +5,5 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal; internal interface ITransportFactory { - ITransport CreateTransport(HttpTransportType availableServerTransports); + ITransport CreateTransport(HttpTransportType availableServerTransports, bool useAck); } diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs index 8a14b77ba3e0..7168f74b866b 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs @@ -11,6 +11,8 @@ using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using PipelinesOverNetwork; +using static System.IO.Pipelines.DuplexPipe; namespace Microsoft.AspNetCore.Http.Connections.Client.Internal; @@ -21,6 +23,7 @@ internal sealed partial class LongPollingTransport : ITransport private readonly HttpConnectionOptions _httpConnectionOptions; private IDuplexPipe? _application; private IDuplexPipe? _transport; + private bool _useAck; // Volatile so that the poll loop sees the updated value set from a different thread private volatile Exception? _error; @@ -32,11 +35,12 @@ internal sealed partial class LongPollingTransport : ITransport public PipeWriter Output => _transport!.Output; - public LongPollingTransport(HttpClient httpClient, HttpConnectionOptions? httpConnectionOptions = null, ILoggerFactory? loggerFactory = null) + public LongPollingTransport(HttpClient httpClient, HttpConnectionOptions? httpConnectionOptions = null, ILoggerFactory? loggerFactory = null, bool useAck = false) { _httpClient = httpClient; _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); _httpConnectionOptions = httpConnectionOptions ?? new(); + _useAck = useAck; } public async Task StartAsync(Uri url, TransferFormat transferFormat, CancellationToken cancellationToken = default) @@ -57,12 +61,35 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio } // Create the pipe pair (Application's writer is connected to Transport's reader, and vice versa) - var pair = DuplexPipe.CreateConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); + DuplexPipePair pair; + if (_useAck) + { + pair = CreateConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); + } + else + { + pair = DuplexPipe.CreateConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); + } _transport = pair.Transport; _application = pair.Application; Running = ProcessAsync(url); + + static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) + { + var input = new Pipe(inputOptions); + var output = new Pipe(outputOptions); + + // Use for one side only, i.e. server + var ackWriterApp = new AckPipeWriter(output.Writer); + var ackReader = new AckPipeReader(output.Reader); + var transportReader = new ParseAckPipeReader(input.Reader, ackWriterApp, ackReader); + var transportToApplication = new DuplexPipe(ackReader, input.Writer); + var applicationToTransport = new DuplexPipe(transportReader, ackWriterApp); + + return new DuplexPipePair(applicationToTransport, transportToApplication); + } } private async Task ProcessAsync(Uri url) diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs index 25cbce7ba4a2..48c45bf8647f 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs @@ -13,6 +13,8 @@ using Microsoft.AspNetCore.Shared; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using PipelinesOverNetwork; +using static System.IO.Pipelines.DuplexPipe; namespace Microsoft.AspNetCore.Http.Connections.Client.Internal; @@ -26,6 +28,7 @@ internal sealed partial class ServerSentEventsTransport : ITransport private readonly CancellationTokenSource _transportCts = new CancellationTokenSource(); private readonly CancellationTokenSource _inputCts = new CancellationTokenSource(); private readonly ServerSentEventsMessageParser _parser = new ServerSentEventsMessageParser(); + private readonly bool _useAck; private IDuplexPipe? _transport; private IDuplexPipe? _application; @@ -35,10 +38,11 @@ internal sealed partial class ServerSentEventsTransport : ITransport public PipeWriter Output => _transport!.Output; - public ServerSentEventsTransport(HttpClient httpClient, HttpConnectionOptions? httpConnectionOptions = null, ILoggerFactory? loggerFactory = null) + public ServerSentEventsTransport(HttpClient httpClient, HttpConnectionOptions? httpConnectionOptions = null, ILoggerFactory? loggerFactory = null, bool useAck = false) { ArgumentNullThrowHelper.ThrowIfNull(httpClient); + _useAck = useAck; _httpClient = httpClient; _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); _httpConnectionOptions = httpConnectionOptions ?? new(); @@ -72,8 +76,17 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio throw; } + // Create the pipe pair (Application's writer is connected to Transport's reader, and vice versa) - var pair = DuplexPipe.CreateConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); + DuplexPipePair pair; + if (_useAck) + { + pair = CreateConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); + } + else + { + pair = DuplexPipe.CreateConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); + } _transport = pair.Transport; _application = pair.Application; @@ -84,6 +97,21 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio // _application.Input.OnWriterCompleted((exception, state) => ((CancellationTokenSource)state).Cancel(), inputCts); Running = ProcessAsync(url, response); + + static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) + { + var input = new Pipe(inputOptions); + var output = new Pipe(outputOptions); + + // Use for one side only, i.e. server + var ackWriter = new AckPipeWriter(output.Writer); + var ackReader = new AckPipeReader(output.Reader); + var transportReader = new ParseAckPipeReader(input.Reader, ackWriter, ackReader); + var transportToApplication = new DuplexPipe(ackReader, input.Writer); + var applicationToTransport = new DuplexPipe(transportReader, ackWriter); + + return new DuplexPipePair(applicationToTransport, transportToApplication); + } } private async Task ProcessAsync(Uri url, HttpResponseMessage response) diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs index 9c3d8184fc26..9f8137ee7e40 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs @@ -2,13 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Buffers; using System.Diagnostics; using System.IO.Pipelines; using System.Net; using System.Net.Http; +using System.Net.Sockets; using System.Net.WebSockets; using System.Runtime.InteropServices; using System.Security.Cryptography.X509Certificates; +using System.Text; using System.Text.Encodings.Web; using System.Threading; using System.Threading.Tasks; @@ -16,6 +19,8 @@ using Microsoft.AspNetCore.Shared; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using PipelinesOverNetwork; +using static System.IO.Pipelines.DuplexPipe; namespace Microsoft.AspNetCore.Http.Connections.Client.Internal; @@ -29,7 +34,9 @@ internal sealed partial class WebSocketsTransport : ITransport private volatile bool _aborted; private readonly HttpConnectionOptions _httpConnectionOptions; private readonly HttpClient? _httpClient; - private readonly CancellationTokenSource _stopCts = new CancellationTokenSource(); + private CancellationTokenSource _stopCts = default!; + private readonly bool _useAck; + private readonly Func> _accessTokenProvider; private IDuplexPipe? _transport; @@ -39,8 +46,10 @@ internal sealed partial class WebSocketsTransport : ITransport public PipeWriter Output => _transport!.Output; - public WebSocketsTransport(HttpConnectionOptions httpConnectionOptions, ILoggerFactory loggerFactory, Func> accessTokenProvider, HttpClient? httpClient) + public WebSocketsTransport(HttpConnectionOptions httpConnectionOptions, ILoggerFactory loggerFactory, Func> accessTokenProvider, HttpClient? httpClient, + bool useAck = false) { + _useAck = useAck; _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); _httpConnectionOptions = httpConnectionOptions ?? new HttpConnectionOptions(); @@ -48,7 +57,7 @@ public WebSocketsTransport(HttpConnectionOptions httpConnectionOptions, ILoggerF // We were given an updated delegate from the HttpConnection and we are updating what we have in httpOptions // options itself is copied object of user's options - _httpConnectionOptions.AccessTokenProvider = accessTokenProvider; + _accessTokenProvider = accessTokenProvider; _httpClient = httpClient; } @@ -201,14 +210,14 @@ static bool IsX509CertificateCollectionEqual(X509CertificateCollection? left, X5 } } - if (_httpConnectionOptions.AccessTokenProvider != null + if (_accessTokenProvider != null #if NET7_0_OR_GREATER && webSocket.Options.HttpVersion < HttpVersion.Version20 #endif ) { // Apply access token logic when using HTTP/1.1 because we don't use the AccessTokenHttpMessageHandler via HttpClient unless the user specifies HTTP/2.0 or higher - var accessToken = await _httpConnectionOptions.AccessTokenProvider().ConfigureAwait(false); + var accessToken = await _accessTokenProvider().ConfigureAwait(false); if (!string.IsNullOrWhiteSpace(accessToken)) { // We can't use request headers in the browser, so instead append the token as a query string in that case @@ -259,9 +268,10 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio throw new ArgumentException($"The '{transferFormat}' transfer format is not supported by this transport.", nameof(transferFormat)); } - _webSocketMessageType = transferFormat == TransferFormat.Binary - ? WebSocketMessageType.Binary - : WebSocketMessageType.Text; + //_webSocketMessageType = transferFormat == TransferFormat.Binary + // ? WebSocketMessageType.Binary + // : WebSocketMessageType.Text; + _webSocketMessageType = WebSocketMessageType.Binary; var resolvedUrl = ResolveWebSocketsUrl(url); @@ -278,15 +288,70 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio Log.StartedTransport(_logger); - // Create the pipe pair (Application's writer is connected to Transport's reader, and vice versa) - var pair = DuplexPipe.CreateConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); + _stopCts = new CancellationTokenSource(); - _transport = pair.Transport; - _application = pair.Application; + if (_transport is null) + { + // Create the pipe pair (Application's writer is connected to Transport's reader, and vice versa) + DuplexPipePair pair; + if (_useAck) + { + pair = CreateConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); + } + else + { + pair = DuplexPipe.CreateConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); + } + + _transport = pair.Transport; + _application = pair.Application; + } + else + { + // TODO: set pipe to start resend + if (_application!.Input is AckPipeReader reader) + { + // write nothing so just the ackid gets sent to server + // server will then send everything client may have missed as well as the last ackid so the client can resend + var buf = new byte[16]; + BitConverter.GetBytes(0).CopyTo(buf.AsMemory()); + BitConverter.GetBytes(((AckPipeWriter)(_transport.Output)).lastAck).CopyTo(buf.AsSpan().Slice(8)); + await _webSocket.SendAsync(new ArraySegment(buf, 0, 16), _webSocketMessageType, true, default).ConfigureAwait(false); + + // set after first send to server + reader.Resend(); + // once we've received something from the server (which will contain the ack id for the client) + // we can start the normal read/write loops, clients first send will resend everything the server missed + var memory = _application.Output.GetMemory(); + var isArray = MemoryMarshal.TryGetArray(memory, out var arraySegment); + Debug.Assert(isArray); + + // Exceptions are handled above where the send and receive tasks are being run. + var receiveResult = await _webSocket.ReceiveAsync(arraySegment, _stopCts.Token).ConfigureAwait(false); + _application.Output.Advance(receiveResult.Count); + + var flushResult = await _application.Output.FlushAsync().ConfigureAwait(false); + } + } // TODO: Handle TCP connection errors // https://github.com/SignalR/SignalR/blob/1fba14fa3437e24c204dfaf8a18db3fce8acad3c/src/Microsoft.AspNet.SignalR.Core/Owin/WebSockets/WebSocketHandler.cs#L248-L251 Running = ProcessSocketAsync(_webSocket); + + static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) + { + var input = new Pipe(inputOptions); + var output = new Pipe(outputOptions); + + // Use for one side only, i.e. server + var ackWriter = new AckPipeWriter(output.Writer); + var ackReader = new AckPipeReader(output.Reader); + var transportReader = new ParseAckPipeReader(input.Reader, ackWriter, ackReader); + var transportToApplication = new DuplexPipe(ackReader, input.Writer); + var applicationToTransport = new DuplexPipe(transportReader, ackWriter); + + return new DuplexPipePair(applicationToTransport, transportToApplication); + } } private async Task ProcessSocketAsync(WebSocket socket) @@ -311,7 +376,7 @@ private async Task ProcessSocketAsync(WebSocket socket) // 2. Waiting for a websocket send to complete // Cancel the application so that ReadAsync yields - _application.Input.CancelPendingRead(); + //_application.Input.CancelPendingRead(); var resultTask = await Task.WhenAny(sending, Task.Delay(_closeTimeout, _stopCts.Token)).ConfigureAwait(false); @@ -335,7 +400,7 @@ private async Task ProcessSocketAsync(WebSocket socket) socket.Abort(); // Cancel any pending flush so that we can quit - _application.Output.CancelPendingFlush(); + //_application.Output.CancelPendingFlush(); } } } @@ -392,6 +457,18 @@ private async Task StartReceiving(WebSocket socket) Log.MessageReceived(_logger, receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage); + //LogBytes(memory.Slice(0, receiveResult.Count), _logger); + void LogBytes(Memory memory, ILogger logger) + { + var sb = new StringBuilder(); + sb.Append("received: "); + foreach (var b in memory.Span) + { + sb.Append($"0x{b:x} "); + } + logger.LogDebug(sb.ToString()); + } + _application.Output.Advance(receiveResult.Count); var flushResult = await _application.Output.FlushAsync().ConfigureAwait(false); @@ -412,13 +489,13 @@ private async Task StartReceiving(WebSocket socket) { if (!_aborted) { - _application.Output.Complete(ex); + //_application.Output.Complete(ex); } } finally { // We're done writing - _application.Output.Complete(); + //_application.Output.Complete(); Log.ReceiveStopped(_logger); } @@ -437,6 +514,18 @@ private async Task StartSending(WebSocket socket) var result = await _application.Input.ReadAsync().ConfigureAwait(false); var buffer = result.Buffer; + //LogBytes(buffer.ToArray(), _logger); + void LogBytes(Memory memory, ILogger logger) + { + var sb = new StringBuilder(); + sb.Append("sending: "); + foreach (var b in memory.Span) + { + sb.Append($"0x{b:x} "); + } + logger.LogDebug(sb.ToString()); + } + // Get a frame from the application try @@ -509,7 +598,7 @@ private async Task StartSending(WebSocket socket) } } - _application.Input.Complete(); + //_application.Input.Complete(); Log.SendStopped(_logger); } @@ -547,11 +636,11 @@ public async Task StopAsync() return; } - _transport!.Output.Complete(); - _transport!.Input.Complete(); + //_transport!.Output.Complete(); + //_transport!.Input.Complete(); // Cancel any pending reads from the application, this should start the entire shutdown process - _application.Input.CancelPendingRead(); + //_application.Input.CancelPendingRead(); // Start ungraceful close timer _stopCts.CancelAfter(_closeTimeout); diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Microsoft.AspNetCore.Http.Connections.Client.csproj b/src/SignalR/clients/csharp/Http.Connections.Client/src/Microsoft.AspNetCore.Http.Connections.Client.csproj index e79a1fd7bdba..54cd016c2db4 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Microsoft.AspNetCore.Http.Connections.Client.csproj +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Microsoft.AspNetCore.Http.Connections.Client.csproj @@ -11,6 +11,8 @@ + + diff --git a/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs b/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs index 72b6838753a4..8ff9ba259ed1 100644 --- a/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs +++ b/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs @@ -34,6 +34,8 @@ public static class NegotiateProtocol private static readonly JsonEncodedText ErrorPropertyNameBytes = JsonEncodedText.Encode(ErrorPropertyName); private const string NegotiateVersionPropertyName = "negotiateVersion"; private static readonly JsonEncodedText NegotiateVersionPropertyNameBytes = JsonEncodedText.Encode(NegotiateVersionPropertyName); + private const string AckPropertyName = "useAck"; + private static readonly JsonEncodedText AckPropertyNameBytes = JsonEncodedText.Encode(AckPropertyName); // Use C#7.3's ReadOnlySpan optimization for static data https://vcsjones.com/2019/02/01/csharp-readonly-span-bytes-static/ // Used to detect ASP.NET SignalR Server connection attempt @@ -64,6 +66,11 @@ public static void WriteResponse(NegotiationResponse response, IBufferWriter content) List? availableTransports = null; string? error = null; int version = 0; + bool useAck = false; var completed = false; while (!completed && reader.CheckRead()) @@ -206,6 +214,10 @@ public static NegotiationResponse ParseResponse(ReadOnlySpan content) { throw new InvalidOperationException("Detected a connection attempt to an ASP.NET SignalR Server. This client only supports connecting to an ASP.NET Core SignalR Server. See https://aka.ms/signalr-core-differences for details."); } + else if (reader.ValueTextEquals(AckPropertyNameBytes.EncodedUtf8Bytes)) + { + useAck = reader.ReadAsBoolean(AckPropertyName); + } else { reader.Skip(); @@ -249,7 +261,8 @@ public static NegotiationResponse ParseResponse(ReadOnlySpan content) AccessToken = accessToken, AvailableTransports = availableTransports, Error = error, - Version = version + Version = version, + UseAcking = useAck, }; } catch (Exception ex) diff --git a/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs b/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs index 17c838137f44..17ba7671064d 100644 --- a/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs +++ b/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs @@ -44,4 +44,6 @@ public class NegotiationResponse /// An optional error during the negotiate. If this is not null the other properties on the response can be ignored. /// public string? Error { get; set; } + + public bool UseAcking { get; set; } } diff --git a/src/SignalR/common/Http.Connections.Common/src/PublicAPI.Unshipped.txt b/src/SignalR/common/Http.Connections.Common/src/PublicAPI.Unshipped.txt index 7dc5c58110bf..437066a1e801 100644 --- a/src/SignalR/common/Http.Connections.Common/src/PublicAPI.Unshipped.txt +++ b/src/SignalR/common/Http.Connections.Common/src/PublicAPI.Unshipped.txt @@ -1 +1,3 @@ #nullable enable +Microsoft.AspNetCore.Http.Connections.NegotiationResponse.UseAcking.get -> bool +Microsoft.AspNetCore.Http.Connections.NegotiationResponse.UseAcking.set -> void diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs index a74f17d6aad7..b5d34bc6500f 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs @@ -384,6 +384,7 @@ private async Task WaitOnTasks(Task applicationTask, Task transportTask, bool cl internal bool TryActivatePersistentConnection( ConnectionDelegate connectionDelegate, IHttpTransport transport, + Task currentRequestTask, HttpContext context, ILogger dispatcherLogger) { @@ -393,8 +394,10 @@ internal bool TryActivatePersistentConnection( { Status = HttpConnectionStatus.Active; + PreviousPollTask = currentRequestTask; + // Call into the end point passing the connection - ApplicationTask = ExecuteApplication(connectionDelegate); + ApplicationTask ??= ExecuteApplication(connectionDelegate); // Start the transport TransportTask = transport.ProcessRequestAsync(context, context.RequestAborted); diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs index de9fc73186df..3190585b64fe 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs @@ -155,37 +155,46 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti await DoPersistentConnection(connectionDelegate, sse, context, connection); } - else if (context.WebSockets.IsWebSocketRequest) + //else if (context.WebSockets.IsWebSocketRequest) + //{ + // // Connection can be established lazily + // var connection = await GetOrCreateConnectionAsync(context, options); + // if (connection == null) + // { + // // No such connection, GetOrCreateConnection already set the response status code + // return; + // } + + // if (!await EnsureConnectionStateAsync(connection, context, HttpTransportType.WebSockets, supportedTransports, logScope)) + // { + // // Bad connection state. It's already set the response status code. + // return; + // } + + // Log.EstablishedConnection(_logger); + + // // Allow the reads to be canceled + // connection.Cancellation = new CancellationTokenSource(); + + // var ws = new WebSocketsServerTransport(options.WebSockets, connection.Application, connection, _loggerFactory); + + // await DoPersistentConnection(connectionDelegate, ws, context, connection); + //} + else { - // Connection can be established lazily - var connection = await GetOrCreateConnectionAsync(context, options); - if (connection == null) + // GET /{path} maps to long polling + + var transport = HttpTransportType.LongPolling; + if (context.WebSockets.IsWebSocketRequest) { - // No such connection, GetOrCreateConnection already set the response status code - return; - } - if (!await EnsureConnectionStateAsync(connection, context, HttpTransportType.WebSockets, supportedTransports, logScope)) + transport = HttpTransportType.WebSockets; + } + else { - // Bad connection state. It's already set the response status code. - return; + AddNoCacheHeaders(context.Response); } - Log.EstablishedConnection(_logger); - - // Allow the reads to be canceled - connection.Cancellation = new CancellationTokenSource(); - - var ws = new WebSocketsServerTransport(options.WebSockets, connection.Application, connection, _loggerFactory); - - await DoPersistentConnection(connectionDelegate, ws, context, connection); - } - else - { - // GET /{path} maps to long polling - - AddNoCacheHeaders(context.Response); - // Connection must already exist var connection = await GetConnectionAsync(context); if (connection == null) @@ -194,7 +203,7 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti return; } - if (!await EnsureConnectionStateAsync(connection, context, HttpTransportType.LongPolling, supportedTransports, logScope)) + if (!await EnsureConnectionStateAsync(connection, context, transport, supportedTransports, logScope)) { // Bad connection state. It's already set the response status code. return; @@ -209,11 +218,24 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti // Create a new Tcs every poll to keep track of the poll finishing, so we can properly wait on previous polls var currentRequestTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - if (!connection.TryActivateLongPollingConnection( - connectionDelegate, context, options.LongPolling.PollTimeout, - currentRequestTcs.Task, _loggerFactory, _logger)) + switch (transport) { - return; + case HttpTransportType.None: + break; + case HttpTransportType.WebSockets: + var ws = new WebSocketsServerTransport(options.WebSockets, connection.Application, connection, _loggerFactory); + connection.TryActivatePersistentConnection(connectionDelegate, ws, currentRequestTcs.Task, context, _logger); + break; + case HttpTransportType.LongPolling: + if (!connection.TryActivateLongPollingConnection( + connectionDelegate, context, options.LongPolling.PollTimeout, + currentRequestTcs.Task, _loggerFactory, _logger)) + { + return; + } + break; + default: + break; } context.Features.Get()?.DisableTimeout(); @@ -276,7 +298,7 @@ private async Task DoPersistentConnection(ConnectionDelegate connectionDelegate, HttpContext context, HttpConnectionContext connection) { - if (connection.TryActivatePersistentConnection(connectionDelegate, transport, context, _logger)) + //if (connection.TryActivatePersistentConnection(connectionDelegate, transport, context, _logger)) { context.Features.Get()?.DisableTimeout(); // Wait for any of them to end @@ -317,11 +339,18 @@ private async Task ProcessNegotiate(HttpContext context, HttpConnectionDispatche Log.NegotiateProtocolVersionMismatch(_logger, 0); } + var useAck = false; + if (context.Request.Query.TryGetValue("UseAck", out var useAckValue)) + { + var useAckStringValue = useAckValue.ToString(); + bool.TryParse(useAckStringValue, out useAck); + } + // Establish the connection HttpConnectionContext? connection = null; if (error == null) { - connection = CreateConnection(options, clientProtocolVersion); + connection = CreateConnection(options, clientProtocolVersion, useAck); } // Set the Connection ID on the logging scope so that logs from now on will have the @@ -334,7 +363,7 @@ private async Task ProcessNegotiate(HttpContext context, HttpConnectionDispatche try { // Get the bytes for the connection id - WriteNegotiatePayload(writer, connection?.ConnectionId, connection?.ConnectionToken, context, options, clientProtocolVersion, error); + WriteNegotiatePayload(writer, connection?.ConnectionId, connection?.ConnectionToken, context, options, clientProtocolVersion, error, useAck); Log.NegotiationRequest(_logger); @@ -349,7 +378,7 @@ private async Task ProcessNegotiate(HttpContext context, HttpConnectionDispatche } private static void WriteNegotiatePayload(IBufferWriter writer, string? connectionId, string? connectionToken, HttpContext context, HttpConnectionDispatcherOptions options, - int clientProtocolVersion, string? error) + int clientProtocolVersion, string? error, bool useAck) { var response = new NegotiationResponse(); @@ -364,6 +393,7 @@ private static void WriteNegotiatePayload(IBufferWriter writer, string? co response.ConnectionId = connectionId; response.ConnectionToken = connectionToken; response.AvailableTransports = new List(); + response.UseAcking = useAck; if ((options.Transports & HttpTransportType.WebSockets) != 0 && ServerHasWebSockets(context.Features)) { @@ -745,9 +775,9 @@ private static void CloneHttpContext(HttpContext context, HttpConnectionContext return connection; } - private HttpConnectionContext CreateConnection(HttpConnectionDispatcherOptions options, int clientProtocolVersion = 0) + private HttpConnectionContext CreateConnection(HttpConnectionDispatcherOptions options, int clientProtocolVersion = 0, bool useAck = false) { - return _manager.CreateConnection(options, clientProtocolVersion); + return _manager.CreateConnection(options, clientProtocolVersion, useAck); } private static void AddNoCacheHeaders(HttpResponse response) diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs index 3f1cce2edb6b..8511c4a618b8 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs @@ -11,6 +11,8 @@ using Microsoft.Extensions.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using PipelinesOverNetwork; +using static System.IO.Pipelines.DuplexPipe; namespace Microsoft.AspNetCore.Http.Connections.Internal; @@ -67,7 +69,7 @@ internal HttpConnectionContext CreateConnection() /// Creates a connection without Pipes setup to allow saving allocations until Pipes are needed. /// /// - internal HttpConnectionContext CreateConnection(HttpConnectionDispatcherOptions options, int negotiateVersion = 0) + internal HttpConnectionContext CreateConnection(HttpConnectionDispatcherOptions options, int negotiateVersion = 0, bool useAck = false) { string connectionToken; var id = MakeNewConnectionId(); @@ -96,6 +98,21 @@ internal HttpConnectionContext CreateConnection(HttpConnectionDispatcherOptions _connections.TryAdd(connectionToken, (connection, startTimestamp)); return connection; + + static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) + { + var input = new Pipe(inputOptions); + var output = new Pipe(outputOptions); + + // Use for one side only, i.e. server + var ackWriterApp = new AckPipeWriter(output.Writer); + var ackReader = new AckPipeReader(output.Reader); + var transportReader = new ParseAckPipeReader(input.Reader, ackWriterApp, ackReader); + var transportToApplication = new DuplexPipe(ackReader, input.Writer); + var applicationToTransport = new DuplexPipe(transportReader, ackWriterApp); + + return new DuplexPipePair(transportToApplication, applicationToTransport); + } } public void RemoveConnection(string id, HttpTransportType transportType, HttpConnectionStopStatus status) @@ -159,7 +176,7 @@ public void Scan() // Once the decision has been made to dispose we don't check the status again // But don't clean up connections while the debugger is attached. - if (!Debugger.IsAttached && lastSeenTick.HasValue && (ticks - lastSeenTick.Value) > _disconnectTimeoutTicks) + if (!true && lastSeenTick.HasValue && (ticks - lastSeenTick.Value) > _disconnectTimeoutTicks) { Log.ConnectionTimedOut(_logger, connection.ConnectionId); HttpConnectionsEventSource.Log.ConnectionTimedOut(connection.ConnectionId); diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs index 3d473aa095b9..7b94a157e999 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs @@ -1,11 +1,14 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; using System.Diagnostics; using System.IO.Pipelines; using System.Net.WebSockets; +using System.Text; using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.Logging; +using PipelinesOverNetwork; namespace Microsoft.AspNetCore.Http.Connections.Internal.Transports; @@ -54,9 +57,24 @@ public async Task ProcessRequestAsync(HttpContext context, CancellationToken tok public async Task ProcessSocketAsync(WebSocket socket) { - // Begin sending and receiving. Receiving must be started first because ExecuteAsync enables SendAsync. - var receiving = StartReceiving(socket); - var sending = StartSending(socket); + Task receiving; + Task sending; + if (_application.Input is AckPipeReader reader) + { + _aborted = false; + // TODO: check if the pipe was used previously? + + reader.Resend(); + // wait for first read? + _ = await socket.ReceiveAsync(Memory.Empty, _connection.Cancellation?.Token ?? default); + } + // if (_application.Input.HasBeenUsedBefore) + // read first to get the ack id for resending + // set resend id on output pipe + // start send loop which will resend and tell the client the last ack id it got from the read side + // Begin sending and receiving. Receiving must be started first because ExecuteAsync enables SendAsync. + receiving = StartReceiving(socket); + sending = StartSending(socket); // Wait for send or receive to complete var trigger = await Task.WhenAny(receiving, sending); @@ -148,7 +166,19 @@ private async Task StartReceiving(WebSocket socket) return; } + void LogBytes(Memory memory, ILogger logger) + { + var sb = new StringBuilder(); + sb.Append("received: "); + foreach (var b in memory.Span) + { + sb.Append($"0x{b:x} "); + } + logger.LogDebug(sb.ToString()); + } + Log.MessageReceived(_logger, receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage); + LogBytes(memory.Slice(0, receiveResult.Count), _logger); _application.Output.Advance(receiveResult.Count); @@ -181,7 +211,7 @@ private async Task StartReceiving(WebSocket socket) finally { // We're done writing - _application.Output.Complete(); + //_application.Output.Complete(); } } @@ -211,9 +241,23 @@ private async Task StartSending(WebSocket socket) { Log.SendPayload(_logger, buffer.Length); - var webSocketMessageType = (_connection.ActiveFormat == TransferFormat.Binary - ? WebSocketMessageType.Binary - : WebSocketMessageType.Text); + //var webSocketMessageType = (_connection.ActiveFormat == TransferFormat.Binary + // ? WebSocketMessageType.Binary + // : WebSocketMessageType.Text); + var webSocketMessageType = WebSocketMessageType.Binary; + + LogBytes(buffer.ToArray(), _logger); + + void LogBytes(Memory memory, ILogger logger) + { + var sb = new StringBuilder(); + sb.Append("sending: "); + foreach (var b in memory.Span) + { + sb.Append($"0x{b:x} "); + } + logger.LogDebug(sb.ToString()); + } if (WebSocketCanSend(socket)) { @@ -266,7 +310,12 @@ private async Task StartSending(WebSocket socket) } } - _application.Input.Complete(); + if (error is not null) + { + _logger.LogError("Error in send {ex}.", error); + } + + //_application.Input.Complete(); } } diff --git a/src/SignalR/common/Http.Connections/src/Microsoft.AspNetCore.Http.Connections.csproj b/src/SignalR/common/Http.Connections/src/Microsoft.AspNetCore.Http.Connections.csproj index e6ee74ec7735..c731acef4c54 100644 --- a/src/SignalR/common/Http.Connections/src/Microsoft.AspNetCore.Http.Connections.csproj +++ b/src/SignalR/common/Http.Connections/src/Microsoft.AspNetCore.Http.Connections.csproj @@ -17,6 +17,8 @@ + + diff --git a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs index 9d16f32f4a6b..518ca1646d2a 100644 --- a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs +++ b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs @@ -1722,7 +1722,7 @@ public async Task AttemptingToPollWhileAlreadyPollingReplacesTheCurrentPoll() Assert.Equal(string.Empty, GetContentAsString(context1.Response.Body)); AssertResponseHasCacheHeaders(context1.Response); Assert.Equal(StatusCodes.Status200OK, context2.Response.StatusCode); - Assert.Equal("Hello, World", GetContentAsString(context2.Response.Body)); + Assert.Equal("Hello, World", GetContentAsString(context2.Response.Body).AsSpan(16).ToString()); AssertResponseHasCacheHeaders(context2.Response); } } diff --git a/src/SignalR/common/Shared/AcknowledgePipe/DuplexPipe.cs b/src/SignalR/common/Shared/AcknowledgePipe/DuplexPipe.cs new file mode 100644 index 000000000000..cf8f21f0809c --- /dev/null +++ b/src/SignalR/common/Shared/AcknowledgePipe/DuplexPipe.cs @@ -0,0 +1,93 @@ +using System.IO.Pipelines; + +namespace PipelinesOverNetwork +{ + internal sealed class DuplexPipe : IDuplexPipe + { + public DuplexPipe(PipeReader reader, PipeWriter writer) + { + Input = reader; + Output = writer; + } + + public PipeReader Input { get; } + + public PipeWriter Output { get; } + + public static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) + { + var input = new Pipe(inputOptions); + var output = new Pipe(outputOptions); + + var transportToApplication = new DuplexPipe(output.Reader, input.Writer); + var applicationToTransport = new DuplexPipe(input.Reader, output.Writer); + + return new DuplexPipePair(applicationToTransport, transportToApplication); + } + + // This class exists to work around issues with value tuple on .NET Framework + public readonly struct DuplexPipePair + { + public IDuplexPipe Transport { get; } + public IDuplexPipe Application { get; } + + public DuplexPipePair(IDuplexPipe transport, IDuplexPipe application) + { + Transport = transport; + Application = application; + } + } + } + + internal sealed class AckDuplexPipe : IDuplexPipe + { + + public AckDuplexPipe(PipeReader reader, PipeWriter writer) + { + Input = reader; + Output = writer; + } + + public PipeReader Input { get; } + + public PipeWriter Output { get; } + + public static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) + { + var input = new Pipe(inputOptions); + var output = new Pipe(outputOptions); + + // wire up both sides for testing + var ackWriterApp = new AckPipeWriter(output.Writer); + var ackReaderApp = new AckPipeReader(output.Reader); + var ackWriterClient = new AckPipeWriter(input.Writer); + var ackReaderClient = new AckPipeReader(input.Reader); + var transportReader = new ParseAckPipeReader(input.Reader, ackWriterApp, ackReaderApp); + var applicationReader = new ParseAckPipeReader(ackReaderApp, ackWriterClient, ackReaderClient); + var transportToApplication = new DuplexPipe(applicationReader, ackWriterClient); + var applicationToTransport = new DuplexPipe(transportReader, ackWriterApp); + + // Use for one side only, i.e. server + //var ackWriter = new AckPipeWriter(output.Writer); + //var ackReader = new AckPipeReader(output.Reader); + //var transportReader = new ParseAckPipeReader(input.Reader, ackWriter, ackReader); + //var transportToApplication = new DuplexPipe(ackReader, input.Writer); + //var applicationToTransport = new DuplexPipe(transportReader, ackWriter); + + return new DuplexPipePair(applicationToTransport, transportToApplication); + } + + // This class exists to work around issues with value tuple on .NET Framework + public readonly struct DuplexPipePair + { + public IDuplexPipe Transport { get; } + public IDuplexPipe Application { get; } + + public DuplexPipePair(IDuplexPipe transport, IDuplexPipe application) + { + Transport = transport; + Application = application; + } + } + } +} diff --git a/src/SignalR/common/Shared/AcknowledgePipeV2.cs b/src/SignalR/common/Shared/AcknowledgePipeV2.cs new file mode 100644 index 000000000000..4a8346258d56 --- /dev/null +++ b/src/SignalR/common/Shared/AcknowledgePipeV2.cs @@ -0,0 +1,201 @@ +using System; +using System.Diagnostics; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; + +namespace PipelinesOverNetwork +{ + // Wrapper around a PipeReader that adds an Ack position which replaces Consumed + // This allows the underlying pipe to keep un-acked data in the pipe while still providing only new data to the reader + internal sealed class AckPipeReader : PipeReader + { + private readonly PipeReader _inner; + private SequencePosition _consumed; + private SequencePosition _ackPosition; + private long _ackDiff; + private long _ackId; + private long _totalWritten; + private bool _resend; + private object _lock = new object(); + + public AckPipeReader(PipeReader inner) + { + _inner = inner; + } + + public void Ack(long byteID) + { + lock (_lock) + { + //Debug.Assert(_ackDiff == 0); + // ignore? Is this a bad state? + if (byteID < _ackId) + { + return; + } + //Debug.Assert(byteID >= _ackId); + _ackDiff = byteID - _ackId; + //Console.WriteLine($"AckId: {byteID}"); + } + } + + public void Resend() + { + Debug.Assert(_resend == false); + _resend = true; + } + + public override void AdvanceTo(SequencePosition consumed) + { + AdvanceTo(consumed, consumed); + } + + public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) + { + _consumed = consumed; + if (_consumed.Equals(_ackPosition)) + _consumed = default; + _inner.AdvanceTo(_ackPosition, examined); + } + + public override void CancelPendingRead() + { + _inner.CancelPendingRead(); + } + + public override void Complete(Exception? exception = null) + { + _inner.Complete(exception); + } + + public override async ValueTask ReadAsync(CancellationToken cancellationToken = default) + { + var res = await _inner.ReadAsync(cancellationToken).ConfigureAwait(false); + var buffer = res.Buffer; + lock (_lock) + { + if (_ackDiff != 0) + { + //if (buffer.Slice(_ackDiff).Start.GetInteger() == 0 && buffer.Slice(_consumed).Start.GetInteger() > 0) + //{ + // Debugger.Break(); + //} + //if (buffer.Slice(_consumed).Start.Equals(buffer.Slice(_ackDiff).Start)) + //{ + // _consumed = buffer.Slice(_ackDiff).Start; + //} + if (buffer.Slice(_consumed).First.Length == 0 && buffer.Slice(_ackDiff).Start.GetInteger() == 0) + { + _consumed = buffer.Slice(buffer.Length - buffer.Slice(_consumed).Length).Start; + } + //buffer = buffer.Slice(_ackDiff + 16); + buffer = buffer.Slice(_ackDiff); + _ackId += _ackDiff; + _ackDiff = 0; + _ackPosition = buffer.Start; + } + } + // Slice consumed, unless resending, then slice to ackPosition + // TODO: implement resend for reconnect + if (_resend) + { + _resend = false; + buffer = buffer.Slice(_ackPosition); + // update total written? + } + else + { + _ackPosition = buffer.Start; + // TODO: buffer.Length is 0 sometimes, figure out why and verify behavior + if (buffer.Length > 0 && !_consumed.Equals(default)) + { + buffer = buffer.Slice(_consumed); + } + _totalWritten += (uint)buffer.Length; + } + res = new(buffer, res.IsCanceled, res.IsCompleted); + return res; + } + + public override bool TryRead(out ReadResult result) + { + throw new NotImplementedException(); + } + } + + // Wrapper around a PipeWriter that adds framing to writes + internal sealed class AckPipeWriter : PipeWriter + { + private const int FrameSize = 16; + private readonly PipeWriter _inner; + internal long lastAck; + + Memory _frameHeader; + bool _shouldAdvanceFrameHeader; + private long _buffered; + + public AckPipeWriter(PipeWriter inner) + { + _inner = inner; + } + + public override void Advance(int bytes) + { + _buffered += bytes; + if (_shouldAdvanceFrameHeader) + { + bytes += FrameSize; + _shouldAdvanceFrameHeader = false; + } + _inner.Advance(bytes); + } + + public override void CancelPendingFlush() + { + _inner.CancelPendingFlush(); + } + + public override void Complete(Exception? exception = null) + { + _inner.Complete(exception); + } + + // X - 8 byte size of payload as uint + // Y - 8 byte number of acked bytes + // Z - payload + // [ XXXX YYYY ZZZZ ] + public override ValueTask FlushAsync(CancellationToken cancellationToken = default) + { +#if NETSTANDARD2_1_OR_GREATER + BitConverter.TryWriteBytes(_frameHeader.Span, _buffered); + BitConverter.TryWriteBytes(_frameHeader.Slice(8).Span, lastAck); +#else + BitConverter.GetBytes(_buffered).CopyTo(_frameHeader); + BitConverter.GetBytes(lastAck).CopyTo(_frameHeader.Slice(8).Span); +#endif + //Console.WriteLine($"SendingAckId: {lastAck}"); + _frameHeader = Memory.Empty; + _buffered = 0; + return _inner.FlushAsync(cancellationToken); + } + + public override Memory GetMemory(int sizeHint = 0) + { + var segment = _inner.GetMemory(Math.Max(FrameSize + 1, sizeHint)); + if (_frameHeader.IsEmpty || _buffered == 0) + { + // TODO: segment less than FrameSize + _frameHeader = segment.Slice(0, FrameSize); + segment = segment.Slice(FrameSize); + _shouldAdvanceFrameHeader = true; + } + return segment; + } + + public override Span GetSpan(int sizeHint = 0) + { + return GetMemory(sizeHint).Span; + } + } +} diff --git a/src/SignalR/common/Shared/ParseAckPipeReader.cs b/src/SignalR/common/Shared/ParseAckPipeReader.cs new file mode 100644 index 000000000000..272b3f5ab7d9 --- /dev/null +++ b/src/SignalR/common/Shared/ParseAckPipeReader.cs @@ -0,0 +1,108 @@ +using System; +using System.Buffers; +using System.Diagnostics; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; + +namespace PipelinesOverNetwork +{ + // Read from "network" + // Parse framing and slice the read so the application doesn't see the framing + // Notify outbound pipe of framing details for when sending back + internal class ParseAckPipeReader : PipeReader + { + private readonly PipeReader _inner; + private readonly AckPipeWriter _ackPipeWriter; + private readonly AckPipeReader _ackPipeReader; + private long _totalBytes; + + private ReadOnlySequence _currentRead; + + public ParseAckPipeReader(PipeReader inner, AckPipeWriter ackPipeWriter, AckPipeReader ackPipeReader) + { + _inner = inner; + _ackPipeWriter = ackPipeWriter; + _ackPipeReader = ackPipeReader; + } + + public override void AdvanceTo(SequencePosition consumed) + { + var len =_currentRead.Length - _currentRead.Slice(consumed).Length; + //Console.WriteLine($"lastack: {_ackPipeWriter.lastAck} to {_ackPipeWriter.lastAck + len}"); + // ignore the empty length send, maybe don't return from ReadAsync instead? + _ackPipeWriter.lastAck += (len == 16) ? 0 : len; + _inner.AdvanceTo(consumed); + } + + public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) + { + var len = _currentRead.Length - _currentRead.Slice(consumed).Length; + //Console.WriteLine($"lastack: {_ackPipeWriter.lastAck} to {_ackPipeWriter.lastAck + len}"); + _ackPipeWriter.lastAck += (len == 16) ? 0 : len; + // Track? + _inner.AdvanceTo(consumed, examined); + } + + public override void CancelPendingRead() + { + _inner.CancelPendingRead(); + } + + public override void Complete(Exception? exception = null) + { + _inner.Complete(exception); + } + + public override async ValueTask ReadAsync(CancellationToken cancellationToken = default) + { + var res = await _inner.ReadAsync(cancellationToken); + if (res.IsCompleted || res.IsCanceled) + { + if (res.Buffer.Length >= 16) + res = new(res.Buffer.Slice(16), res.IsCanceled, res.IsCompleted); + return res; + } + + _currentRead = res.Buffer; + // TODO: handle previous payload not fully received + // TODO: handle multiple frame prefixed messages + var frame = res.Buffer.Slice(0, 16); + var len = ParseFrame(frame, _ackPipeReader); + _totalBytes += len; + // 0 len sent on reconnect and not part of acks + if (len != 0) + { + //Console.WriteLine($"lastack: {_ackPipeWriter.lastAck} to {_ackPipeWriter.lastAck + res.Buffer.Length}"); + //_ackPipeWriter.lastAck += res.Buffer.Length; + } + + // TODO: validation everywhere! + Debug.Assert(len < res.Buffer.Length); + + res = new(res.Buffer.Slice(16, len), res.IsCanceled, res.IsCompleted); + return res; + + static long ParseFrame(ReadOnlySequence frame, AckPipeReader ackPipeReader) + { + Span buffer = stackalloc byte[16]; + frame.CopyTo(buffer); + // TODO: use these values +#if NETSTANDARD2_1_OR_GREATER + var len = BitConverter.ToInt64(buffer); + var ackId = BitConverter.ToInt64(buffer.Slice(8)); +#else + var len = BitConverter.ToInt64(buffer.Slice(0, 8).ToArray(), 0); + var ackId = BitConverter.ToInt64(buffer.Slice(8).ToArray(), 0); +#endif + ackPipeReader.Ack(ackId); + return len; + } + } + + public override bool TryRead(out ReadResult result) + { + throw new NotImplementedException(); + } + } +} diff --git a/src/SignalR/samples/ClientSample/HubSample.cs b/src/SignalR/samples/ClientSample/HubSample.cs index f77b9c3a9a8d..c1b121fa4ef6 100644 --- a/src/SignalR/samples/ClientSample/HubSample.cs +++ b/src/SignalR/samples/ClientSample/HubSample.cs @@ -34,7 +34,7 @@ public static async Task ExecuteAsync(string baseUrl) var connectionBuilder = new HubConnectionBuilder() .ConfigureLogging(logging => { - logging.AddConsole(); + //logging.AddConsole(); }); connectionBuilder.Services.Configure(options => @@ -55,6 +55,7 @@ public static async Task ExecuteAsync(string baseUrl) using var closedTokenSource = new CancellationTokenSource(); var connection = connectionBuilder.Build(); + connection.ServerTimeout = TimeSpan.FromSeconds(15); try { @@ -99,7 +100,7 @@ public static async Task ExecuteAsync(string baseUrl) try { - await connection.InvokeAsync("Send", line); + await connection.InvokeAsync("Send", "C#", line); } catch when (closedTokenSource.IsCancellationRequested) { diff --git a/src/SignalR/samples/SignalRSamples/Program.cs b/src/SignalR/samples/SignalRSamples/Program.cs index 3675e5c32b37..757e269b9e99 100644 --- a/src/SignalR/samples/SignalRSamples/Program.cs +++ b/src/SignalR/samples/SignalRSamples/Program.cs @@ -25,12 +25,13 @@ public static Task Main(string[] args) { factory.AddConfiguration(c.Configuration.GetSection("Logging")); factory.AddConsole(); - //factory.SetMinimumLevel(LogLevel.Trace); + factory.SetMinimumLevel(LogLevel.Trace); + //factory.SetMinimumLevel(LogLevel.Debug); }) .UseKestrel(options => { // Default port - options.ListenAnyIP(0); + options.ListenAnyIP(5000); // Hub bound to TCP end point //options.Listen(IPAddress.Any, 9001, builder => diff --git a/src/SignalR/samples/SignalRSamples/Startup.cs b/src/SignalR/samples/SignalRSamples/Startup.cs index 5a3d67e481c3..5c42ffdb7293 100644 --- a/src/SignalR/samples/SignalRSamples/Startup.cs +++ b/src/SignalR/samples/SignalRSamples/Startup.cs @@ -18,7 +18,12 @@ public void ConfigureServices(IServiceCollection services) { services.AddConnections(); - services.AddSignalR() + services.AddSignalR(o => + { + o.MaximumParallelInvocationsPerClient = 10; + o.ClientTimeoutInterval = TimeSpan.FromSeconds(100); + o.KeepAliveInterval = TimeSpan.FromSeconds(5); + }) .AddMessagePackProtocol(); //.AddStackExchangeRedis(); } diff --git a/src/SignalR/server/Core/src/HubConnectionHandler.cs b/src/SignalR/server/Core/src/HubConnectionHandler.cs index ab3d0f5bbd7b..6b0b4c1cfaee 100644 --- a/src/SignalR/server/Core/src/HubConnectionHandler.cs +++ b/src/SignalR/server/Core/src/HubConnectionHandler.cs @@ -1,7 +1,9 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; using System.Linq; +using System.Text; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; @@ -249,6 +251,17 @@ private async Task DispatchMessagesAsync(HubConnectionContext connection) { var result = await input.ReadAsync(); var buffer = result.Buffer; + LogBytes(buffer.ToArray(), _logger); + + void LogBytes(Memory memory, ILogger logger) + { + var sb = new StringBuilder(); + foreach (var b in memory.Span) + { + sb.Append($"0x{b:x} "); + } + logger.LogDebug(sb.ToString()); + } try { diff --git a/src/SignalR/server/SignalR/test/DefaultTransportFactoryTests.cs b/src/SignalR/server/SignalR/test/DefaultTransportFactoryTests.cs index 3daaaa4e4477..d553e88db86c 100644 --- a/src/SignalR/server/SignalR/test/DefaultTransportFactoryTests.cs +++ b/src/SignalR/server/SignalR/test/DefaultTransportFactoryTests.cs @@ -53,7 +53,7 @@ public void DefaultTransportFactoryCreatesRequestedTransportIfAvailable(HttpTran { var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null); Assert.IsType(expectedTransportType, - transportFactory.CreateTransport(AllTransportTypes)); + transportFactory.CreateTransport(AllTransportTypes, true)); } [Theory] @@ -66,7 +66,7 @@ public void DefaultTransportFactoryThrowsIfItCannotCreateRequestedTransport(Http var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null); var ex = Assert.Throws( - () => transportFactory.CreateTransport(~requestedTransport)); + () => transportFactory.CreateTransport(~requestedTransport, true)); Assert.Equal("No requested transports available on the server.", ex.Message); } @@ -77,7 +77,7 @@ public void DefaultTransportFactoryCreatesWebSocketsTransportIfAvailable() { Assert.IsType( new DefaultTransportFactory(AllTransportTypes, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null) - .CreateTransport(AllTransportTypes)); + .CreateTransport(AllTransportTypes, true)); } [Theory] @@ -90,7 +90,7 @@ public void DefaultTransportFactoryCreatesRequestedTransportIfAvailable_Win7(Htt { var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null); Assert.IsType(expectedTransportType, - transportFactory.CreateTransport(AllTransportTypes)); + transportFactory.CreateTransport(AllTransportTypes, true)); } } @@ -103,7 +103,7 @@ public void DefaultTransportFactoryThrowsIfItCannotCreateRequestedTransport_Win7 var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null); var ex = Assert.Throws( - () => transportFactory.CreateTransport(AllTransportTypes)); + () => transportFactory.CreateTransport(AllTransportTypes, true)); Assert.Equal("No requested transports available on the server.", ex.Message); } diff --git a/src/SignalR/server/SignalR/test/EndToEndTests.cs b/src/SignalR/server/SignalR/test/EndToEndTests.cs index a6e85b802510..bdde2c2673f1 100644 --- a/src/SignalR/server/SignalR/test/EndToEndTests.cs +++ b/src/SignalR/server/SignalR/test/EndToEndTests.cs @@ -685,7 +685,7 @@ private class TestTransportFactory : ITransportFactory { private ITransport _transport; - public ITransport CreateTransport(HttpTransportType availableServerTransports) + public ITransport CreateTransport(HttpTransportType availableServerTransports, bool useAck) { if (_transport == null) { From fd4f81775f7839dfed3176fdd7cab2b176b14346 Mon Sep 17 00:00:00 2001 From: Brennan Date: Fri, 17 Mar 2023 09:26:25 -0700 Subject: [PATCH 02/25] It's working --- .../csharp/Client.Core/src/HubConnection.cs | 72 +++---- .../src/HttpConnection.cs | 77 +------- .../src/Internal/WebSocketsTransport.cs | 54 +++-- .../src/Internal/HttpConnectionContext.cs | 9 +- .../src/Internal/HttpConnectionDispatcher.cs | 42 +++- .../src/Internal/HttpConnectionManager.cs | 2 +- .../src/Internal/Transports/IHttpTransport.cs | 2 +- .../Transports/LongPollingServerTransport.cs | 5 +- .../ServerSentEventsServerTransport.cs | 4 +- .../Transports/WebSocketsServerTransport.cs | 34 +++- .../test/HttpConnectionDispatcherTests.cs | 4 +- .../test/HttpConnectionManagerTests.cs | 14 +- .../common/Shared/AcknowledgePipeV2.cs | 3 +- .../common/Shared/ParseAckPipeReader.cs | 158 +++++++-------- .../test/Internal/Protocol/AckPipeTests.cs | 186 ++++++++++++++++++ ...oft.AspNetCore.SignalR.Common.Tests.csproj | 3 + 16 files changed, 424 insertions(+), 245 deletions(-) create mode 100644 src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index 78ce75e8a35c..797ecee716c8 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.Buffers; using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; @@ -13,7 +12,6 @@ using System.Net; using System.Reflection; using System.Runtime.CompilerServices; -using System.Text; using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; @@ -463,10 +461,10 @@ public virtual async Task SendCoreAsync(string methodName, object?[] args, Cance } } - private async Task StartAsyncCore(CancellationToken cancellationToken, bool sendHandshake = true) + private async Task StartAsyncCore(CancellationToken cancellationToken) { _state.AssertInConnectionLock(); - //SafeAssert(_state.CurrentConnectionStateUnsynchronized == null, "We already have a connection!"); + SafeAssert(_state.CurrentConnectionStateUnsynchronized == null, "We already have a connection!"); cancellationToken.ThrowIfCancellationRequested(); @@ -474,37 +472,24 @@ private async Task StartAsyncCore(CancellationToken cancellationToken, bool send Log.Starting(_logger); - if (_state.CurrentConnectionStateUnsynchronized is not null) - { - // public Task StartAsync(CancellationToken cancellationToken = default) - - var method = _state.CurrentConnectionStateUnsynchronized.Connection.GetType().GetMethod("StartAsync", new Type[] { typeof(CancellationToken) }); - await ((Task)method.Invoke(_state.CurrentConnectionStateUnsynchronized.Connection, new object[] { cancellationToken })).ConfigureAwait(false); - _state.CurrentConnectionStateUnsynchronized.ReceiveTask = ReceiveLoop(_state.CurrentConnectionStateUnsynchronized); - return; - } - // Start the connection var connection = await _connectionFactory.ConnectAsync(_endPoint, cancellationToken).ConfigureAwait(false); var startingConnectionState = new ConnectionState(connection, this); - if (sendHandshake) + // From here on, if an error occurs we need to shut down the connection because + // we still own it. + try { - // From here on, if an error occurs we need to shut down the connection because - // we still own it. - try - { - Log.HubProtocol(_logger, _protocol.Name, _protocol.Version); - await HandshakeAsync(startingConnectionState, cancellationToken).ConfigureAwait(false); - } - catch (Exception ex) - { - Log.ErrorStartingConnection(_logger, ex); + Log.HubProtocol(_logger, _protocol.Name, _protocol.Version); + await HandshakeAsync(startingConnectionState, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + Log.ErrorStartingConnection(_logger, ex); - // Can't have any invocations to cancel, we're in the lock. - await CloseAsync(startingConnectionState.Connection).ConfigureAwait(false); - throw; - } + // Can't have any invocations to cancel, we're in the lock. + await CloseAsync(startingConnectionState.Connection).ConfigureAwait(false); + throw; } // Set this at the end to avoid setting internal state until the connection is real @@ -1345,17 +1330,6 @@ async Task StartProcessingInvocationMessages(ChannelReader in var result = await input.ReadAsync().ConfigureAwait(false); var buffer = result.Buffer; - LogBytes(buffer.ToArray(), _logger); - void LogBytes(Memory memory, ILogger logger) - { - var sb = new StringBuilder(); - foreach (var b in memory.Span) - { - sb.Append($"0x{b:x} "); - } - logger.LogDebug(sb.ToString()); - } - try { if (result.IsCanceled) @@ -1452,16 +1426,15 @@ private async Task HandleConnectionClose(ConnectionState connectionState) await _state.WaitConnectionLockAsync(token: default).ConfigureAwait(false); try { - //SafeAssert(ReferenceEquals(_state.CurrentConnectionStateUnsynchronized, connectionState), - // "Someone other than ReceiveLoop cleared the connection state!"); - //_state.CurrentConnectionStateUnsynchronized = null; + SafeAssert(ReferenceEquals(_state.CurrentConnectionStateUnsynchronized, connectionState), + "Someone other than ReceiveLoop cleared the connection state!"); + _state.CurrentConnectionStateUnsynchronized = null; - await ((ValueTask)connectionState.Connection.GetType().GetMethod("CloseAsync").Invoke(connectionState.Connection, Array.Empty())).ConfigureAwait(false); // Dispose the connection - //await CloseAsync(connectionState.Connection).ConfigureAwait(false); + await CloseAsync(connectionState.Connection).ConfigureAwait(false); // Cancel any outstanding invocations within the connection lock - //connectionState.CancelOutstandingInvocations(connectionState.CloseException); + connectionState.CancelOutstandingInvocations(connectionState.CloseException); if (connectionState.Stopping || _reconnectPolicy == null) { @@ -1586,11 +1559,10 @@ private async Task ReconnectAsync(Exception? closeException) await _state.WaitConnectionLockAsync(token: default).ConfigureAwait(false); try { - //SafeAssert(ReferenceEquals(_state.CurrentConnectionStateUnsynchronized, null), - // "Someone other than Reconnect set the connection state!"); + SafeAssert(ReferenceEquals(_state.CurrentConnectionStateUnsynchronized, null), + "Someone other than Reconnect set the connection state!"); - // TODO: sendHandshake needs to be determined by something - await StartAsyncCore(_state.StopCts.Token, sendHandshake: false).ConfigureAwait(false); + await StartAsyncCore(_state.StopCts.Token).ConfigureAwait(false); Log.Reconnected(_logger, previousReconnectAttempts, DateTime.UtcNow - reconnectStartTime); diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs index b33af85fa9ec..4434d8c02d6e 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs @@ -3,11 +3,9 @@ using System; using System.Collections.Generic; -using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.IO.Pipelines; using System.Linq; -using System.Net; using System.Net.Http; using System.Net.WebSockets; using System.Threading; @@ -22,16 +20,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Client; -internal interface IConnectionCanRetry -{ - Task StartAsync(CancellationToken cancellationToken); - ValueTask CloseAsync(); -} - /// /// Used to make a connection to an ASP.NET Core ConnectionHandler using an HTTP-based transport. /// -public partial class HttpConnection : ConnectionContext, IConnectionInherentKeepAliveFeature, IConnectionCanRetry +public partial class HttpConnection : ConnectionContext, IConnectionInherentKeepAliveFeature { // Not configurable on purpose, high enough that if we reach here, it's likely // a buggy server @@ -57,7 +49,6 @@ public partial class HttpConnection : ConnectionContext, IConnectionInherentKeep private readonly ILoggerFactory _loggerFactory; private readonly Uri _url; private Func>? _accessTokenProvider; - private Uri? _connectUrl; /// public override IDuplexPipe Transport @@ -316,26 +307,13 @@ private async Task SelectAndStartTransport(TransferFormat transferFormat, Cancel var transportExceptions = new List(); - var skipNegotiation = _httpConnectionOptions.SkipNegotiation; - if (!string.IsNullOrEmpty(_connectionId)) - { - skipNegotiation = true; - } - - if (skipNegotiation) + if (_httpConnectionOptions.SkipNegotiation) { if (_httpConnectionOptions.Transports == HttpTransportType.WebSockets) { Log.StartingTransport(_logger, _httpConnectionOptions.Transports, uri); await StartTransport(uri, _httpConnectionOptions.Transports, transferFormat, cancellationToken, false).ConfigureAwait(false); } - else if (skipNegotiation && !_httpConnectionOptions.SkipNegotiation) - { - Debug.Assert(_connectUrl is not null); - //Log.StartingTransport(_logger, _httpConnectionOptions.Transports, uri); - HttpTransportType transport = _transport is WebSocketsTransport ? HttpTransportType.WebSockets : HttpTransportType.LongPolling; - await StartTransport(_connectUrl, transport, transferFormat, cancellationToken, false).ConfigureAwait(false); - } else { throw new InvalidOperationException("Negotiation can only be skipped when using the WebSocket transport directly."); @@ -372,7 +350,7 @@ private async Task SelectAndStartTransport(TransferFormat transferFormat, Cancel } // This should only need to happen once - _connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionToken); + var connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionToken); // We're going to search for the transfer format as a string because we don't want to parse // all the transfer formats in the negotiation response, and we want to allow transfer formats @@ -420,11 +398,11 @@ private async Task SelectAndStartTransport(TransferFormat transferFormat, Cancel if (negotiationResponse == null) { negotiationResponse = await GetNegotiationResponseAsync(uri, cancellationToken).ConfigureAwait(false); - _connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionToken); + connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionToken); } Log.StartingTransport(_logger, transportType, uri); - await StartTransport(_connectUrl, transportType, transferFormat, cancellationToken, negotiationResponse.UseAcking).ConfigureAwait(false); + await StartTransport(connectUrl, transportType, transferFormat, cancellationToken, negotiationResponse.UseAcking).ConfigureAwait(false); break; } } @@ -524,12 +502,8 @@ private static Uri CreateConnectUrl(Uri url, string? connectionId) private async Task StartTransport(Uri connectUrl, HttpTransportType transportType, TransferFormat transferFormat, CancellationToken cancellationToken, bool useAck) { - var transport = _transport; - if (transport is null) - { - // Construct the transport - transport = _transportFactory.CreateTransport(transportType, useAck); - } + // Construct the transport + var transport = _transportFactory.CreateTransport(transportType, useAck); // Start the transport, giving it one end of the pipe try @@ -730,41 +704,4 @@ private async Task GetNegotiationResponseAsync(Uri uri, Can _logScope.ConnectionId = _connectionId; return negotiationResponse; } - - public async ValueTask CloseAsync() - { - await _connectionLock.WaitAsync().ConfigureAwait(false); - try - { - if (_started) - { - Log.DisposingHttpConnection(_logger); - - // Stop the transport, but we don't care if it throws. - // The transport should also have completed the pipe with this exception. - try - { - await _transport!.StopAsync().ConfigureAwait(false); - } - catch (Exception ex) - { - Log.TransportThrewExceptionOnStop(_logger, ex); - } - - Log.Disposed(_logger); - } - else - { - Log.SkippingDispose(_logger); - } - } - finally - { - // We want to do these things even if the WaitForWriterToComplete/WaitForReaderToComplete fails - - _started = false; - - _connectionLock.Release(); - } - } } diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs index 9f8137ee7e40..7441aa629588 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs @@ -36,9 +36,9 @@ internal sealed partial class WebSocketsTransport : ITransport private readonly HttpClient? _httpClient; private CancellationTokenSource _stopCts = default!; private readonly bool _useAck; - private readonly Func> _accessTokenProvider; private IDuplexPipe? _transport; + private bool _closed; internal Task Running { get; private set; } = Task.CompletedTask; @@ -57,7 +57,7 @@ public WebSocketsTransport(HttpConnectionOptions httpConnectionOptions, ILoggerF // We were given an updated delegate from the HttpConnection and we are updating what we have in httpOptions // options itself is copied object of user's options - _accessTokenProvider = accessTokenProvider; + _httpConnectionOptions.AccessTokenProvider = accessTokenProvider; _httpClient = httpClient; } @@ -210,14 +210,14 @@ static bool IsX509CertificateCollectionEqual(X509CertificateCollection? left, X5 } } - if (_accessTokenProvider != null + if (_httpConnectionOptions.AccessTokenProvider != null #if NET7_0_OR_GREATER && webSocket.Options.HttpVersion < HttpVersion.Version20 #endif ) { // Apply access token logic when using HTTP/1.1 because we don't use the AccessTokenHttpMessageHandler via HttpClient unless the user specifies HTTP/2.0 or higher - var accessToken = await _accessTokenProvider().ConfigureAwait(false); + var accessToken = await _httpConnectionOptions.AccessTokenProvider().ConfigureAwait(false); if (!string.IsNullOrWhiteSpace(accessToken)) { // We can't use request headers in the browser, so instead append the token as a query string in that case @@ -336,7 +336,7 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio // TODO: Handle TCP connection errors // https://github.com/SignalR/SignalR/blob/1fba14fa3437e24c204dfaf8a18db3fce8acad3c/src/Microsoft.AspNet.SignalR.Core/Owin/WebSockets/WebSocketHandler.cs#L248-L251 - Running = ProcessSocketAsync(_webSocket); + Running = ProcessSocketAsync(_webSocket, url); static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) { @@ -354,7 +354,7 @@ static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions } } - private async Task ProcessSocketAsync(WebSocket socket) + private async Task ProcessSocketAsync(WebSocket socket, Uri url) { Debug.Assert(_application != null); @@ -375,8 +375,11 @@ private async Task ProcessSocketAsync(WebSocket socket) // 1. Waiting for application data // 2. Waiting for a websocket send to complete - // Cancel the application so that ReadAsync yields - //_application.Input.CancelPendingRead(); + if (_closed) + { + // Cancel the application so that ReadAsync yields + _application.Input.CancelPendingRead(); + } var resultTask = await Task.WhenAny(sending, Task.Delay(_closeTimeout, _stopCts.Token)).ConfigureAwait(false); @@ -400,9 +403,19 @@ private async Task ProcessSocketAsync(WebSocket socket) socket.Abort(); // Cancel any pending flush so that we can quit - //_application.Output.CancelPendingFlush(); + if (_closed) + { + _application.Output.CancelPendingFlush(); + } } } + + Console.WriteLine("closed socket"); + + if (_useAck && !_closed) + { + await StartAsync(url, _webSocketMessageType == WebSocketMessageType.Binary ? TransferFormat.Binary : TransferFormat.Text, default).ConfigureAwait(false); + } } private async Task StartReceiving(WebSocket socket) @@ -419,6 +432,7 @@ private async Task StartReceiving(WebSocket socket) if (result.MessageType == WebSocketMessageType.Close) { + _closed = true; Log.WebSocketClosed(_logger, socket.CloseStatus); if (socket.CloseStatus != WebSocketCloseStatus.NormalClosure) @@ -445,6 +459,7 @@ private async Task StartReceiving(WebSocket socket) // Need to check again for netstandard2.1 because a close can happen between a 0-byte read and the actual read if (receiveResult.MessageType == WebSocketMessageType.Close) { + _closed = true; Log.WebSocketClosed(_logger, socket.CloseStatus); if (socket.CloseStatus != WebSocketCloseStatus.NormalClosure) @@ -489,13 +504,17 @@ void LogBytes(Memory memory, ILogger logger) { if (!_aborted) { - //_application.Output.Complete(ex); + _application.Output.Complete(ex); + _closed = true; } } finally { // We're done writing - //_application.Output.Complete(); + if (_closed) + { + _application.Output.Complete(); + } Log.ReceiveStopped(_logger); } @@ -547,6 +566,7 @@ void LogBytes(Memory memory, ILogger logger) } else { + socket.Dispose(); break; } } @@ -598,7 +618,10 @@ void LogBytes(Memory memory, ILogger logger) } } - //_application.Input.Complete(); + if (_closed) + { + _application.Input.Complete(); + } Log.SendStopped(_logger); } @@ -628,6 +651,7 @@ private static Uri ResolveWebSocketsUrl(Uri url) public async Task StopAsync() { + _closed = true; Log.TransportStopping(_logger); if (_application == null) @@ -636,11 +660,11 @@ public async Task StopAsync() return; } - //_transport!.Output.Complete(); - //_transport!.Input.Complete(); + _transport!.Output.Complete(); + _transport!.Input.Complete(); // Cancel any pending reads from the application, this should start the entire shutdown process - //_application.Input.CancelPendingRead(); + _application.Input.CancelPendingRead(); // Start ungraceful close timer _stopCts.CancelAfter(_closeTimeout); diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs index b5d34bc6500f..a93156e81233 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs @@ -113,7 +113,7 @@ public HttpConnectionContext(string connectionId, string connectionToken, ILogge internal bool IsAuthenticationExpirationEnabled => _options.CloseOnAuthenticationExpiration; - public Task? TransportTask { get; set; } + public Task? TransportTask { get; set; } public Task PreviousPollTask { get; set; } = Task.CompletedTask; @@ -443,7 +443,12 @@ public bool TryActivateLongPollingConnection( // On the first poll, we flush the response immediately to mark the poll as "initialized" so future // requests can be made safely - TransportTask = nonClonedContext.Response.Body.FlushAsync(); + TransportTask = Func(); + async Task Func() + { + await nonClonedContext.Response.Body.FlushAsync(); + return false; + }; } else { diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs index 3190585b64fe..c129671b4dee 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs @@ -153,7 +153,10 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti // We only need to provide the Input channel since writing to the application is handled through /send. var sse = new ServerSentEventsServerTransport(connection.Application.Input, connection.ConnectionId, connection, _loggerFactory); - await DoPersistentConnection(connectionDelegate, sse, context, connection); + if (connection.TryActivatePersistentConnection(connectionDelegate, sse, Task.CompletedTask, context, _logger)) + { + await DoPersistentConnection(connectionDelegate, sse, context, connection); + } } //else if (context.WebSockets.IsWebSocketRequest) //{ @@ -195,8 +198,17 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti AddNoCacheHeaders(context.Response); } - // Connection must already exist - var connection = await GetConnectionAsync(context); + HttpConnectionContext? connection; + if (transport == HttpTransportType.WebSockets) + { + connection = await GetOrCreateConnectionAsync(context, options); + } + else + { + // Connection must already exist + connection = await GetConnectionAsync(context); + } + if (connection == null) { // No such connection, GetConnection already set the response status code @@ -266,8 +278,15 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti } else { - // Only allow repoll if we aren't removing the connection. - connection.MarkInactive(); + if (transport != HttpTransportType.LongPolling) + { + await _manager.DisposeAndRemoveAsync(connection, closeGracefully: false); + } + else + { + // Only allow repoll if we aren't removing the connection. + connection.MarkInactive(); + } } } else if (resultTask.IsFaulted || resultTask.IsCanceled) @@ -280,8 +299,17 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti } else { - // Only allow repoll if we aren't removing the connection. - connection.MarkInactive(); + Console.WriteLine("waiting transporttask"); + if (await connection.TransportTask!) + { + Console.WriteLine("disposing"); + await _manager.DisposeAndRemoveAsync(connection, closeGracefully: true); + } + else + { + // Only allow repoll if we aren't removing the connection. + connection.MarkInactive(); + } } } finally diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs index 8511c4a618b8..11da9a2ecc7d 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs @@ -176,7 +176,7 @@ public void Scan() // Once the decision has been made to dispose we don't check the status again // But don't clean up connections while the debugger is attached. - if (!true && lastSeenTick.HasValue && (ticks - lastSeenTick.Value) > _disconnectTimeoutTicks) + if (!Debugger.IsAttached && lastSeenTick.HasValue && (ticks - lastSeenTick.Value) > _disconnectTimeoutTicks) { Log.ConnectionTimedOut(_logger, connection.ConnectionId); HttpConnectionsEventSource.Log.ConnectionTimedOut(connection.ConnectionId); diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/IHttpTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/IHttpTransport.cs index af9ab19f4ef4..557c42c550e5 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/IHttpTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/IHttpTransport.cs @@ -11,5 +11,5 @@ internal interface IHttpTransport /// /// /// A that completes when the transport has finished processing - Task ProcessRequestAsync(HttpContext context, CancellationToken token); + Task ProcessRequestAsync(HttpContext context, CancellationToken token); } diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/LongPollingServerTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/LongPollingServerTransport.cs index e3e360328848..34251849d8d1 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/LongPollingServerTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/LongPollingServerTransport.cs @@ -29,7 +29,7 @@ public LongPollingServerTransport(CancellationToken timeoutToken, PipeReader app _logger = loggerFactory.CreateLogger("Microsoft.AspNetCore.Http.Connections.Internal.Transports.LongPollingTransport"); } - public async Task ProcessRequestAsync(HttpContext context, CancellationToken token) + public async Task ProcessRequestAsync(HttpContext context, CancellationToken token) { try { @@ -43,7 +43,7 @@ public async Task ProcessRequestAsync(HttpContext context, CancellationToken tok Log.LongPolling204(_logger); context.Response.ContentType = "text/plain"; context.Response.StatusCode = StatusCodes.Status204NoContent; - return; + return false; } // We're intentionally not checking cancellation here because we need to drain messages we've got so far, @@ -109,6 +109,7 @@ public async Task ProcessRequestAsync(HttpContext context, CancellationToken tok context.Response.StatusCode = StatusCodes.Status500InternalServerError; throw; } + return false; } private static partial class Log diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/ServerSentEventsServerTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/ServerSentEventsServerTransport.cs index 72c3c816b423..bf4f9d85d97b 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/ServerSentEventsServerTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/ServerSentEventsServerTransport.cs @@ -28,7 +28,7 @@ public ServerSentEventsServerTransport(PipeReader application, string connection _logger = loggerFactory.CreateLogger("Microsoft.AspNetCore.Http.Connections.Internal.Transports.ServerSentEventsTransport"); } - public async Task ProcessRequestAsync(HttpContext context, CancellationToken cancellationToken) + public async Task ProcessRequestAsync(HttpContext context, CancellationToken cancellationToken) { context.Response.ContentType = "text/event-stream"; context.Response.Headers.CacheControl = "no-cache,no-store"; @@ -81,6 +81,8 @@ public async Task ProcessRequestAsync(HttpContext context, CancellationToken can { // Closed connection } + + return true; } private static partial class Log diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs index 7b94a157e999..96c1823a249e 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs @@ -6,7 +6,6 @@ using System.IO.Pipelines; using System.Net.WebSockets; using System.Text; -using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.Logging; using PipelinesOverNetwork; @@ -20,6 +19,8 @@ internal sealed partial class WebSocketsServerTransport : IHttpTransport private readonly HttpConnectionContext _connection; private volatile bool _aborted; + private bool _closed; + public WebSocketsServerTransport(WebSocketOptions options, IDuplexPipe application, HttpConnectionContext connection, ILoggerFactory loggerFactory) { ArgumentNullException.ThrowIfNull(options); @@ -34,7 +35,7 @@ public WebSocketsServerTransport(WebSocketOptions options, IDuplexPipe applicati _logger = loggerFactory.CreateLogger("Microsoft.AspNetCore.Http.Connections.Internal.Transports.WebSocketsTransport"); } - public async Task ProcessRequestAsync(HttpContext context, CancellationToken token) + public async Task ProcessRequestAsync(HttpContext context, CancellationToken token) { Debug.Assert(context.WebSockets.IsWebSocketRequest, "Not a websocket request"); @@ -53,6 +54,8 @@ public async Task ProcessRequestAsync(HttpContext context, CancellationToken tok Log.SocketClosed(_logger); } } + + return _closed; } public async Task ProcessSocketAsync(WebSocket socket) @@ -153,6 +156,7 @@ private async Task StartReceiving(WebSocket socket) if (result.MessageType == WebSocketMessageType.Close) { + _closed = true; return; } @@ -163,6 +167,7 @@ private async Task StartReceiving(WebSocket socket) // Need to check again for netcoreapp3.0 and later because a close can happen between a 0-byte read and the actual read if (receiveResult.MessageType == WebSocketMessageType.Close) { + _closed = true; return; } @@ -205,13 +210,17 @@ void LogBytes(Memory memory, ILogger logger) { if (!_aborted && !token.IsCancellationRequested) { + _closed = true; _application.Output.Complete(ex); } } finally { - // We're done writing - //_application.Output.Complete(); + if (_closed) + { + // We're done writing + _application.Output.Complete(); + } } } @@ -269,6 +278,12 @@ void LogBytes(Memory memory, ILogger logger) break; } } + catch (OperationCanceledException ex) when (ex.CancellationToken == _connection.SendingToken) + { + _closed = true; + // Log + break; + } catch (Exception ex) { if (!_aborted) @@ -310,12 +325,15 @@ void LogBytes(Memory memory, ILogger logger) } } - if (error is not null) + //if (error is not null) + //{ + // _logger.LogError("Error in send {ex}.", error); + //} + + if (_closed) { - _logger.LogError("Error in send {ex}.", error); + _application.Input.Complete(); } - - //_application.Input.Complete(); } } diff --git a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs index 518ca1646d2a..d7f044089634 100644 --- a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs +++ b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs @@ -1366,7 +1366,7 @@ public async Task WebSocketTransportTimesOutWhenCloseFrameNotReceived() } [Theory] - [InlineData(HttpTransportType.WebSockets)] + //[InlineData(HttpTransportType.WebSockets)] [InlineData(HttpTransportType.ServerSentEvents)] public async Task RequestToActiveConnectionId409ForStreamingTransports(HttpTransportType transportType) { @@ -1722,7 +1722,7 @@ public async Task AttemptingToPollWhileAlreadyPollingReplacesTheCurrentPoll() Assert.Equal(string.Empty, GetContentAsString(context1.Response.Body)); AssertResponseHasCacheHeaders(context1.Response); Assert.Equal(StatusCodes.Status200OK, context2.Response.StatusCode); - Assert.Equal("Hello, World", GetContentAsString(context2.Response.Body).AsSpan(16).ToString()); + Assert.Equal("Hello, World", GetContentAsString(context2.Response.Body)); AssertResponseHasCacheHeaders(context2.Response); } } diff --git a/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs b/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs index ff067e567f4a..a08e47a0b79f 100644 --- a/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs +++ b/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs @@ -95,6 +95,7 @@ public async Task DisposingConnectionsClosesBothSidesOfThePipe(ConnectionStates { throw new Exception("Transport failed"); } + return false; }); } @@ -102,7 +103,7 @@ public async Task DisposingConnectionsClosesBothSidesOfThePipe(ConnectionStates { // If the transport is faulted then we want to make sure the transport task only completes after // the application completes - connection.TransportTask = Task.FromException(new Exception("Application failed")); + connection.TransportTask = Task.FromException(new Exception("Application failed")); connection.ApplicationTask = Task.Run(async () => { // Wait for the application to end @@ -113,7 +114,7 @@ public async Task DisposingConnectionsClosesBothSidesOfThePipe(ConnectionStates else { connection.ApplicationTask = Task.CompletedTask; - connection.TransportTask = Task.CompletedTask; + connection.TransportTask = Task.FromResult(true); } try @@ -271,6 +272,7 @@ public async Task CloseConnectionsEndsAllPendingConnections() { connection.Application.Input.AdvanceTo(result.Buffer.End); } + return true; }); connectionManager.CloseConnections(); @@ -286,7 +288,7 @@ public async Task DisposingConnectionMultipleTimesWaitsOnConnectionClose() { var connectionManager = CreateConnectionManager(LoggerFactory); var connection = connectionManager.CreateConnection(); - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); connection.ApplicationTask = tcs.Task; connection.TransportTask = tcs.Task; @@ -296,7 +298,7 @@ public async Task DisposingConnectionMultipleTimesWaitsOnConnectionClose() Assert.False(firstTask.IsCompleted); Assert.False(secondTask.IsCompleted); - tcs.TrySetResult(); + tcs.TrySetResult(true); await Task.WhenAll(firstTask, secondTask).DefaultTimeout(); } @@ -309,7 +311,7 @@ public async Task DisposingConnectionMultipleGetsExceptionFromTransportOrApp() { var connectionManager = CreateConnectionManager(LoggerFactory); var connection = connectionManager.CreateConnection(); - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); connection.ApplicationTask = tcs.Task; connection.TransportTask = tcs.Task; @@ -336,7 +338,7 @@ public async Task DisposingConnectionMultipleGetsCancellation() { var connectionManager = CreateConnectionManager(LoggerFactory); var connection = connectionManager.CreateConnection(); - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); connection.ApplicationTask = tcs.Task; connection.TransportTask = tcs.Task; diff --git a/src/SignalR/common/Shared/AcknowledgePipeV2.cs b/src/SignalR/common/Shared/AcknowledgePipeV2.cs index 4a8346258d56..72691709c363 100644 --- a/src/SignalR/common/Shared/AcknowledgePipeV2.cs +++ b/src/SignalR/common/Shared/AcknowledgePipeV2.cs @@ -11,13 +11,14 @@ namespace PipelinesOverNetwork internal sealed class AckPipeReader : PipeReader { private readonly PipeReader _inner; + private readonly object _lock = new object(); + private SequencePosition _consumed; private SequencePosition _ackPosition; private long _ackDiff; private long _ackId; private long _totalWritten; private bool _resend; - private object _lock = new object(); public AckPipeReader(PipeReader inner) { diff --git a/src/SignalR/common/Shared/ParseAckPipeReader.cs b/src/SignalR/common/Shared/ParseAckPipeReader.cs index 272b3f5ab7d9..b92b51707b79 100644 --- a/src/SignalR/common/Shared/ParseAckPipeReader.cs +++ b/src/SignalR/common/Shared/ParseAckPipeReader.cs @@ -5,104 +5,104 @@ using System.Threading; using System.Threading.Tasks; -namespace PipelinesOverNetwork +namespace PipelinesOverNetwork; + +// Read from "network" +// Parse framing and slice the read so the application doesn't see the framing +// Notify outbound pipe of framing details for when sending back +internal class ParseAckPipeReader : PipeReader { - // Read from "network" - // Parse framing and slice the read so the application doesn't see the framing - // Notify outbound pipe of framing details for when sending back - internal class ParseAckPipeReader : PipeReader + private readonly PipeReader _inner; + private readonly AckPipeWriter _ackPipeWriter; + private readonly AckPipeReader _ackPipeReader; + private long _totalBytes; + + private ReadOnlySequence _currentRead; + + public ParseAckPipeReader(PipeReader inner, AckPipeWriter ackPipeWriter, AckPipeReader ackPipeReader) { - private readonly PipeReader _inner; - private readonly AckPipeWriter _ackPipeWriter; - private readonly AckPipeReader _ackPipeReader; - private long _totalBytes; + _inner = inner; + _ackPipeWriter = ackPipeWriter; + _ackPipeReader = ackPipeReader; + } - private ReadOnlySequence _currentRead; + public override void AdvanceTo(SequencePosition consumed) + { + var len = _currentRead.Length - _currentRead.Slice(consumed).Length; + //Console.WriteLine($"lastack: {_ackPipeWriter.lastAck} to {_ackPipeWriter.lastAck + len}"); + // ignore the empty length send, maybe don't return from ReadAsync instead? + _ackPipeWriter.lastAck += (len == 16) ? 0 : len; + _inner.AdvanceTo(consumed); + } - public ParseAckPipeReader(PipeReader inner, AckPipeWriter ackPipeWriter, AckPipeReader ackPipeReader) - { - _inner = inner; - _ackPipeWriter = ackPipeWriter; - _ackPipeReader = ackPipeReader; - } + public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) + { + var len = _currentRead.Length - _currentRead.Slice(consumed).Length; + //Console.WriteLine($"lastack: {_ackPipeWriter.lastAck} to {_ackPipeWriter.lastAck + len}"); + _ackPipeWriter.lastAck += (len == 16) ? 0 : len; + // Track? + _inner.AdvanceTo(consumed, examined); + } - public override void AdvanceTo(SequencePosition consumed) - { - var len =_currentRead.Length - _currentRead.Slice(consumed).Length; - //Console.WriteLine($"lastack: {_ackPipeWriter.lastAck} to {_ackPipeWriter.lastAck + len}"); - // ignore the empty length send, maybe don't return from ReadAsync instead? - _ackPipeWriter.lastAck += (len == 16) ? 0 : len; - _inner.AdvanceTo(consumed); - } + public override void CancelPendingRead() + { + _inner.CancelPendingRead(); + } - public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) - { - var len = _currentRead.Length - _currentRead.Slice(consumed).Length; - //Console.WriteLine($"lastack: {_ackPipeWriter.lastAck} to {_ackPipeWriter.lastAck + len}"); - _ackPipeWriter.lastAck += (len == 16) ? 0 : len; - // Track? - _inner.AdvanceTo(consumed, examined); - } + public override void Complete(Exception? exception = null) + { + _inner.Complete(exception); + } - public override void CancelPendingRead() + public override async ValueTask ReadAsync(CancellationToken cancellationToken = default) + { + var res = await _inner.ReadAsync(cancellationToken).ConfigureAwait(false); + _currentRead = res.Buffer; + + if (res.IsCompleted || res.IsCanceled) { - _inner.CancelPendingRead(); + if (res.Buffer.Length >= 16) + res = new(res.Buffer.Slice(16), res.IsCanceled, res.IsCompleted); + return res; } - public override void Complete(Exception? exception = null) + // TODO: handle previous payload not fully received + // TODO: handle multiple frame prefixed messages + var frame = res.Buffer.Slice(0, 16); + var len = ParseFrame(frame, _ackPipeReader); + _totalBytes += len; + // 0 len sent on reconnect and not part of acks + if (len != 0) { - _inner.Complete(exception); + //Console.WriteLine($"lastack: {_ackPipeWriter.lastAck} to {_ackPipeWriter.lastAck + res.Buffer.Length}"); + //_ackPipeWriter.lastAck += res.Buffer.Length; } - public override async ValueTask ReadAsync(CancellationToken cancellationToken = default) + // TODO: validation everywhere! + Debug.Assert(len < res.Buffer.Length); + + res = new(res.Buffer.Slice(16, len), res.IsCanceled, res.IsCompleted); + return res; + + static long ParseFrame(ReadOnlySequence frame, AckPipeReader ackPipeReader) { - var res = await _inner.ReadAsync(cancellationToken); - if (res.IsCompleted || res.IsCanceled) - { - if (res.Buffer.Length >= 16) - res = new(res.Buffer.Slice(16), res.IsCanceled, res.IsCompleted); - return res; - } - - _currentRead = res.Buffer; - // TODO: handle previous payload not fully received - // TODO: handle multiple frame prefixed messages - var frame = res.Buffer.Slice(0, 16); - var len = ParseFrame(frame, _ackPipeReader); - _totalBytes += len; - // 0 len sent on reconnect and not part of acks - if (len != 0) - { - //Console.WriteLine($"lastack: {_ackPipeWriter.lastAck} to {_ackPipeWriter.lastAck + res.Buffer.Length}"); - //_ackPipeWriter.lastAck += res.Buffer.Length; - } - - // TODO: validation everywhere! - Debug.Assert(len < res.Buffer.Length); - - res = new(res.Buffer.Slice(16, len), res.IsCanceled, res.IsCompleted); - return res; + Span buffer = stackalloc byte[16]; + frame.CopyTo(buffer); - static long ParseFrame(ReadOnlySequence frame, AckPipeReader ackPipeReader) - { - Span buffer = stackalloc byte[16]; - frame.CopyTo(buffer); - // TODO: use these values #if NETSTANDARD2_1_OR_GREATER - var len = BitConverter.ToInt64(buffer); - var ackId = BitConverter.ToInt64(buffer.Slice(8)); + var len = BitConverter.ToInt64(buffer); + var ackId = BitConverter.ToInt64(buffer.Slice(8)); #else - var len = BitConverter.ToInt64(buffer.Slice(0, 8).ToArray(), 0); - var ackId = BitConverter.ToInt64(buffer.Slice(8).ToArray(), 0); + var len = BitConverter.ToInt64(buffer.Slice(0, 8).ToArray(), 0); + var ackId = BitConverter.ToInt64(buffer.Slice(8).ToArray(), 0); #endif - ackPipeReader.Ack(ackId); - return len; - } + ackPipeReader.Ack(ackId); + return len; } + } - public override bool TryRead(out ReadResult result) - { - throw new NotImplementedException(); - } + public override bool TryRead(out ReadResult result) + { + throw new NotImplementedException(); } } diff --git a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs new file mode 100644 index 000000000000..a9e8a93919fe --- /dev/null +++ b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs @@ -0,0 +1,186 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Buffers; +using System.Collections; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using PipelinesOverNetwork; +using static PipelinesOverNetwork.AckDuplexPipe; + +namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol; + +public class AckPipeTests +{ + [Fact] + public async Task CanSendAndReceiveTransport() + { + var duplexPipe = AckDuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions()); + + var values = new byte[] { 1, 2, 3, 4, 5 }; + var flushRes = await duplexPipe.Transport.Output.WriteAsync(values); + + Assert.False(flushRes.IsCanceled); + Assert.False(flushRes.IsCompleted); + + var readResult = await duplexPipe.Application.Input.ReadAsync(); + + Assert.False(readResult.IsCanceled); + Assert.False(readResult.IsCompleted); + Assert.Equal(values.Length, readResult.Buffer.Length); + Assert.Equal(values, readResult.Buffer.ToArray()); + } + + [Fact] + public async Task CanSendAndReceiveLargeAmount() + { + var duplexPipe = AckDuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions()); + + var values = new byte[20000]; + Random.Shared.NextBytes(values); + var flushRes = await duplexPipe.Transport.Output.WriteAsync(values); + + Assert.False(flushRes.IsCanceled); + Assert.False(flushRes.IsCompleted); + + var readResult = await duplexPipe.Application.Input.ReadAsync(); + + Assert.False(readResult.IsCanceled); + Assert.False(readResult.IsCompleted); + Assert.Equal(values.Length, readResult.Buffer.Length); + Assert.Equal(values, readResult.Buffer.ToArray()); + } + + [Fact] + public async Task CanSendAndReceiveLargeAmount_ManyWritesSingleFlush() + { + var duplexPipe = AckDuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions()); + + var values = new byte[20000]; + Random.Shared.NextBytes(values); + var written = 0; + while (written < values.Length) + { + var mem = duplexPipe.Transport.Output.GetMemory(); + var toWrite = Math.Min(mem.Length, values.Length - written); + values.AsSpan(written, toWrite).CopyTo(mem.Span); + duplexPipe.Transport.Output.Advance(toWrite); + written += toWrite; + } + + var flushRes = await duplexPipe.Transport.Output.FlushAsync(); + + Assert.False(flushRes.IsCanceled); + Assert.False(flushRes.IsCompleted); + + var readResult = await duplexPipe.Application.Input.ReadAsync(); + + Assert.False(readResult.IsCanceled); + Assert.False(readResult.IsCompleted); + Assert.Equal(values.Length, readResult.Buffer.Length); + Assert.Equal(values, readResult.Buffer.ToArray()); + } + + [Fact] + public async Task ReadFromTransportRemovesFraming() + { + var duplexPipe = CreateClient(); + + var buffer = new byte[20]; + WriteFrame(buffer, buffer.Length - 16, 0); + buffer[16] = 9; + buffer[17] = 9; + buffer[18] = 9; + buffer[19] = 9; + + // "write" from server + await duplexPipe.Application.Output.WriteAsync(buffer); + + // read in client application layer + var res = await duplexPipe.Transport.Input.ReadAsync(); + Assert.Equal(4, res.Buffer.Length); + Assert.Equal(buffer.AsSpan(16).ToArray(), res.Buffer.ToArray()); + } + + [Fact] + public async Task WriteFromApplicationAddsFraming() + { + var duplexPipe = CreateClient(); + + var buffer = new byte[20]; + Random.Shared.NextBytes(buffer); + + await duplexPipe.Transport.Output.WriteAsync(buffer); + + var res = await duplexPipe.Application.Input.ReadAsync(); + var framing = ReadFrame(res.Buffer.ToArray()); + + Assert.Equal(buffer.Length, framing.Length); + Assert.Equal(0, framing.AckId); + Assert.Equal(buffer.Length + 16, res.Buffer.Length); + Assert.Equal(buffer, res.Buffer.Slice(16).ToArray()); + } + + internal static DuplexPipePair CreateClient(PipeOptions inputOptions = default, PipeOptions outputOptions = default) + { + var input = new Pipe(inputOptions ?? new()); + var output = new Pipe(outputOptions ?? new()); + + // Use for one side only, this is client side + var ackWriter = new AckPipeWriter(output.Writer); + var ackReader = new AckPipeReader(output.Reader); + var transportReader = new ParseAckPipeReader(input.Reader, ackWriter, ackReader); + var transportToApplication = new DuplexPipe(ackReader, input.Writer); + var applicationToTransport = new DuplexPipe(transportReader, ackWriter); + + // Transport.Output.Write goes to Application.Input, which is read in the transport code + // Application.Output.Write goes to Transport.Input, which is read in the application code + + return new DuplexPipePair(applicationToTransport, transportToApplication); + } + + internal static DuplexPipePair CreateServer(PipeOptions inputOptions = default, PipeOptions outputOptions = default) + { + var input = new Pipe(inputOptions ?? new()); + var output = new Pipe(outputOptions ?? new()); + + // Use for one side only, this is server side + var ackWriter = new AckPipeWriter(output.Writer); + var ackReader = new AckPipeReader(output.Reader); + var transportReader = new ParseAckPipeReader(input.Reader, ackWriter, ackReader); + var transportToApplication = new DuplexPipe(ackReader, input.Writer); + var applicationToTransport = new DuplexPipe(transportReader, ackWriter); + + return new DuplexPipePair(transportToApplication, applicationToTransport); + } + + internal static void WriteFrame(byte[] header, long payloadLength, long ackId = 0) + { + Assert.True(header.Length >= 16); + + Assert.True(BitConverter.TryWriteBytes(header, payloadLength)); + Assert.True(BitConverter.TryWriteBytes(header.AsSpan(8), ackId)); + } + + internal static (long Length, long AckId) ReadFrame(byte[] buffer) + { + var len = BitConverter.ToInt64(buffer); + var ackId = BitConverter.ToInt64(buffer.AsSpan(8)); + + return (len, ackId); + } + + internal static (long PayloadLength, long AckId) ReadFrame(ref Span header) + { + Assert.True(header.Length >= 16); + + var len = BitConverter.ToInt64(header); + var ackId = BitConverter.ToInt64(header.Slice(8)); + + return (len, ackId); + } +} diff --git a/src/SignalR/common/SignalR.Common/test/Microsoft.AspNetCore.SignalR.Common.Tests.csproj b/src/SignalR/common/SignalR.Common/test/Microsoft.AspNetCore.SignalR.Common.Tests.csproj index 1120424ee602..a6631b548a30 100644 --- a/src/SignalR/common/SignalR.Common/test/Microsoft.AspNetCore.SignalR.Common.Tests.csproj +++ b/src/SignalR/common/SignalR.Common/test/Microsoft.AspNetCore.SignalR.Common.Tests.csproj @@ -7,6 +7,9 @@ + + + From 9d0e34f915bc1b34c6f770f6437b1aaf08331b66 Mon Sep 17 00:00:00 2001 From: Brennan Date: Tue, 28 Mar 2023 14:17:26 -0700 Subject: [PATCH 03/25] tests and base64 --- .../HttpConnectionTests.Negotiate.cs | 6 +- .../src/Internal/WebSocketsTransport.cs | 11 +- .../Transports/WebSocketsServerTransport.cs | 8 +- .../Shared/AcknowledgePipe/DuplexPipe.cs | 6 +- .../common/Shared/AcknowledgePipeV2.cs | 334 +++++----- .../common/Shared/ParseAckPipeReader.cs | 165 +++-- .../test/Internal/Protocol/AckPipeTests.cs | 607 +++++++++++++++++- 7 files changed, 923 insertions(+), 214 deletions(-) diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs index b7790ffc07c3..ba621a57e32f 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs @@ -508,7 +508,7 @@ public async Task StartSkipsOverTransportsThatTheClientDoesNotUnderstand() var transportFactory = new Mock(MockBehavior.Strict); - transportFactory.Setup(t => t.CreateTransport(HttpTransportType.LongPolling, true)) + transportFactory.Setup(t => t.CreateTransport(HttpTransportType.LongPolling, false)) .Returns(new TestTransport(transferFormat: TransferFormat.Text | TransferFormat.Binary)); using (var noErrorScope = new VerifyNoErrorsScope()) @@ -523,7 +523,7 @@ await WithConnectionAsync( } [Fact] - public async Task StartSkipsOverTransportsThatDoNotSupportTheRequredTransferFormat() + public async Task StartSkipsOverTransportsThatDoNotSupportTheRequiredTransferFormat() { var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); @@ -557,7 +557,7 @@ public async Task StartSkipsOverTransportsThatDoNotSupportTheRequredTransferForm var transportFactory = new Mock(MockBehavior.Strict); - transportFactory.Setup(t => t.CreateTransport(HttpTransportType.LongPolling, true)) + transportFactory.Setup(t => t.CreateTransport(HttpTransportType.LongPolling, false)) .Returns(new TestTransport(transferFormat: TransferFormat.Text | TransferFormat.Binary)); await WithConnectionAsync( diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs index 7441aa629588..528a0ed2d71d 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs @@ -268,10 +268,9 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio throw new ArgumentException($"The '{transferFormat}' transfer format is not supported by this transport.", nameof(transferFormat)); } - //_webSocketMessageType = transferFormat == TransferFormat.Binary - // ? WebSocketMessageType.Binary - // : WebSocketMessageType.Text; - _webSocketMessageType = WebSocketMessageType.Binary; + _webSocketMessageType = transferFormat == TransferFormat.Binary + ? WebSocketMessageType.Binary + : WebSocketMessageType.Text; var resolvedUrl = ResolveWebSocketsUrl(url); @@ -472,7 +471,7 @@ private async Task StartReceiving(WebSocket socket) Log.MessageReceived(_logger, receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage); - //LogBytes(memory.Slice(0, receiveResult.Count), _logger); + LogBytes(memory.Slice(0, receiveResult.Count), _logger); void LogBytes(Memory memory, ILogger logger) { var sb = new StringBuilder(); @@ -533,7 +532,7 @@ private async Task StartSending(WebSocket socket) var result = await _application.Input.ReadAsync().ConfigureAwait(false); var buffer = result.Buffer; - //LogBytes(buffer.ToArray(), _logger); + LogBytes(buffer.ToArray(), _logger); void LogBytes(Memory memory, ILogger logger) { var sb = new StringBuilder(); diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs index 96c1823a249e..91ac3c59125f 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs @@ -6,6 +6,7 @@ using System.IO.Pipelines; using System.Net.WebSockets; using System.Text; +using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.Logging; using PipelinesOverNetwork; @@ -250,10 +251,9 @@ private async Task StartSending(WebSocket socket) { Log.SendPayload(_logger, buffer.Length); - //var webSocketMessageType = (_connection.ActiveFormat == TransferFormat.Binary - // ? WebSocketMessageType.Binary - // : WebSocketMessageType.Text); - var webSocketMessageType = WebSocketMessageType.Binary; + var webSocketMessageType = (_connection.ActiveFormat == TransferFormat.Binary + ? WebSocketMessageType.Binary + : WebSocketMessageType.Text); LogBytes(buffer.ToArray(), _logger); diff --git a/src/SignalR/common/Shared/AcknowledgePipe/DuplexPipe.cs b/src/SignalR/common/Shared/AcknowledgePipe/DuplexPipe.cs index cf8f21f0809c..b6d72a7df146 100644 --- a/src/SignalR/common/Shared/AcknowledgePipe/DuplexPipe.cs +++ b/src/SignalR/common/Shared/AcknowledgePipe/DuplexPipe.cs @@ -1,4 +1,4 @@ -using System.IO.Pipelines; +using System.IO.Pipelines; namespace PipelinesOverNetwork { @@ -64,8 +64,8 @@ public static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, Pipe var ackReaderClient = new AckPipeReader(input.Reader); var transportReader = new ParseAckPipeReader(input.Reader, ackWriterApp, ackReaderApp); var applicationReader = new ParseAckPipeReader(ackReaderApp, ackWriterClient, ackReaderClient); - var transportToApplication = new DuplexPipe(applicationReader, ackWriterClient); - var applicationToTransport = new DuplexPipe(transportReader, ackWriterApp); + var transportToApplication = new AckDuplexPipe(applicationReader, ackWriterClient); + var applicationToTransport = new AckDuplexPipe(transportReader, ackWriterApp); // Use for one side only, i.e. server //var ackWriter = new AckPipeWriter(output.Writer); diff --git a/src/SignalR/common/Shared/AcknowledgePipeV2.cs b/src/SignalR/common/Shared/AcknowledgePipeV2.cs index 72691709c363..d5149ff08a1f 100644 --- a/src/SignalR/common/Shared/AcknowledgePipeV2.cs +++ b/src/SignalR/common/Shared/AcknowledgePipeV2.cs @@ -1,202 +1,242 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + using System; +using System.Buffers; +using System.Buffers.Text; using System.Diagnostics; using System.IO.Pipelines; using System.Threading; using System.Threading.Tasks; -namespace PipelinesOverNetwork +#nullable enable + +namespace PipelinesOverNetwork; + +// Wrapper around a PipeReader that adds an Ack position which replaces Consumed +// This allows the underlying pipe to keep un-acked data in the pipe while still providing only new data to the reader +internal sealed class AckPipeReader : PipeReader { - // Wrapper around a PipeReader that adds an Ack position which replaces Consumed - // This allows the underlying pipe to keep un-acked data in the pipe while still providing only new data to the reader - internal sealed class AckPipeReader : PipeReader - { - private readonly PipeReader _inner; - private readonly object _lock = new object(); + private readonly PipeReader _inner; + private readonly object _lock = new object(); - private SequencePosition _consumed; - private SequencePosition _ackPosition; - private long _ackDiff; - private long _ackId; - private long _totalWritten; - private bool _resend; + private SequencePosition _consumed; + private SequencePosition _ackPosition; + private long _ackDiff; + private long _ackId; + private long _totalWritten; + private bool _resend; - public AckPipeReader(PipeReader inner) - { - _inner = inner; - } + public AckPipeReader(PipeReader inner) + { + _inner = inner; + } - public void Ack(long byteID) + // Update the ack position. This number includes the framing size. + // If byteID is larger than the total bytes sent, it'll throw InvalidOperationException. + public void Ack(long byteID) + { + lock (_lock) { - lock (_lock) + //Debug.Assert(_ackDiff == 0); + // ignore? Is this a bad state? + if (byteID < _ackId) { - //Debug.Assert(_ackDiff == 0); - // ignore? Is this a bad state? - if (byteID < _ackId) + return; + } + //Debug.Assert(byteID >= _ackId); + _ackDiff = byteID - _ackId; + //Console.WriteLine($"AckId: {byteID}"); + + if (_totalWritten < byteID) + { + Throw(); + static void Throw() { - return; + throw new InvalidOperationException("Ack ID is greater than total amount of bytes that have been sent."); } - //Debug.Assert(byteID >= _ackId); - _ackDiff = byteID - _ackId; - //Console.WriteLine($"AckId: {byteID}"); } } + } - public void Resend() + public void Resend() + { + Debug.Assert(_resend == false); + if (_totalWritten == 0) { - Debug.Assert(_resend == false); - _resend = true; + return; } + _resend = true; + } - public override void AdvanceTo(SequencePosition consumed) - { - AdvanceTo(consumed, consumed); - } + public override void AdvanceTo(SequencePosition consumed) + { + AdvanceTo(consumed, consumed); + } - public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) + public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) + { + _consumed = consumed; + if (_consumed.Equals(_ackPosition)) { - _consumed = consumed; - if (_consumed.Equals(_ackPosition)) - _consumed = default; - _inner.AdvanceTo(_ackPosition, examined); + // Reset to default, we check this in ReadAsync to know if we should provide the current read buffer to the user + // Or slice to the consumed position + _consumed = default; } + _inner.AdvanceTo(_ackPosition, examined); + } - public override void CancelPendingRead() - { - _inner.CancelPendingRead(); - } + public override void CancelPendingRead() + { + _inner.CancelPendingRead(); + } - public override void Complete(Exception? exception = null) - { - _inner.Complete(exception); - } + public override void Complete(Exception? exception = null) + { + _inner.Complete(exception); + } - public override async ValueTask ReadAsync(CancellationToken cancellationToken = default) + public override async ValueTask ReadAsync(CancellationToken cancellationToken = default) + { + var res = await _inner.ReadAsync(cancellationToken).ConfigureAwait(false); + var buffer = res.Buffer; + lock (_lock) { - var res = await _inner.ReadAsync(cancellationToken).ConfigureAwait(false); - var buffer = res.Buffer; - lock (_lock) + if (_ackDiff != 0) { - if (_ackDiff != 0) + // This detects the odd scenario where _consumed points to the end of a Segment and buffer.Slice(_ackDiff) points to the beginning of the next Segment + // While they technically point to different positions, they point to the same concept of "beginning of the next buffer" + var ackSlice = buffer.Slice(_ackDiff); + if (buffer.Slice(_consumed).First.Length == 0 && ackSlice.Start.GetInteger() == 0) { - //if (buffer.Slice(_ackDiff).Start.GetInteger() == 0 && buffer.Slice(_consumed).Start.GetInteger() > 0) - //{ - // Debugger.Break(); - //} - //if (buffer.Slice(_consumed).Start.Equals(buffer.Slice(_ackDiff).Start)) - //{ - // _consumed = buffer.Slice(_ackDiff).Start; - //} - if (buffer.Slice(_consumed).First.Length == 0 && buffer.Slice(_ackDiff).Start.GetInteger() == 0) - { - _consumed = buffer.Slice(buffer.Length - buffer.Slice(_consumed).Length).Start; - } - //buffer = buffer.Slice(_ackDiff + 16); - buffer = buffer.Slice(_ackDiff); - _ackId += _ackDiff; - _ackDiff = 0; - _ackPosition = buffer.Start; + // Fix consumed to point to the beginning of the next Segment + _consumed = ackSlice.Start; + // wtf does this do though + //_consumed = buffer.Slice(buffer.Length - buffer.Slice(_consumed).Length).Start; } - } - // Slice consumed, unless resending, then slice to ackPosition - // TODO: implement resend for reconnect - if (_resend) - { - _resend = false; - buffer = buffer.Slice(_ackPosition); - // update total written? - } - else - { + + buffer = ackSlice; + _ackId += _ackDiff; + _ackDiff = 0; _ackPosition = buffer.Start; - // TODO: buffer.Length is 0 sometimes, figure out why and verify behavior - if (buffer.Length > 0 && !_consumed.Equals(default)) - { - buffer = buffer.Slice(_consumed); - } - _totalWritten += (uint)buffer.Length; } - res = new(buffer, res.IsCanceled, res.IsCompleted); - return res; } - public override bool TryRead(out ReadResult result) + // Slice consumed, unless resending, then slice to ackPosition + if (_resend) { - throw new NotImplementedException(); + _resend = false; + buffer = buffer.Slice(_ackPosition); + // update total written? } + else + { + _ackPosition = buffer.Start; + // TODO: buffer.Length is 0 sometimes, figure out why and verify behavior + if (buffer.Length > 0 && !_consumed.Equals(default)) + { + buffer = buffer.Slice(_consumed); + } + _totalWritten += (uint)buffer.Length; + } + res = new(buffer, res.IsCanceled, res.IsCompleted); + return res; } - // Wrapper around a PipeWriter that adds framing to writes - internal sealed class AckPipeWriter : PipeWriter + public override bool TryRead(out ReadResult result) { - private const int FrameSize = 16; - private readonly PipeWriter _inner; - internal long lastAck; + throw new NotImplementedException(); + } +} - Memory _frameHeader; - bool _shouldAdvanceFrameHeader; - private long _buffered; +// Wrapper around a PipeWriter that adds framing to writes +internal sealed class AckPipeWriter : PipeWriter +{ + private const int FrameSize = 24; + private readonly PipeWriter _inner; + internal long lastAck; - public AckPipeWriter(PipeWriter inner) - { - _inner = inner; - } + Memory _frameHeader; + bool _shouldAdvanceFrameHeader; + private long _buffered; - public override void Advance(int bytes) - { - _buffered += bytes; - if (_shouldAdvanceFrameHeader) - { - bytes += FrameSize; - _shouldAdvanceFrameHeader = false; - } - _inner.Advance(bytes); - } + public AckPipeWriter(PipeWriter inner) + { + _inner = inner; + } - public override void CancelPendingFlush() + public override void Advance(int bytes) + { + _buffered += bytes; + if (_shouldAdvanceFrameHeader) { - _inner.CancelPendingFlush(); + bytes += FrameSize; + _shouldAdvanceFrameHeader = false; } + _inner.Advance(bytes); + } - public override void Complete(Exception? exception = null) - { - _inner.Complete(exception); - } + public override void CancelPendingFlush() + { + _inner.CancelPendingFlush(); + } - // X - 8 byte size of payload as uint - // Y - 8 byte number of acked bytes - // Z - payload - // [ XXXX YYYY ZZZZ ] - public override ValueTask FlushAsync(CancellationToken cancellationToken = default) - { -#if NETSTANDARD2_1_OR_GREATER - BitConverter.TryWriteBytes(_frameHeader.Span, _buffered); - BitConverter.TryWriteBytes(_frameHeader.Slice(8).Span, lastAck); + public override void Complete(Exception? exception = null) + { + _inner.Complete(exception); + } + + // X - 8 byte size of payload as uint + // Y - 8 byte number of acked bytes + // Z - payload + // [ XXXX YYYY ZZZZ ] + public override ValueTask FlushAsync(CancellationToken cancellationToken = default) + { + Debug.Assert(_frameHeader.Length >= FrameSize); + +#if NETSTANDARD2_1_OR_GREATER || NET5_0_OR_GREATER + var res = BitConverter.TryWriteBytes(_frameHeader.Span, _buffered); + Debug.Assert(res); + var status = Base64.EncodeToUtf8InPlace(_frameHeader.Span, 8, out var written); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(written == 12); + res = BitConverter.TryWriteBytes(_frameHeader.Slice(12).Span, lastAck); + Debug.Assert(res); + status = Base64.EncodeToUtf8InPlace(_frameHeader.Slice(12).Span, 8, out written); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(written == 12); #else - BitConverter.GetBytes(_buffered).CopyTo(_frameHeader); - BitConverter.GetBytes(lastAck).CopyTo(_frameHeader.Slice(8).Span); + BitConverter.GetBytes(_buffered).CopyTo(_frameHeader); + var status = Base64.EncodeToUtf8InPlace(_frameHeader.Span, 8, out var written); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(written == 12); + BitConverter.GetBytes(lastAck).CopyTo(_frameHeader.Slice(12).Span); + status = Base64.EncodeToUtf8InPlace(_frameHeader.Slice(12).Span, 8, out written); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(written == 12); #endif - //Console.WriteLine($"SendingAckId: {lastAck}"); - _frameHeader = Memory.Empty; - _buffered = 0; - return _inner.FlushAsync(cancellationToken); - } - public override Memory GetMemory(int sizeHint = 0) - { - var segment = _inner.GetMemory(Math.Max(FrameSize + 1, sizeHint)); - if (_frameHeader.IsEmpty || _buffered == 0) - { - // TODO: segment less than FrameSize - _frameHeader = segment.Slice(0, FrameSize); - segment = segment.Slice(FrameSize); - _shouldAdvanceFrameHeader = true; - } - return segment; - } + _frameHeader = Memory.Empty; + _buffered = 0; + return _inner.FlushAsync(cancellationToken); + } - public override Span GetSpan(int sizeHint = 0) + public override Memory GetMemory(int sizeHint = 0) + { + var segment = _inner.GetMemory(Math.Max(FrameSize + 1, sizeHint)); + if (_frameHeader.IsEmpty || _buffered == 0) { - return GetMemory(sizeHint).Span; + Debug.Assert(segment.Length > FrameSize); + + _frameHeader = segment.Slice(0, FrameSize); + segment = segment.Slice(FrameSize); + _shouldAdvanceFrameHeader = true; } + return segment; + } + + public override Span GetSpan(int sizeHint = 0) + { + return GetMemory(sizeHint).Span; } } diff --git a/src/SignalR/common/Shared/ParseAckPipeReader.cs b/src/SignalR/common/Shared/ParseAckPipeReader.cs index b92b51707b79..737a283322c2 100644 --- a/src/SignalR/common/Shared/ParseAckPipeReader.cs +++ b/src/SignalR/common/Shared/ParseAckPipeReader.cs @@ -1,21 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + using System; using System.Buffers; +using System.Buffers.Text; using System.Diagnostics; using System.IO.Pipelines; using System.Threading; using System.Threading.Tasks; +#nullable enable + namespace PipelinesOverNetwork; // Read from "network" // Parse framing and slice the read so the application doesn't see the framing // Notify outbound pipe of framing details for when sending back +// Notify application pipe of ack id provided by other side of the network internal class ParseAckPipeReader : PipeReader { + private const int FrameSize = 24; private readonly PipeReader _inner; private readonly AckPipeWriter _ackPipeWriter; private readonly AckPipeReader _ackPipeReader; private long _totalBytes; + private long _remaining; private ReadOnlySequence _currentRead; @@ -28,22 +37,27 @@ public ParseAckPipeReader(PipeReader inner, AckPipeWriter ackPipeWriter, AckPipe public override void AdvanceTo(SequencePosition consumed) { - var len = _currentRead.Length - _currentRead.Slice(consumed).Length; - //Console.WriteLine($"lastack: {_ackPipeWriter.lastAck} to {_ackPipeWriter.lastAck + len}"); - // ignore the empty length send, maybe don't return from ReadAsync instead? - _ackPipeWriter.lastAck += (len == 16) ? 0 : len; + CommonAdvance(ref consumed); _inner.AdvanceTo(consumed); } public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) { - var len = _currentRead.Length - _currentRead.Slice(consumed).Length; - //Console.WriteLine($"lastack: {_ackPipeWriter.lastAck} to {_ackPipeWriter.lastAck + len}"); - _ackPipeWriter.lastAck += (len == 16) ? 0 : len; - // Track? + CommonAdvance(ref consumed); _inner.AdvanceTo(consumed, examined); } + private void CommonAdvance(ref SequencePosition consumed) + { + // Get the number of bytes consumed to update our internal state + var len = _currentRead.Length; + // This is used by ReadAsync to help update the ack id + _currentRead = _currentRead.Slice(consumed); + len -= _currentRead.Length; + + _remaining -= len; + } + public override void CancelPendingRead() { _inner.CancelPendingRead(); @@ -57,45 +71,122 @@ public override void Complete(Exception? exception = null) public override async ValueTask ReadAsync(CancellationToken cancellationToken = default) { var res = await _inner.ReadAsync(cancellationToken).ConfigureAwait(false); - _currentRead = res.Buffer; - - if (res.IsCompleted || res.IsCanceled) + try { - if (res.Buffer.Length >= 16) - res = new(res.Buffer.Slice(16), res.IsCanceled, res.IsCompleted); - return res; + var newBytes = res.Buffer.Length - _currentRead.Length; + _currentRead = res.Buffer; + + if (res.IsCompleted || res.IsCanceled) + { + // TODO: figure out behavior + if (res.Buffer.Length >= FrameSize) + { + res = new(res.Buffer.Slice(FrameSize), res.IsCanceled, res.IsCompleted); + } + return res; + } + + ReadOnlySequence buffer = res.Buffer; + if (_remaining == 0) + { + // TODO: didn't get 16 bytes + var frame = buffer.Slice(0, FrameSize); + var len = ParseFrame(ref frame, _ackPipeReader); + _totalBytes += len; + // 0 len sent on reconnect and not part of acks + if (len != 0) + { + //Console.WriteLine($"lastack: {_ackPipeWriter.lastAck} to {_ackPipeWriter.lastAck + res.Buffer.Length}"); + //_ackPipeWriter.lastAck += res.Buffer.Length; + } + + _remaining = len; + + // if the buffer doesn't have enough data we need to update how much we're slicing + if (len >= buffer.Length - FrameSize) + { + len = buffer.Length - FrameSize; + } + + buffer = buffer.Slice(FrameSize, len); + _currentRead = buffer; + _ackPipeWriter.lastAck += buffer.Length + FrameSize; + } + else + { + // Advance was called and didn't consume everything even though we gave it the entire Frame Length of data + // This means the caller is expecting more than a single frame of data + // We'll need to start buffering to parse multiple frames of data + if (_remaining <= _currentRead.Length && buffer.Length > _remaining) + { + // TODO + Console.WriteLine("multi frame"); + } + _ackPipeWriter.lastAck += Math.Min(_remaining, newBytes); + _currentRead = buffer; + buffer = buffer.Slice(0, Math.Min(_remaining, buffer.Length)); + } + + // TODO: validation everywhere! + //Debug.Assert(len < res.Buffer.Length); + + res = new(buffer, res.IsCanceled, res.IsCompleted); + + // TODO: probably should avoid returning when we have 0 bytes to return (unless canceled/completed) + //Debug.Assert(buffer.Length > 0); } - - // TODO: handle previous payload not fully received - // TODO: handle multiple frame prefixed messages - var frame = res.Buffer.Slice(0, 16); - var len = ParseFrame(frame, _ackPipeReader); - _totalBytes += len; - // 0 len sent on reconnect and not part of acks - if (len != 0) + catch (Exception ex) { - //Console.WriteLine($"lastack: {_ackPipeWriter.lastAck} to {_ackPipeWriter.lastAck + res.Buffer.Length}"); - //_ackPipeWriter.lastAck += res.Buffer.Length; + _inner.Complete(ex); + throw; } - // TODO: validation everywhere! - Debug.Assert(len < res.Buffer.Length); - - res = new(res.Buffer.Slice(16, len), res.IsCanceled, res.IsCompleted); return res; - static long ParseFrame(ReadOnlySequence frame, AckPipeReader ackPipeReader) + static long ParseFrame(ref ReadOnlySequence frame, AckPipeReader ackPipeReader) { - Span buffer = stackalloc byte[16]; - frame.CopyTo(buffer); - -#if NETSTANDARD2_1_OR_GREATER - var len = BitConverter.ToInt64(buffer); - var ackId = BitConverter.ToInt64(buffer.Slice(8)); + Debug.Assert(frame.Length >= FrameSize); + + long len; + long ackId; +#if NETSTANDARD2_1_OR_GREATER || NET5_0_OR_GREATER + // Both the Span check and Stackalloc paths are faster than using SequenceReader + var frameSpan = frame.FirstSpan; + if (frameSpan.Length >= FrameSize) + { + Span decodedBytes = stackalloc byte[8]; + var status = Base64.DecodeFromUtf8(frameSpan.Slice(0, 12), decodedBytes, out var consumed, out var written, isFinalBlock: true); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(consumed == 12); + Debug.Assert(written == 8); + len = BitConverter.ToInt64(decodedBytes); + status = Base64.DecodeFromUtf8(frameSpan.Slice(12, 12), decodedBytes, out consumed, out written, isFinalBlock: true); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(consumed == 12); + Debug.Assert(written == 8); + ackId = BitConverter.ToInt64(decodedBytes); + } + else + { + Span buffer = stackalloc byte[FrameSize]; + frame.CopyTo(buffer); + var status = Base64.DecodeFromUtf8InPlace(buffer, out var written); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(written == 8); + len = BitConverter.ToInt64(buffer); + status = Base64.DecodeFromUtf8InPlace(buffer.Slice(12), out written); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(written == 8); + ackId = BitConverter.ToInt64(buffer.Slice(12)); + } #else - var len = BitConverter.ToInt64(buffer.Slice(0, 8).ToArray(), 0); - var ackId = BitConverter.ToInt64(buffer.Slice(8).ToArray(), 0); +// TODO + Span buffer = stackalloc byte[FrameSize]; + frame.CopyTo(buffer); + len = BitConverter.ToInt64(buffer.Slice(0, 8).ToArray(), 0); + ackId = BitConverter.ToInt64(buffer.Slice(8).ToArray(), 0); #endif + // Update ack id provided by other side, so the underlying pipe can release buffered memory ackPipeReader.Ack(ackId); return len; } diff --git a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs index a9e8a93919fe..3adfd021c16e 100644 --- a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs +++ b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs @@ -3,8 +3,10 @@ using System; using System.Buffers; +using System.Buffers.Text; using System.Collections; using System.Collections.Generic; +using System.Diagnostics; using System.IO.Pipelines; using System.Linq; using System.Text; @@ -16,6 +18,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol; public class AckPipeTests { + private const int FrameSize = 24; + [Fact] public async Task CanSendAndReceiveTransport() { @@ -90,12 +94,9 @@ public async Task ReadFromTransportRemovesFraming() { var duplexPipe = CreateClient(); - var buffer = new byte[20]; - WriteFrame(buffer, buffer.Length - 16, 0); - buffer[16] = 9; - buffer[17] = 9; - buffer[18] = 9; - buffer[19] = 9; + var buffer = new byte[28]; + Random.Shared.NextBytes(buffer); + WriteFrame(buffer, buffer.Length - FrameSize, 0); // "write" from server await duplexPipe.Application.Output.WriteAsync(buffer); @@ -103,7 +104,7 @@ public async Task ReadFromTransportRemovesFraming() // read in client application layer var res = await duplexPipe.Transport.Input.ReadAsync(); Assert.Equal(4, res.Buffer.Length); - Assert.Equal(buffer.AsSpan(16).ToArray(), res.Buffer.ToArray()); + Assert.Equal(buffer.AsSpan(FrameSize).ToArray(), res.Buffer.ToArray()); } [Fact] @@ -121,8 +122,567 @@ public async Task WriteFromApplicationAddsFraming() Assert.Equal(buffer.Length, framing.Length); Assert.Equal(0, framing.AckId); - Assert.Equal(buffer.Length + 16, res.Buffer.Length); - Assert.Equal(buffer, res.Buffer.Slice(16).ToArray()); + Assert.Equal(buffer.Length + FrameSize, res.Buffer.Length); + Assert.Equal(buffer, res.Buffer.Slice(FrameSize).ToArray()); + } + + [Fact] + public async Task MultipleWritesSingleFlushFromApplicationAddsFraming() + { + var duplexPipe = CreateClient(); + + var buffer = new byte[20]; + Random.Shared.NextBytes(buffer); + + for (var i = 0; i < 3; i++) + { + var memory = duplexPipe.Transport.Output.GetMemory(); + buffer.CopyTo(memory); + duplexPipe.Transport.Output.Advance(buffer.Length); + } + await duplexPipe.Transport.Output.FlushAsync(); + + var res = await duplexPipe.Application.Input.ReadAsync(); + var framing = ReadFrame(res.Buffer.ToArray()); + + Assert.Equal(buffer.Length * 3, framing.Length); + Assert.Equal(0, framing.AckId); + Assert.Equal(buffer.Length * 3 + FrameSize, res.Buffer.Length); + Assert.Equal(buffer, res.Buffer.Slice(FrameSize, buffer.Length).ToArray()); + Assert.Equal(buffer, res.Buffer.Slice(FrameSize + buffer.Length, buffer.Length).ToArray()); + Assert.Equal(buffer, res.Buffer.Slice(FrameSize + buffer.Length * 2, buffer.Length).ToArray()); + } + + [Fact] + public async Task ReadFromTransportAcrossMultipleReads() + { + var duplexPipe = CreateClient(); + + var buffer = new byte[28]; + Random.Shared.NextBytes(buffer); + WriteFrame(buffer, buffer.Length - FrameSize + buffer.Length + buffer.Length, 0); + + // "write" from server + await duplexPipe.Application.Output.WriteAsync(buffer); + + // read in client application layer + var res = await duplexPipe.Transport.Input.ReadAsync(); + + Assert.Equal(4, res.Buffer.Length); + Assert.Equal(buffer.AsSpan(FrameSize).ToArray(), res.Buffer.ToArray()); + + // consume nothing + duplexPipe.Transport.Input.AdvanceTo(res.Buffer.Start, res.Buffer.End); + + await duplexPipe.Application.Output.WriteAsync(buffer); + res = await duplexPipe.Transport.Input.ReadAsync(); + + Assert.Equal(32, res.Buffer.Length); + + // consume nothing + duplexPipe.Transport.Input.AdvanceTo(res.Buffer.Start, res.Buffer.End); + + await duplexPipe.Application.Output.WriteAsync(buffer); + res = await duplexPipe.Transport.Input.ReadAsync(); + + Assert.Equal(60, res.Buffer.Length); + + // consume everything + duplexPipe.Transport.Input.AdvanceTo(res.Buffer.End); + + // New write to make sure internal state is cleared from completed read + WriteFrame(buffer, buffer.Length - FrameSize, 0); + await duplexPipe.Application.Output.WriteAsync(buffer); + res = await duplexPipe.Transport.Input.ReadAsync(); + + Assert.Equal(4, res.Buffer.Length); + } + + [Fact] + public async Task ManyWritesSingleFlush_WritesSingleFrame() + { + var duplexPipe = CreateClient(); + + var buffer = new byte[20]; + Random.Shared.NextBytes(buffer); + + var memory = duplexPipe.Transport.Output.GetMemory(); + Assert.True(memory.Length > buffer.Length); + buffer.CopyTo(memory); + duplexPipe.Transport.Output.Advance(buffer.Length); + + memory = duplexPipe.Transport.Output.GetMemory(); + Assert.True(memory.Length > buffer.Length); + buffer.CopyTo(memory); + duplexPipe.Transport.Output.Advance(buffer.Length); + + memory = duplexPipe.Transport.Output.GetMemory(); + Assert.True(memory.Length > buffer.Length); + buffer.CopyTo(memory); + duplexPipe.Transport.Output.Advance(buffer.Length); + + await duplexPipe.Transport.Output.FlushAsync(); + + var res = await duplexPipe.Application.Input.ReadAsync(); + var framing = ReadFrame(res.Buffer.ToArray()); + + Assert.Equal(buffer.Length * 3, framing.Length); + Assert.Equal(0, framing.AckId); + Assert.Equal(framing.Length + FrameSize, res.Buffer.Length); + + var buf = res.Buffer.Slice(FrameSize); + while (buf.Length > 0) + { + Assert.Equal(buffer, buf.Slice(0, buffer.Length).ToArray()); + buf = buf.Slice(buffer.Length); + } + } + + [Fact(Skip = "Something we want to support?")] + public async Task ReadFromTransportAcrossFrames() + { + var duplexPipe = CreateClient(); + + var buffer = new byte[20]; + Random.Shared.NextBytes(buffer); + WriteFrame(buffer, buffer.Length - FrameSize, 0); + + // "write" from server + await duplexPipe.Application.Output.WriteAsync(buffer); + await duplexPipe.Application.Output.WriteAsync(buffer); + + // read in client application layer + var res = await duplexPipe.Transport.Input.ReadAsync(); + + Assert.Equal(4, res.Buffer.Length); + Assert.Equal(buffer.AsSpan(FrameSize).ToArray(), res.Buffer.ToArray()); + + // consume nothing + duplexPipe.Transport.Input.AdvanceTo(res.Buffer.Start, res.Buffer.End); + + res = await duplexPipe.Transport.Input.ReadAsync(); + // ?? + } + + [Fact] + public async Task AckFromTransportReadUpdatesApplicationBuffer() + { + var duplexPipe = CreateClient(); + // write something so we can ack it and see that the pipe has nothing in it + await duplexPipe.Transport.Output.WriteAsync(new byte[2]); + + var res = await duplexPipe.Application.Input.ReadAsync(); + // in real usage this will be advanced properly + // but we're claiming we read nothing so we can observe the ack behavior in the next read + duplexPipe.Application.Input.AdvanceTo(res.Buffer.Start); + + var buffer = new byte[28]; + Random.Shared.NextBytes(buffer); + WriteFrame(buffer, buffer.Length - 24, ackId: FrameSize + 2); + + // "write" from server + await duplexPipe.Application.Output.WriteAsync(buffer); + + // read in client application layer + // this reads the ack from the "server" and updates state + _ = await duplexPipe.Transport.Input.ReadAsync(); + + // this will be an empty read because the ack will be applied and everything will be marked as read + res = await duplexPipe.Application.Input.ReadAsync(); + + Assert.Equal(0, res.Buffer.Length); + Assert.False(res.IsCanceled); + Assert.False(res.IsCompleted); + } + + [Fact] + public async Task AckFromTransportReadUpdatesApplicationBuffer_CanReadNewDataAfter() + { + // Basically the same test as AckFromTransportReadUpdatesApplicationBuffer but we write more data after the ack has fully flowed + // Just to smoke test that the pipe is still usable + + var duplexPipe = CreateClient(); + // write something so we can ack it + await duplexPipe.Transport.Output.WriteAsync(new byte[2]); + + var res = await duplexPipe.Application.Input.ReadAsync(); + // in real usage this will be advanced properly + // but we're claiming we read nothing so we can observe the ack behavior in the next read + duplexPipe.Application.Input.AdvanceTo(res.Buffer.Start); + + var buffer = new byte[FrameSize + 4]; + Random.Shared.NextBytes(buffer); + WriteFrame(buffer, buffer.Length - FrameSize, ackId: FrameSize + 2); + + // "write" from server + await duplexPipe.Application.Output.WriteAsync(buffer); + + // read in client application layer + // this reads the ack from the "server" and updates state + res = await duplexPipe.Transport.Input.ReadAsync(); + duplexPipe.Transport.Input.AdvanceTo(res.Buffer.End); + + // write again to update total sent + await duplexPipe.Transport.Output.WriteAsync(new byte[2] { 42, 99 }); + + res = await duplexPipe.Application.Input.ReadAsync(); + + Assert.Equal(FrameSize + 2, res.Buffer.Length); + var (len, ack) = ReadFrame(res.Buffer.ToArray()); + Assert.Equal(2, len); + Assert.Equal(FrameSize + 4, ack); + Assert.Equal(new byte[] { 42, 99 }, res.Buffer.Slice(FrameSize).ToArray()); + Assert.False(res.IsCanceled); + Assert.False(res.IsCompleted); + } + + [Fact] + public async Task ReceiveAckIdLargerThanTotalSentErrors() + { + var duplexPipe = CreateClient(); + + var buffer = new byte[28]; + Random.Shared.NextBytes(buffer); + // ackId more than what has been sent + WriteFrame(buffer, buffer.Length - FrameSize, ackId: 30); + + // "write" from server + await duplexPipe.Application.Output.WriteAsync(buffer); + + // read in client application layer + var exception = await Assert.ThrowsAsync(async () => await duplexPipe.Transport.Input.ReadAsync()); + Assert.Equal("Ack ID is greater than total amount of bytes that have been sent.", exception.Message); + + // Pipe is completed + exception = await Assert.ThrowsAsync(async () => await duplexPipe.Transport.Input.ReadAsync()); + Assert.Equal("Reading is not allowed after reader was completed.", exception.Message); + } + + // This is a fun edge case test, where if we have consumed everything in a BufferSegment and Acked everything too + // then consumed points to the end of the Segment, while Ack points to the beginning of the next Segment + // This test verifies that everything behaves correctly in that case + [Fact] + public async Task ConsumeAndAckAtEndOfSegment_CanServeNextSegment() + { + var duplexPipe = CreateClient(); + + var buffer = new byte[4072]; + Random.Shared.NextBytes(buffer); + + // "write" from server + await duplexPipe.Transport.Output.WriteAsync(buffer); + + // read in client application layer + var res = await duplexPipe.Application.Input.ReadAsync(); + duplexPipe.Application.Input.AdvanceTo(res.Buffer.End); + + Random.Shared.NextBytes(buffer); + await duplexPipe.Transport.Output.WriteAsync(buffer); + + var appBuffer = new byte[28]; + Random.Shared.NextBytes(appBuffer); + WriteFrame(appBuffer, appBuffer.Length - FrameSize, 4096); + await duplexPipe.Application.Output.WriteAsync(appBuffer); + + // Updates Ack in Application.Input + await duplexPipe.Transport.Input.ReadAsync(); + + res = await duplexPipe.Application.Input.ReadAsync(); + Assert.Equal(4096, res.Buffer.Length); + var (len, ack) = ReadFrame(res.Buffer.ToArray()); + Assert.Equal(4072, len); + Assert.Equal(0, ack); + Assert.Equal(buffer, res.Buffer.Slice(FrameSize).ToArray()); + Assert.True(res.Buffer.IsSingleSegment); + } + + [Fact] + public async Task ApplicationSendsAck() + { + var duplexPipe = CreateClient(); + + var buffer = new byte[FrameSize + 4]; + Random.Shared.NextBytes(buffer); + WriteFrame(buffer, buffer.Length - FrameSize, 0); + + // "write" from server + await duplexPipe.Application.Output.WriteAsync(buffer); + + // read in client application layer + var res = await duplexPipe.Transport.Input.ReadAsync(); + duplexPipe.Transport.Input.AdvanceTo(res.Buffer.End); + + _ = await duplexPipe.Transport.Output.WriteAsync(new byte[2]); + + res = await duplexPipe.Application.Input.ReadAsync(); + var (length, ackId) = ReadFrame(res.Buffer.ToArray()); + + Assert.Equal(2, length); + Assert.Equal(FrameSize + 4, ackId); + } + + [Fact] + public async Task ApplicationSendsAckWithMultiSegment_ConsumingWhileReading() + { + var duplexPipe = CreateClient(); + + var buffer = new byte[FrameSize + 5]; + Random.Shared.NextBytes(buffer); + WriteFrame(buffer, buffer.Length - FrameSize, 0); + + // "write" from server, 26 of the 29 bytes, we want to force the reader to do two reads to get the full data + await duplexPipe.Application.Output.WriteAsync(buffer.AsSpan(0, FrameSize + 2).ToArray()); + + // read in client application layer + var res = await duplexPipe.Transport.Input.ReadAsync(); + Assert.Equal(2, res.Buffer.Length); + Assert.Equal(buffer.AsSpan(FrameSize, 2).ToArray(), res.Buffer.ToArray()); + // Consume all seen so far + duplexPipe.Transport.Input.AdvanceTo(res.Buffer.End); + + // write again, the last 3 of the 29 bytes + await duplexPipe.Application.Output.WriteAsync(buffer.AsSpan(FrameSize + 2, 3).ToArray()); + + res = await duplexPipe.Transport.Input.ReadAsync(); + Assert.Equal(3, res.Buffer.Length); + Assert.Equal(buffer.AsSpan(FrameSize + 2, 3).ToArray(), res.Buffer.ToArray()); + duplexPipe.Transport.Input.AdvanceTo(res.Buffer.End); + + _ = await duplexPipe.Transport.Output.WriteAsync(new byte[2]); + + res = await duplexPipe.Application.Input.ReadAsync(); + var (length, ackId) = ReadFrame(res.Buffer.ToArray()); + + Assert.Equal(2, length); + Assert.Equal(FrameSize + 5, ackId); + } + + [Fact] + public async Task ApplicationSendsAckWithMultiSegment_OnlyConsumeAtEnd() + { + var duplexPipe = CreateClient(); + + var buffer = new byte[29]; + Random.Shared.NextBytes(buffer); + WriteFrame(buffer, buffer.Length - FrameSize, 0); + + // "write" from server, 26 of the 29 bytes, we want to force the reader to do two reads to get the full data + await duplexPipe.Application.Output.WriteAsync(buffer.AsSpan(0, 26).ToArray()); + + // read in client application layer + var res = await duplexPipe.Transport.Input.ReadAsync(); + Assert.Equal(2, res.Buffer.Length); + Assert.Equal(buffer.AsSpan(FrameSize, 2).ToArray(), res.Buffer.ToArray()); + // Don't consume any + duplexPipe.Transport.Input.AdvanceTo(res.Buffer.Start, res.Buffer.End); + + // write again, the last 3 of the 29 bytes + await duplexPipe.Application.Output.WriteAsync(buffer.AsSpan(26, 3).ToArray()); + + res = await duplexPipe.Transport.Input.ReadAsync(); + Assert.Equal(5, res.Buffer.Length); + Assert.Equal(buffer.AsSpan(FrameSize, 5).ToArray(), res.Buffer.ToArray()); + duplexPipe.Transport.Input.AdvanceTo(res.Buffer.End); + + _ = await duplexPipe.Transport.Output.WriteAsync(new byte[2]); + + res = await duplexPipe.Application.Input.ReadAsync(); + var (length, ackId) = ReadFrame(res.Buffer.ToArray()); + + Assert.Equal(2, length); + Assert.Equal(29, ackId); + } + + [Fact] + public async Task CompleteWithErrorFromTransportWriterFlowsToAppReader() + { + var duplexPipe = CreateClient(); + + duplexPipe.Transport.Output.Complete(new Exception("custom")); + + var ex = await Assert.ThrowsAsync(async () => await duplexPipe.Application.Input.ReadAsync()); + Assert.Equal("custom", ex.Message); + } + + [Fact] + public async Task CompleteWithErrorFromTransportReaderFlowsToAppWriter() + { + var duplexPipe = CreateClient(); + + duplexPipe.Transport.Input.Complete(new Exception("custom")); + + var ex = await Assert.ThrowsAsync(async () => await duplexPipe.Application.Output.FlushAsync()); + Assert.Equal("custom", ex.Message); + } + + [Fact] + public async Task CompleteWithErrorFromAppWriterFlowsToTransportReader() + { + var duplexPipe = CreateClient(); + + duplexPipe.Application.Output.Complete(new Exception("custom")); + + var ex = await Assert.ThrowsAsync(async () => await duplexPipe.Transport.Input.ReadAsync()); + Assert.Equal("custom", ex.Message); + } + + [Fact] + public async Task CompleteWithErrorFromAppReaderFlowsToTransportWriter() + { + var duplexPipe = CreateClient(); + + duplexPipe.Application.Input.Complete(new Exception("custom")); + + var ex = await Assert.ThrowsAsync(async () => await duplexPipe.Transport.Output.WriteAsync(new byte[1])); + Assert.Equal("custom", ex.Message); + } + + [Fact] + public async Task TriggerResendWithNothingWritten() + { + var duplexPipe = CreateClient(); + + var reader = (AckPipeReader)duplexPipe.Application.Input; + reader.Resend(); + + await duplexPipe.Transport.Output.WriteAsync(new byte[2]); + + var res = await duplexPipe.Application.Input.ReadAsync(); + + Assert.Equal(FrameSize + 2, res.Buffer.Length); + Assert.False(res.IsCanceled); + Assert.False(res.IsCompleted); + + var (length, ackId) = ReadFrame(res.Buffer.ToArray()); + Assert.Equal(2, length); + Assert.Equal(0, ackId); + } + + [Fact] + public async Task TriggerResendWithEverythingAcked() + { + var duplexPipe = CreateClient(); + + await duplexPipe.Transport.Output.WriteAsync(new byte[2]); + // Read to pretend we've sent 18 bytes, so that an ack will be allowed + var res = await duplexPipe.Application.Input.ReadAsync(); + duplexPipe.Application.Input.AdvanceTo(res.Buffer.Start, res.Buffer.Start); + + var buffer = new byte[FrameSize]; + WriteFrame(buffer, 0, FrameSize + 2); + await duplexPipe.Application.Output.WriteAsync(buffer); + + // Updates ack from App.Output in App.Input + _ = await duplexPipe.Transport.Input.ReadAsync(); + + var reader = (AckPipeReader)duplexPipe.Application.Input; + reader.Resend(); + + // Nothing returned since everything was acked before resend triggered + res = await duplexPipe.Application.Input.ReadAsync(); + + Assert.Equal(0, res.Buffer.Length); + Assert.False(res.IsCanceled); + Assert.False(res.IsCompleted); + duplexPipe.Application.Input.AdvanceTo(res.Buffer.End); + + // smoke testing that we can still receive + await duplexPipe.Transport.Output.WriteAsync(new byte[2]); + + res = await duplexPipe.Application.Input.ReadAsync(); + + Assert.Equal(FrameSize + 2, res.Buffer.Length); + var (len, ackId) = ReadFrame(res.Buffer.ToArray()); + Assert.Equal(2, len); + Assert.Equal(FrameSize, ackId); + Assert.False(res.IsCanceled); + Assert.False(res.IsCompleted); + } + + [Fact] + public async Task TriggerResendSendsEverythingNotAcked() + { + var duplexPipe = CreateClient(); + + // Write two frames of data + await duplexPipe.Transport.Output.WriteAsync(new byte[2] { 1, 2 }); + await duplexPipe.Transport.Output.WriteAsync(new byte[2] { 3, 4 }); + + var reader = (AckPipeReader)duplexPipe.Application.Input; + reader.Resend(); + + var res = await duplexPipe.Application.Input.ReadAsync(); + + Assert.Equal(52, res.Buffer.Length); + Assert.False(res.IsCanceled); + Assert.False(res.IsCompleted); + var (len, ackId) = ReadFrame(res.Buffer.ToArray()); + Assert.Equal(2, len); + Assert.Equal(0, ackId); + Assert.Equal(new byte[] { 1, 2 }, res.Buffer.ToArray().AsSpan(FrameSize, 2).ToArray()); + (len, ackId) = ReadFrame(res.Buffer.ToArray().AsSpan(FrameSize + 2).ToArray()); + Assert.Equal(2, len); + Assert.Equal(0, ackId); + Assert.Equal(new byte[] { 3, 4 }, res.Buffer.ToArray().AsSpan(FrameSize * 2 + 2, 2).ToArray()); + + duplexPipe.Application.Input.AdvanceTo(res.Buffer.End); + + // smoke testing that we can still receive + await duplexPipe.Transport.Output.WriteAsync(new byte[2] { 4, 5 }); + + res = await duplexPipe.Application.Input.ReadAsync(); + + Assert.Equal(FrameSize + 2, res.Buffer.Length); + (len, ackId) = ReadFrame(res.Buffer.ToArray()); + Assert.Equal(2, len); + Assert.Equal(0, ackId); + Assert.Equal(new byte[] { 4, 5 }, res.Buffer.ToArray().AsSpan(FrameSize, 2).ToArray()); + Assert.False(res.IsCanceled); + Assert.False(res.IsCompleted); + } + + [Fact] + public async Task TriggerResendWhenPartialFrameAcked() + { + var duplexPipe = CreateClient(); + + await duplexPipe.Transport.Output.WriteAsync(new byte[] { 1, 2, 3, 4, 5, 6, 7 }); + // Read to pretend we've sent 31 bytes, so that an ack will be allowed + var res = await duplexPipe.Application.Input.ReadAsync(); + duplexPipe.Application.Input.AdvanceTo(res.Buffer.Start, res.Buffer.Start); + + var buffer = new byte[FrameSize]; + // Only ack 26 of 31 bytes + WriteFrame(buffer, 0, FrameSize + 2); + await duplexPipe.Application.Output.WriteAsync(buffer); + + // Updates ack from App.Output in App.Input + _ = await duplexPipe.Transport.Input.ReadAsync(); + + var reader = (AckPipeReader)duplexPipe.Application.Input; + reader.Resend(); + + res = await duplexPipe.Application.Input.ReadAsync(); + + Assert.Equal(5, res.Buffer.Length); + Assert.False(res.IsCanceled); + Assert.False(res.IsCompleted); + Assert.Equal(new byte[] { 3, 4, 5, 6, 7 }, res.Buffer.ToArray()); + + duplexPipe.Application.Input.AdvanceTo(res.Buffer.End); + + // smoke testing that we can still receive + await duplexPipe.Transport.Output.WriteAsync(new byte[2] { 9, 7 }); + + res = await duplexPipe.Application.Input.ReadAsync(); + + Assert.Equal(FrameSize + 2, res.Buffer.Length); + var (len, ackId) = ReadFrame(res.Buffer.ToArray()); + Assert.Equal(2, len); + Assert.Equal(FrameSize, ackId); + Assert.Equal(new byte[] { 9, 7 }, res.Buffer.ToArray().AsSpan(FrameSize, 2).ToArray()); + Assert.False(res.IsCanceled); + Assert.False(res.IsCompleted); } internal static DuplexPipePair CreateClient(PipeOptions inputOptions = default, PipeOptions outputOptions = default) @@ -160,26 +720,45 @@ internal static DuplexPipePair CreateServer(PipeOptions inputOptions = default, internal static void WriteFrame(byte[] header, long payloadLength, long ackId = 0) { - Assert.True(header.Length >= 16); + Assert.True(header.Length >= FrameSize); Assert.True(BitConverter.TryWriteBytes(header, payloadLength)); Assert.True(BitConverter.TryWriteBytes(header.AsSpan(8), ackId)); + var res = BitConverter.TryWriteBytes(header.AsSpan(), payloadLength); + Debug.Assert(res); + var status = Base64.EncodeToUtf8InPlace(header.AsSpan(), 8, out var written); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(written == 12); + res = BitConverter.TryWriteBytes(header.AsSpan(12), ackId); + Debug.Assert(res); + status = Base64.EncodeToUtf8InPlace(header.AsSpan(12), 8, out written); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(written == 12); } - internal static (long Length, long AckId) ReadFrame(byte[] buffer) + internal static (long Length, long AckId) ReadFrame(byte[] frameBytes) { + var frame = frameBytes.AsSpan(0, FrameSize); + Span buffer = stackalloc byte[FrameSize]; + frame.CopyTo(buffer); + var status = Base64.DecodeFromUtf8InPlace(buffer.Slice(0, 12), out var written); + Assert.Equal(OperationStatus.Done, status); + Assert.Equal(8, written); var len = BitConverter.ToInt64(buffer); - var ackId = BitConverter.ToInt64(buffer.AsSpan(8)); + status = Base64.DecodeFromUtf8InPlace(buffer.Slice(12, 12), out written); + Assert.Equal(OperationStatus.Done, status); + Assert.Equal(8, written); + var ackId = BitConverter.ToInt64(buffer.Slice(12)); return (len, ackId); } internal static (long PayloadLength, long AckId) ReadFrame(ref Span header) { - Assert.True(header.Length >= 16); + Assert.True(header.Length >= FrameSize); var len = BitConverter.ToInt64(header); - var ackId = BitConverter.ToInt64(header.Slice(8)); + var ackId = BitConverter.ToInt64(header.Slice(FrameSize / 2)); return (len, ackId); } From f5492da74eef4d7afe348f202309ef895e8798eb Mon Sep 17 00:00:00 2001 From: Brennan Date: Thu, 6 Apr 2023 09:18:12 -0700 Subject: [PATCH 04/25] E2E test :o --- .../csharp/Client.Core/src/HubConnection.cs | 1 + .../FunctionalTests/HubConnectionTests.cs | 69 ++++++++++ .../test/FunctionalTests/ProxyStartup.cs | 104 ++++++++++++++ .../src/Internal/LongPollingTransport.cs | 2 +- .../src/Internal/WebSocketsTransport.cs | 57 ++++++-- .../Transports/WebSocketsServerTransport.cs | 44 +++++- .../test/HttpConnectionDispatcherTests.cs | 4 +- .../common/Shared/AcknowledgePipeV2.cs | 129 +++++++++++++----- .../common/Shared/ParseAckPipeReader.cs | 82 +++++------ .../test/Internal/Protocol/AckPipeTests.cs | 2 +- .../server/Core/src/HubConnectionHandler.cs | 2 +- 11 files changed, 401 insertions(+), 95 deletions(-) create mode 100644 src/SignalR/clients/csharp/Client/test/FunctionalTests/ProxyStartup.cs diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index 797ecee716c8..862bb4a0cacc 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -1329,6 +1329,7 @@ async Task StartProcessingInvocationMessages(ChannelReader in { var result = await input.ReadAsync().ConfigureAwait(false); var buffer = result.Buffer; + _logger.LogInformation("recv {len}", buffer.Length); try { diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs index c3d34fa616de..276be3e76555 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs @@ -3,6 +3,7 @@ using System.Net; using System.Net.Http; +using System.Net.WebSockets; using System.Text.Json; using System.Threading.Channels; using Microsoft.AspNetCore.Connections; @@ -2537,6 +2538,74 @@ public async Task ServerSentEventsWorksWithHttp2OnlyEndpoint() } } + [Fact] + //[Theory] + //[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] + [Repeat(500)] + public async Task CanReconnectAndSendMessageWhileDisconnected(/*string protocolName, HttpTransportType transportType, string path*/) + { + var protocol = HubProtocols["json"]; + await using (var server = await StartServer()) + //await using (var proxyServer = await StartServer()) + { + //using var httpClient = new HttpClient(); + //await httpClient.SendAsync(new HttpRequestMessage(HttpMethod.Get, proxyServer.Url + $"/server?url={server.Url}")); + + var websocket = new ClientWebSocket(); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + tcs.SetResult(); + + const string originalMessage = "SignalR"; + //var connection = CreateHubConnection(proxyServer.Url, "/default", HttpTransportType.WebSockets, protocol, LoggerFactory); + var connectionBuilder = new HubConnectionBuilder() + .WithLoggerFactory(LoggerFactory) + .WithUrl(server.Url + "/default", HttpTransportType.WebSockets, o => + { + o.WebSocketFactory = async (context, token) => + { + await tcs.Task; + await websocket.ConnectAsync(context.Uri, token); + tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + return websocket; + }; + }); + connectionBuilder.Services.AddSingleton(protocol); + var connection = connectionBuilder.Build(); + connection.ServerTimeout = TimeSpan.FromMinutes(2); + + try + { + await connection.StartAsync().DefaultTimeout(); + + var result = await connection.InvokeAsync(nameof(TestHub.Echo), originalMessage).DefaultTimeout(); + + Assert.Equal(originalMessage, result); + + var originalWebsocket = websocket; + websocket = new ClientWebSocket(); + + //var resultTask = connection.InvokeAsync(nameof(TestHub.Echo), originalMessage).DefaultTimeout(); + //await Task.Delay(1); + originalWebsocket.Dispose(); + //await Task.Delay(1000); + var resultTask = connection.InvokeAsync(nameof(TestHub.Echo), originalMessage).DefaultTimeout(); + tcs.SetResult(); + result = await resultTask; + + Assert.Equal(originalMessage, result); + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().DefaultTimeout(); + } + } + } + private class OneAtATimeSynchronizationContext : SynchronizationContext, IAsyncDisposable { private readonly Channel<(SendOrPostCallback, object)> _taskQueue = Channel.CreateUnbounded<(SendOrPostCallback, object)>(); diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/ProxyStartup.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/ProxyStartup.cs new file mode 100644 index 000000000000..1150baa476ab --- /dev/null +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/ProxyStartup.cs @@ -0,0 +1,104 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.IdentityModel.Tokens.Jwt; +using System.IO; +using System.Net.Http; +using System.Net.WebSockets; +using System.Security.Claims; +using Microsoft.AspNetCore.Authentication.JwtBearer; +using Microsoft.AspNetCore.Authentication.Negotiate; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.DataProtection; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Connections; +using Microsoft.AspNetCore.Routing; +using Microsoft.AspNetCore.WebSockets; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.IdentityModel.Tokens; +using Newtonsoft.Json; + +namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests; + +public class ProxyStartup +{ + private string ServerUrl; + + public void ConfigureServices(IServiceCollection services) + { + // Since tests run in parallel, it's possible multiple servers will startup and read files being written by another test + // Use a unique directory per server to avoid this collision + services.AddDataProtection() + .PersistKeysToFileSystem(Directory.CreateDirectory(Path.GetRandomFileName())); + + services.AddWebSockets(o => o.KeepAliveInterval = TimeSpan.Zero); + + services.AddRouting(); + } + + public void Configure(IApplicationBuilder app) + { + app.UseRouting(); + app.UseWebSockets(); + + app.Use(next => + { + return async context => + { + if (context.Request.Path.Value.EndsWith("/server", StringComparison.Ordinal)) + { + ServerUrl = context.Request.Query["url"]; + } + else if (context.Request.Path.Value.EndsWith("/drop", StringComparison.Ordinal)) + { + // TODO: drop connection + // for testing seamless reconnect + } + else + { + // TODO: forward to server + if (context.WebSockets.IsWebSocketRequest) + { + var uriBuilder = new UriBuilder(ServerUrl); + uriBuilder.Path = context.Request.Path; + uriBuilder.Scheme = context.Request.IsHttps ? "wss" : "ws"; + uriBuilder.Query = context.Request.QueryString.Value; + using var ws = await context.WebSockets.AcceptWebSocketAsync(); + using var forwardingWebsocket = new ClientWebSocket(); + await forwardingWebsocket.ConnectAsync(uriBuilder.Uri, new CancellationTokenSource(TimeSpan.FromSeconds(30)).Token); + var recvTask = Forward(ws, forwardingWebsocket); + var sendTask = Forward(forwardingWebsocket, ws); + + await Task.WhenAny(recvTask, sendTask); + } + else + { + var uriBuilder = new UriBuilder(ServerUrl); + uriBuilder.Path = context.Request.Path; + uriBuilder.Query = context.Request.QueryString.Value; + using var httpClient = new HttpClient(); + var request = new HttpRequestMessage(new HttpMethod(context.Request.Method), uriBuilder.ToString()); + request.Content = new StreamContent(context.Request.Body); + var resp = await httpClient.SendAsync(request); + + context.Response.StatusCode = (int)resp.StatusCode; + await resp.Content.CopyToAsync(context.Response.Body); + } + } + await next(context); + }; + }); + } + + private static async Task Forward(WebSocket ws, WebSocket forwardWebSocket) + { + var buffer = new byte[4096]; + while (forwardWebSocket.CloseStatus is null) + { + var res = await ws.ReceiveAsync(buffer, cancellationToken: default); + await forwardWebSocket.SendAsync(buffer.AsMemory(..res.Count), res.MessageType, res.EndOfMessage, cancellationToken: default); + } + } +} diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs index 7168f74b866b..d3edf3356567 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs @@ -21,9 +21,9 @@ internal sealed partial class LongPollingTransport : ITransport private readonly HttpClient _httpClient; private readonly ILogger _logger; private readonly HttpConnectionOptions _httpConnectionOptions; + private readonly bool _useAck; private IDuplexPipe? _application; private IDuplexPipe? _transport; - private bool _useAck; // Volatile so that the poll loop sees the updated value set from a different thread private volatile Exception? _error; diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs index 528a0ed2d71d..469ae37519d1 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs @@ -310,13 +310,13 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio // TODO: set pipe to start resend if (_application!.Input is AckPipeReader reader) { + _logger.LogInformation("start resend"); // write nothing so just the ackid gets sent to server // server will then send everything client may have missed as well as the last ackid so the client can resend - var buf = new byte[16]; - BitConverter.GetBytes(0).CopyTo(buf.AsMemory()); - BitConverter.GetBytes(((AckPipeWriter)(_transport.Output)).lastAck).CopyTo(buf.AsSpan().Slice(8)); - await _webSocket.SendAsync(new ArraySegment(buf, 0, 16), _webSocketMessageType, true, default).ConfigureAwait(false); - + var buf = new byte[AckPipeWriter.FrameSize]; + AckPipeWriter.WriteFrame(buf.AsSpan(), 0, ((AckPipeWriter)(_transport.Output)).lastAck); + await _webSocket.SendAsync(new ArraySegment(buf, 0, AckPipeWriter.FrameSize), _webSocketMessageType, true, default).ConfigureAwait(false); + _logger.LogInformation("send resend {lastAck}", ((AckPipeWriter)(_transport.Output)).lastAck); // set after first send to server reader.Resend(); // once we've received something from the server (which will contain the ack id for the client) @@ -328,8 +328,30 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio // Exceptions are handled above where the send and receive tasks are being run. var receiveResult = await _webSocket.ReceiveAsync(arraySegment, _stopCts.Token).ConfigureAwait(false); _application.Output.Advance(receiveResult.Count); + // server sends 0 length, but with latest ack, so there shouldn't be more than a frame of data sent + Debug.Assert(receiveResult.Count == AckPipeWriter.FrameSize); + LogBytes(memory.Slice(0, receiveResult.Count), _logger); + // Parsing ack id and updating reader here avoids issue where we send to server before receive loop runs, which is what normally updates ack + // This avoids resending data that was already acked + var parsedLen = ParseAckPipeReader.ParseFrame(new ReadOnlySequence(memory), reader); + Debug.Assert(parsedLen == 0); + void LogBytes(Memory memory, ILogger logger) + { + var sb = new StringBuilder(); + sb.Append("received: "); + foreach (var b in memory.Span) + { + sb.Append($"0x{b:x} "); + } + logger.LogDebug(sb.ToString()); + } + _logger.LogInformation("recv resend"); var flushResult = await _application.Output.FlushAsync().ConfigureAwait(false); + _logger.LogInformation("done resend"); + // TODO: figure out solution + // delay to allow receive loop to read, which updates the ack position so we don't resend data to the server + //await Task.Delay(2000); } } @@ -374,7 +396,7 @@ private async Task ProcessSocketAsync(WebSocket socket, Uri url) // 1. Waiting for application data // 2. Waiting for a websocket send to complete - if (_closed) + //if (_closed) { // Cancel the application so that ReadAsync yields _application.Input.CancelPendingRead(); @@ -423,6 +445,7 @@ private async Task StartReceiving(WebSocket socket) try { + _logger.LogInformation("recv started"); while (true) { #if NETSTANDARD2_1 || NETCOREAPP @@ -503,8 +526,15 @@ void LogBytes(Memory memory, ILogger logger) { if (!_aborted) { - _application.Output.Complete(ex); - _closed = true; + if (_closed) + { + _application.Output.Complete(ex); + } + else + { + _application.Output.CancelPendingFlush(); + } + //_closed = true; } } finally @@ -527,6 +557,8 @@ private async Task StartSending(WebSocket socket) try { + var ignoreFirstCanceled = true; + _logger.LogInformation("send started"); while (true) { var result = await _application.Input.ReadAsync().ConfigureAwait(false); @@ -548,11 +580,13 @@ void LogBytes(Memory memory, ILogger logger) try { - if (result.IsCanceled) + if (result.IsCanceled && !ignoreFirstCanceled) { break; } + ignoreFirstCanceled = false; + if (!buffer.IsEmpty) { try @@ -622,6 +656,11 @@ void LogBytes(Memory memory, ILogger logger) _application.Input.Complete(); } + if (error is not null) + { + _logger.LogInformation(error, "send loop"); + } + Log.SendStopped(_logger); } } diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs index 91ac3c59125f..4ebc5add3f71 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Buffers; using System.Diagnostics; using System.IO.Pipelines; @@ -67,18 +68,34 @@ public async Task ProcessSocketAsync(WebSocket socket) { _aborted = false; // TODO: check if the pipe was used previously? - - reader.Resend(); - // wait for first read? - _ = await socket.ReceiveAsync(Memory.Empty, _connection.Cancellation?.Token ?? default); + // Currently checked in Resend + if (reader.Resend()) + { + // wait for first read? + var buf = new byte[AckPipeWriter.FrameSize]; + var res = await socket.ReceiveAsync(buf, _connection.Cancellation?.Token ?? default); + Debug.Assert(res.Count == AckPipeWriter.FrameSize); + var parsedLen = ParseAckPipeReader.ParseFrame(new ReadOnlySequence(buf), reader); + Debug.Assert(parsedLen == 0); + await _application.Output.WriteAsync(buf); + + var webSocketMessageType = (_connection.ActiveFormat == TransferFormat.Binary + ? WebSocketMessageType.Binary + : WebSocketMessageType.Text); + buf = new byte[AckPipeWriter.FrameSize]; + AckPipeWriter.WriteFrame(buf.AsSpan(), 0, ((AckPipeWriter)(_connection.Transport.Output)).lastAck); + _logger.LogInformation("sending resend ack {lastack}", ((AckPipeWriter)(_connection.Transport.Output)).lastAck); + await socket.SendAsync(buf, webSocketMessageType, endOfMessage: true, _connection.SendingToken); + } } // if (_application.Input.HasBeenUsedBefore) // read first to get the ack id for resending // set resend id on output pipe // start send loop which will resend and tell the client the last ack id it got from the read side // Begin sending and receiving. Receiving must be started first because ExecuteAsync enables SendAsync. + receiving = StartReceiving(socket); - sending = StartSending(socket); + sending = StartSending(socket, true); // Wait for send or receive to complete var trigger = await Task.WhenAny(receiving, sending); @@ -150,8 +167,10 @@ private async Task StartReceiving(WebSocket socket) try { + _logger.LogInformation("start recv"); while (!token.IsCancellationRequested) { + // Do a 0 byte read so that idle connections don't allocate a buffer when waiting for a read var result = await socket.ReceiveAsync(Memory.Empty, token); @@ -225,12 +244,13 @@ void LogBytes(Memory memory, ILogger logger) } } - private async Task StartSending(WebSocket socket) + private async Task StartSending(WebSocket socket, bool ignoreFirstCancel) { Exception? error = null; try { + _logger.LogInformation("start send"); while (true) { var result = await _application.Input.ReadAsync(); @@ -240,7 +260,7 @@ private async Task StartSending(WebSocket socket) try { - if (result.IsCanceled) + if (result.IsCanceled && !ignoreFirstCancel) { break; } @@ -297,6 +317,16 @@ void LogBytes(Memory memory, ILogger logger) { break; } + else if (ignoreFirstCancel) + { + //var webSocketMessageType = (_connection.ActiveFormat == TransferFormat.Binary + // ? WebSocketMessageType.Binary + // : WebSocketMessageType.Text); + //var buf = new byte[AckPipeWriter.FrameSize]; + //AckPipeWriter.WriteFrame(buf.AsSpan(), 0, ((AckPipeWriter)(_connection.Transport.Output)).lastAck); + //await socket.SendAsync(buffer, webSocketMessageType, _connection.SendingToken); + } + ignoreFirstCancel = false; } finally { diff --git a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs index d7f044089634..40b25d09dfc0 100644 --- a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs +++ b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs @@ -1366,7 +1366,7 @@ public async Task WebSocketTransportTimesOutWhenCloseFrameNotReceived() } [Theory] - //[InlineData(HttpTransportType.WebSockets)] + [InlineData(HttpTransportType.WebSockets)] [InlineData(HttpTransportType.ServerSentEvents)] public async Task RequestToActiveConnectionId409ForStreamingTransports(HttpTransportType transportType) { @@ -1392,7 +1392,7 @@ public async Task RequestToActiveConnectionId409ForStreamingTransports(HttpTrans var options = new HttpConnectionDispatcherOptions(); var request1 = dispatcher.ExecuteAsync(context1, options, app); - await dispatcher.ExecuteAsync(context2, options, app); + await dispatcher.ExecuteAsync(context2, options, app).DefaultTimeout(); Assert.Equal(StatusCodes.Status409Conflict, context2.Response.StatusCode); diff --git a/src/SignalR/common/Shared/AcknowledgePipeV2.cs b/src/SignalR/common/Shared/AcknowledgePipeV2.cs index d5149ff08a1f..0dc90fa0ae5b 100644 --- a/src/SignalR/common/Shared/AcknowledgePipeV2.cs +++ b/src/SignalR/common/Shared/AcknowledgePipeV2.cs @@ -50,23 +50,27 @@ public void Ack(long byteID) if (_totalWritten < byteID) { - Throw(); - static void Throw() + Throw(byteID, _totalWritten); + static void Throw(long id, long total) { - throw new InvalidOperationException("Ack ID is greater than total amount of bytes that have been sent."); + throw new InvalidOperationException($"Ack ID '{id}' is greater than total amount of '{total}' bytes that have been sent."); } } } } - public void Resend() + public bool Resend() { Debug.Assert(_resend == false); if (_totalWritten == 0) { - return; + return false; } + // Unblocks ReadAsync and gives a buffer with the examined but not consumed bytes + // This avoids the issue where we wait for someone to write to the pipe before completing the reconnect handshake + CancelPendingRead(); _resend = true; + return true; } public override void AdvanceTo(SequencePosition consumed) @@ -77,13 +81,23 @@ public override void AdvanceTo(SequencePosition consumed) public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) { _consumed = consumed; + //if (_ackPosition.Equals(default)) + //{ + // Debug.Assert(false); + // _inner.AdvanceTo(consumed, examined); + //} + //else + //{ + _inner.AdvanceTo(_ackPosition, examined); + //} + if (_consumed.Equals(_ackPosition)) { // Reset to default, we check this in ReadAsync to know if we should provide the current read buffer to the user // Or slice to the consumed position _consumed = default; + _ackPosition = default; } - _inner.AdvanceTo(_ackPosition, examined); } public override void CancelPendingRead() @@ -100,10 +114,12 @@ public override async ValueTask ReadAsync(CancellationToken cancella { var res = await _inner.ReadAsync(cancellationToken).ConfigureAwait(false); var buffer = res.Buffer; + long hadAck = 0; lock (_lock) { if (_ackDiff != 0) { + hadAck = _ackDiff; // This detects the odd scenario where _consumed points to the end of a Segment and buffer.Slice(_ackDiff) points to the beginning of the next Segment // While they technically point to different positions, they point to the same concept of "beginning of the next buffer" var ackSlice = buffer.Slice(_ackDiff); @@ -111,25 +127,58 @@ public override async ValueTask ReadAsync(CancellationToken cancella { // Fix consumed to point to the beginning of the next Segment _consumed = ackSlice.Start; - // wtf does this do though - //_consumed = buffer.Slice(buffer.Length - buffer.Slice(_consumed).Length).Start; + } + else if (!_consumed.Equals(default)) + { + if (buffer.Slice(_consumed).Length == ackSlice.Length) + { + _consumed = default; + } + else if (buffer.Slice(_consumed).Length > ackSlice.Length) + { + Debug.Assert(false); + } + else if (buffer.Slice(_consumed).Length < ackSlice.Length) + { + // this is normal, ack id is less than total written + + //_totalWritten += ackSlice.Length - buffer.Slice(_consumed).Length; + } } buffer = ackSlice; _ackId += _ackDiff; _ackDiff = 0; _ackPosition = buffer.Start; + //if (buffer.Length == 0) + //{ + // _ackPosition = default; + // _consumed = default; + //} } } - + bool wasResend = _resend; // Slice consumed, unless resending, then slice to ackPosition if (_resend) { _resend = false; - buffer = buffer.Slice(_ackPosition); - // update total written? + if (buffer.Length != 0 && !_ackPosition.Equals(default)) + { + buffer = buffer.Slice(_ackPosition); + } + // update total written if there is more written to the pipe during a reconnect + // TODO: add tests for both these paths + if (!_consumed.Equals(default)) + { + Debug.Assert(buffer.Length - buffer.Slice(_consumed).Length >= 0); + _totalWritten += buffer.Length - buffer.Slice(_consumed).Length; + } + else + { + _totalWritten += buffer.Length; + } } - else + else if (buffer.Length > 0) { _ackPosition = buffer.Start; // TODO: buffer.Length is 0 sometimes, figure out why and verify behavior @@ -140,6 +189,11 @@ public override async ValueTask ReadAsync(CancellationToken cancella _totalWritten += (uint)buffer.Length; } res = new(buffer, res.IsCanceled, res.IsCompleted); + //if (buffer.Length == 0) + //{ + // // everything has been acked + // _ackPosition = default; + //} return res; } @@ -152,7 +206,7 @@ public override bool TryRead(out ReadResult result) // Wrapper around a PipeWriter that adds framing to writes internal sealed class AckPipeWriter : PipeWriter { - private const int FrameSize = 24; + public const int FrameSize = 24; private readonly PipeWriter _inner; internal long lastAck; @@ -194,27 +248,7 @@ public override ValueTask FlushAsync(CancellationToken cancellation { Debug.Assert(_frameHeader.Length >= FrameSize); -#if NETSTANDARD2_1_OR_GREATER || NET5_0_OR_GREATER - var res = BitConverter.TryWriteBytes(_frameHeader.Span, _buffered); - Debug.Assert(res); - var status = Base64.EncodeToUtf8InPlace(_frameHeader.Span, 8, out var written); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(written == 12); - res = BitConverter.TryWriteBytes(_frameHeader.Slice(12).Span, lastAck); - Debug.Assert(res); - status = Base64.EncodeToUtf8InPlace(_frameHeader.Slice(12).Span, 8, out written); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(written == 12); -#else - BitConverter.GetBytes(_buffered).CopyTo(_frameHeader); - var status = Base64.EncodeToUtf8InPlace(_frameHeader.Span, 8, out var written); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(written == 12); - BitConverter.GetBytes(lastAck).CopyTo(_frameHeader.Slice(12).Span); - status = Base64.EncodeToUtf8InPlace(_frameHeader.Slice(12).Span, 8, out written); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(written == 12); -#endif + WriteFrame(_frameHeader.Span, _buffered, lastAck); _frameHeader = Memory.Empty; _buffered = 0; @@ -239,4 +273,31 @@ public override Span GetSpan(int sizeHint = 0) { return GetMemory(sizeHint).Span; } + + public static void WriteFrame(Span header, long length, long ack) + { + Debug.Assert(header.Length >= FrameSize); + +#if NETSTANDARD2_1_OR_GREATER || NET5_0_OR_GREATER + var res = BitConverter.TryWriteBytes(header, length); + Debug.Assert(res); + var status = Base64.EncodeToUtf8InPlace(header, 8, out var written); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(written == 12); + res = BitConverter.TryWriteBytes(header.Slice(12), ack); + Debug.Assert(res); + status = Base64.EncodeToUtf8InPlace(header.Slice(12), 8, out written); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(written == 12); +#else + BitConverter.GetBytes(length).CopyTo(header); + var status = Base64.EncodeToUtf8InPlace(header, 8, out var written); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(written == 12); + BitConverter.GetBytes(ack).CopyTo(header.Slice(12)); + status = Base64.EncodeToUtf8InPlace(header.Slice(12), 8, out written); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(written == 12); +#endif + } } diff --git a/src/SignalR/common/Shared/ParseAckPipeReader.cs b/src/SignalR/common/Shared/ParseAckPipeReader.cs index 737a283322c2..9a248f9abfd0 100644 --- a/src/SignalR/common/Shared/ParseAckPipeReader.cs +++ b/src/SignalR/common/Shared/ParseAckPipeReader.cs @@ -91,7 +91,7 @@ public override async ValueTask ReadAsync(CancellationToken cancella { // TODO: didn't get 16 bytes var frame = buffer.Slice(0, FrameSize); - var len = ParseFrame(ref frame, _ackPipeReader); + var len = ParseFrame(in frame, _ackPipeReader); _totalBytes += len; // 0 len sent on reconnect and not part of acks if (len != 0) @@ -110,7 +110,9 @@ public override async ValueTask ReadAsync(CancellationToken cancella buffer = buffer.Slice(FrameSize, len); _currentRead = buffer; - _ackPipeWriter.lastAck += buffer.Length + FrameSize; + // 0 length means it was part of the reconnect handshake and not sent over the pipe, ignore it for acking purposes + // TODO: check if 0 byte writes are possible in ConnectionHandlers and possibly handle them differently + _ackPipeWriter.lastAck += buffer.Length == 0 ? 0 : buffer.Length + FrameSize; } else { @@ -142,43 +144,44 @@ public override async ValueTask ReadAsync(CancellationToken cancella } return res; + } - static long ParseFrame(ref ReadOnlySequence frame, AckPipeReader ackPipeReader) - { - Debug.Assert(frame.Length >= FrameSize); + public static long ParseFrame(in ReadOnlySequence frame, AckPipeReader ackPipeReader) + { + Debug.Assert(frame.Length >= FrameSize); - long len; - long ackId; + long len; + long ackId; #if NETSTANDARD2_1_OR_GREATER || NET5_0_OR_GREATER - // Both the Span check and Stackalloc paths are faster than using SequenceReader - var frameSpan = frame.FirstSpan; - if (frameSpan.Length >= FrameSize) - { - Span decodedBytes = stackalloc byte[8]; - var status = Base64.DecodeFromUtf8(frameSpan.Slice(0, 12), decodedBytes, out var consumed, out var written, isFinalBlock: true); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(consumed == 12); - Debug.Assert(written == 8); - len = BitConverter.ToInt64(decodedBytes); - status = Base64.DecodeFromUtf8(frameSpan.Slice(12, 12), decodedBytes, out consumed, out written, isFinalBlock: true); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(consumed == 12); - Debug.Assert(written == 8); - ackId = BitConverter.ToInt64(decodedBytes); - } - else - { - Span buffer = stackalloc byte[FrameSize]; - frame.CopyTo(buffer); - var status = Base64.DecodeFromUtf8InPlace(buffer, out var written); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(written == 8); - len = BitConverter.ToInt64(buffer); - status = Base64.DecodeFromUtf8InPlace(buffer.Slice(12), out written); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(written == 8); - ackId = BitConverter.ToInt64(buffer.Slice(12)); - } + // Both the Span check and Stackalloc paths are faster than using SequenceReader + var frameSpan = frame.FirstSpan; + if (frameSpan.Length >= FrameSize) + { + Span decodedBytes = stackalloc byte[8]; + var status = Base64.DecodeFromUtf8(frameSpan.Slice(0, 12), decodedBytes, out var consumed, out var written, isFinalBlock: true); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(consumed == 12); + Debug.Assert(written == 8); + len = BitConverter.ToInt64(decodedBytes); + status = Base64.DecodeFromUtf8(frameSpan.Slice(12, 12), decodedBytes, out consumed, out written, isFinalBlock: true); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(consumed == 12); + Debug.Assert(written == 8); + ackId = BitConverter.ToInt64(decodedBytes); + } + else + { + Span buffer = stackalloc byte[FrameSize]; + frame.CopyTo(buffer); + var status = Base64.DecodeFromUtf8InPlace(buffer, out var written); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(written == 8); + len = BitConverter.ToInt64(buffer); + status = Base64.DecodeFromUtf8InPlace(buffer.Slice(12), out written); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(written == 8); + ackId = BitConverter.ToInt64(buffer.Slice(12)); + } #else // TODO Span buffer = stackalloc byte[FrameSize]; @@ -186,10 +189,9 @@ static long ParseFrame(ref ReadOnlySequence frame, AckPipeReader ackPipeRe len = BitConverter.ToInt64(buffer.Slice(0, 8).ToArray(), 0); ackId = BitConverter.ToInt64(buffer.Slice(8).ToArray(), 0); #endif - // Update ack id provided by other side, so the underlying pipe can release buffered memory - ackPipeReader.Ack(ackId); - return len; - } + // Update ack id provided by other side, so the underlying pipe can release buffered memory + ackPipeReader.Ack(ackId); + return len; } public override bool TryRead(out ReadResult result) diff --git a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs index 3adfd021c16e..612688ca5013 100644 --- a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs +++ b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs @@ -351,7 +351,7 @@ public async Task ReceiveAckIdLargerThanTotalSentErrors() // read in client application layer var exception = await Assert.ThrowsAsync(async () => await duplexPipe.Transport.Input.ReadAsync()); - Assert.Equal("Ack ID is greater than total amount of bytes that have been sent.", exception.Message); + Assert.Equal("Ack ID '30' is greater than total amount of '0' bytes that have been sent.", exception.Message); // Pipe is completed exception = await Assert.ThrowsAsync(async () => await duplexPipe.Transport.Input.ReadAsync()); diff --git a/src/SignalR/server/Core/src/HubConnectionHandler.cs b/src/SignalR/server/Core/src/HubConnectionHandler.cs index 6b0b4c1cfaee..df573d42c2ab 100644 --- a/src/SignalR/server/Core/src/HubConnectionHandler.cs +++ b/src/SignalR/server/Core/src/HubConnectionHandler.cs @@ -260,7 +260,7 @@ void LogBytes(Memory memory, ILogger logger) { sb.Append($"0x{b:x} "); } - logger.LogDebug(sb.ToString()); + logger.LogDebug($"read: {sb}"); } try From 31d0906012751dd31b91313c57fa23f1b4768981 Mon Sep 17 00:00:00 2001 From: Brennan Date: Wed, 12 Apr 2023 13:56:52 -0700 Subject: [PATCH 05/25] some cleanup --- .../FunctionalTests/HubConnectionTests.cs | 141 ++++++++++++++++-- .../Client/test/FunctionalTests/Hubs.cs | 5 + .../UnitTests/HttpConnectionFactoryTests.cs | 1 + .../src/HttpConnection.cs | 6 +- .../src/HttpConnectionFactory.cs | 3 +- .../src/HttpConnectionOptions.cs | 5 + .../src/Internal/WebSocketsTransport.cs | 11 +- .../src/PublicAPI.Unshipped.txt | 2 + .../src/Internal/HttpConnectionContext.cs | 13 +- .../src/Internal/HttpConnectionDispatcher.cs | 20 ++- .../src/Internal/HttpConnectionManager.cs | 4 +- .../Transports/WebSocketsServerTransport.cs | 8 +- .../Http.Connections/test/WebSocketsTests.cs | 3 +- .../common/Shared/AcknowledgePipeV2.cs | 29 ++-- .../common/Shared/ParseAckPipeReader.cs | 9 +- .../test/Internal/Protocol/AckPipeTests.cs | 11 +- .../server/Core/src/HubConnectionHandler.cs | 2 +- .../server/SignalR/test/EndToEndTests.cs | 2 +- 18 files changed, 212 insertions(+), 63 deletions(-) diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs index 276be3e76555..625c0ebb1d0f 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs @@ -2546,17 +2546,12 @@ public async Task CanReconnectAndSendMessageWhileDisconnected(/*string protocolN { var protocol = HubProtocols["json"]; await using (var server = await StartServer()) - //await using (var proxyServer = await StartServer()) { - //using var httpClient = new HttpClient(); - //await httpClient.SendAsync(new HttpRequestMessage(HttpMethod.Get, proxyServer.Url + $"/server?url={server.Url}")); - var websocket = new ClientWebSocket(); var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); tcs.SetResult(); const string originalMessage = "SignalR"; - //var connection = CreateHubConnection(proxyServer.Url, "/default", HttpTransportType.WebSockets, protocol, LoggerFactory); var connectionBuilder = new HubConnectionBuilder() .WithLoggerFactory(LoggerFactory) .WithUrl(server.Url + "/default", HttpTransportType.WebSockets, o => @@ -2568,14 +2563,15 @@ public async Task CanReconnectAndSendMessageWhileDisconnected(/*string protocolN tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); return websocket; }; + o.UseAcks = true; }); connectionBuilder.Services.AddSingleton(protocol); var connection = connectionBuilder.Build(); - connection.ServerTimeout = TimeSpan.FromMinutes(2); try { await connection.StartAsync().DefaultTimeout(); + var originalConnectionId = connection.ConnectionId; var result = await connection.InvokeAsync(nameof(TestHub.Echo), originalMessage).DefaultTimeout(); @@ -2583,16 +2579,141 @@ public async Task CanReconnectAndSendMessageWhileDisconnected(/*string protocolN var originalWebsocket = websocket; websocket = new ClientWebSocket(); - - //var resultTask = connection.InvokeAsync(nameof(TestHub.Echo), originalMessage).DefaultTimeout(); - //await Task.Delay(1); originalWebsocket.Dispose(); - //await Task.Delay(1000); + var resultTask = connection.InvokeAsync(nameof(TestHub.Echo), originalMessage).DefaultTimeout(); tcs.SetResult(); result = await resultTask; Assert.Equal(originalMessage, result); + Assert.Equal(originalConnectionId, connection.ConnectionId); + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().DefaultTimeout(); + } + } + } + + [Fact] + //[Theory] + //[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] + [Repeat(500)] + public async Task CanReconnectAndSendMessageOnceConnected(/*string protocolName, HttpTransportType transportType, string path*/) + { + var protocol = HubProtocols["json"]; + await using (var server = await StartServer()) + { + var websocket = new ClientWebSocket(); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + const string originalMessage = "SignalR"; + var connectionBuilder = new HubConnectionBuilder() + .WithLoggerFactory(LoggerFactory) + .WithUrl(server.Url + "/default", HttpTransportType.WebSockets, o => + { + o.WebSocketFactory = async (context, token) => + { + await websocket.ConnectAsync(context.Uri, token); + tcs.SetResult(); + return websocket; + }; + o.UseAcks = true; + }) + .WithAutomaticReconnect(); + connectionBuilder.Services.AddSingleton(protocol); + var connection = connectionBuilder.Build(); + + var reconnectCalled = false; + connection.Reconnecting += ex => + { + reconnectCalled = true; + return Task.CompletedTask; + }; + + try + { + await connection.StartAsync().DefaultTimeout(); + await tcs.Task; + tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var originalConnectionId = connection.ConnectionId; + + var result = await connection.InvokeAsync(nameof(TestHub.Echo), originalMessage).DefaultTimeout(); + + Assert.Equal(originalMessage, result); + + var originalWebsocket = websocket; + websocket = new ClientWebSocket(); + + originalWebsocket.Dispose(); + + await tcs.Task.DefaultTimeout(); + result = await connection.InvokeAsync(nameof(TestHub.Echo), originalMessage).DefaultTimeout(); + + Assert.Equal(originalMessage, result); + Assert.Equal(originalConnectionId, connection.ConnectionId); + Assert.False(reconnectCalled); + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().DefaultTimeout(); + } + } + } + + [Fact] + //[Theory] + //[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] + [Repeat(500)] + public async Task ServerAbortsConnectionNoReconnectAttempted(/*string protocolName, HttpTransportType transportType, string path*/) + { + var protocol = HubProtocols["json"]; + await using (var server = await StartServer()) + { + var connectCount = 0; + var connectionBuilder = new HubConnectionBuilder() + .WithLoggerFactory(LoggerFactory) + .WithUrl(server.Url + "/default", HttpTransportType.WebSockets, o => + { + o.WebSocketFactory = async (context, token) => + { + connectCount++; + var ws = new ClientWebSocket(); + await ws.ConnectAsync(context.Uri, token); + return ws; + }; + o.UseAcks = true; + }); + connectionBuilder.Services.AddSingleton(protocol); + var connection = connectionBuilder.Build(); + + var closedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection.Closed += ex => + { + closedTcs.SetResult(ex); + return Task.CompletedTask; + }; + + try + { + await connection.StartAsync().DefaultTimeout(); + + await connection.SendAsync(nameof(TestHub.Abort)).DefaultTimeout(); + + Assert.Null(await closedTcs.Task.DefaultTimeout()); + Assert.Equal(HubConnectionState.Disconnected, connection.State); + Assert.Equal(1, connectCount); } catch (Exception ex) { diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/Hubs.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/Hubs.cs index 0b270d5e4a4b..fdb2f56c175a 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/Hubs.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/Hubs.cs @@ -101,6 +101,11 @@ public string GetHttpProtocol() return Context.GetHttpContext()?.Request?.Protocol ?? "unknown"; } + public void Abort() + { + Context.Abort(); + } + public async Task CallWithUnserializableObject() { await Clients.All.SendAsync("Foo", Unserializable.Create()); diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionFactoryTests.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionFactoryTests.cs index bc9952577264..f019ec2a749b 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionFactoryTests.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionFactoryTests.cs @@ -100,6 +100,7 @@ public void ShallowCopyHttpConnectionOptionsCopiesAllPublicProperties() { $"{nameof(HttpConnectionOptions.WebSocketFactory)}", webSocketFactory }, { $"{nameof(HttpConnectionOptions.ApplicationMaxBufferSize)}", 1L * 1024 * 1024 }, { $"{nameof(HttpConnectionOptions.TransportMaxBufferSize)}", 1L * 1024 * 1024 }, + { $"{nameof(HttpConnectionOptions.UseAcks)}", true }, }; var options = new HttpConnectionOptions(); diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs index 4434d8c02d6e..9a40615c1c4e 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs @@ -453,7 +453,11 @@ private async Task NegotiateAsync(Uri url, HttpClient httpC { uri = Utils.AppendQueryString(urlBuilder.Uri, $"negotiateVersion={_protocolVersionNumber}"); } - uri = Utils.AppendQueryString(uri, "useAck=true"); + + if (_httpConnectionOptions.UseAcks) + { + uri = Utils.AppendQueryString(uri, "useAck=true"); + } using (var request = new HttpRequestMessage(HttpMethod.Post, uri)) { diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnectionFactory.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnectionFactory.cs index 03b52fa385a4..b5f2b035e071 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnectionFactory.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnectionFactory.cs @@ -88,7 +88,8 @@ internal static HttpConnectionOptions ShallowCopyHttpConnectionOptions(HttpConne CloseTimeout = options.CloseTimeout, DefaultTransferFormat = options.DefaultTransferFormat, ApplicationMaxBufferSize = options.ApplicationMaxBufferSize, - TransportMaxBufferSize = options.TransportMaxBufferSize + TransportMaxBufferSize = options.TransportMaxBufferSize, + UseAcks = options.UseAcks, }; if (!OperatingSystem.IsBrowser()) diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnectionOptions.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnectionOptions.cs index 443dc2307786..d63a60c136e4 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnectionOptions.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnectionOptions.cs @@ -275,6 +275,11 @@ public Action? WebSocketConfiguration } } + /// + /// TODO + /// + public bool UseAcks { get; set; } + private static void ThrowIfUnsupportedPlatform() { if (OperatingSystem.IsBrowser()) diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs index 469ae37519d1..ba89974a0df5 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs @@ -314,9 +314,9 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio // write nothing so just the ackid gets sent to server // server will then send everything client may have missed as well as the last ackid so the client can resend var buf = new byte[AckPipeWriter.FrameSize]; - AckPipeWriter.WriteFrame(buf.AsSpan(), 0, ((AckPipeWriter)(_transport.Output)).lastAck); + AckPipeWriter.WriteFrame(buf.AsSpan(), 0, ((AckPipeWriter)(_transport.Output)).LastAck); await _webSocket.SendAsync(new ArraySegment(buf, 0, AckPipeWriter.FrameSize), _webSocketMessageType, true, default).ConfigureAwait(false); - _logger.LogInformation("send resend {lastAck}", ((AckPipeWriter)(_transport.Output)).lastAck); + _logger.LogInformation("send resend {lastAck}", ((AckPipeWriter)(_transport.Output)).LastAck); // set after first send to server reader.Resend(); // once we've received something from the server (which will contain the ack id for the client) @@ -330,7 +330,7 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio _application.Output.Advance(receiveResult.Count); // server sends 0 length, but with latest ack, so there shouldn't be more than a frame of data sent Debug.Assert(receiveResult.Count == AckPipeWriter.FrameSize); - LogBytes(memory.Slice(0, receiveResult.Count), _logger); + //LogBytes(memory.Slice(0, receiveResult.Count), _logger); // Parsing ack id and updating reader here avoids issue where we send to server before receive loop runs, which is what normally updates ack // This avoids resending data that was already acked var parsedLen = ParseAckPipeReader.ParseFrame(new ReadOnlySequence(memory), reader); @@ -494,7 +494,7 @@ private async Task StartReceiving(WebSocket socket) Log.MessageReceived(_logger, receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage); - LogBytes(memory.Slice(0, receiveResult.Count), _logger); + //LogBytes(memory.Slice(0, receiveResult.Count), _logger); void LogBytes(Memory memory, ILogger logger) { var sb = new StringBuilder(); @@ -557,6 +557,7 @@ private async Task StartSending(WebSocket socket) try { + // TODO: only for acks var ignoreFirstCanceled = true; _logger.LogInformation("send started"); while (true) @@ -564,7 +565,7 @@ private async Task StartSending(WebSocket socket) var result = await _application.Input.ReadAsync().ConfigureAwait(false); var buffer = result.Buffer; - LogBytes(buffer.ToArray(), _logger); + //LogBytes(buffer.ToArray(), _logger); void LogBytes(Memory memory, ILogger logger) { var sb = new StringBuilder(); diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/PublicAPI.Unshipped.txt b/src/SignalR/clients/csharp/Http.Connections.Client/src/PublicAPI.Unshipped.txt index 7dc5c58110bf..80a34974298d 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/PublicAPI.Unshipped.txt +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/PublicAPI.Unshipped.txt @@ -1 +1,3 @@ #nullable enable +Microsoft.AspNetCore.Http.Connections.Client.HttpConnectionOptions.UseAcks.get -> bool +Microsoft.AspNetCore.Http.Connections.Client.HttpConnectionOptions.UseAcks.set -> void diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs index a93156e81233..9927481610c6 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs @@ -46,6 +46,7 @@ internal sealed partial class HttpConnectionContext : ConnectionContext, private CancellationTokenSource? _sendCts; private bool _activeSend; private long _startedSendTime; + private readonly bool _useAcks; private readonly object _sendingLock = new object(); internal CancellationToken SendingToken { get; private set; } @@ -57,7 +58,8 @@ internal sealed partial class HttpConnectionContext : ConnectionContext, /// Creates the DefaultConnectionContext without Pipes to avoid upfront allocations. /// The caller is expected to set the and pipes manually. /// - public HttpConnectionContext(string connectionId, string connectionToken, ILogger logger, MetricsContext metricsContext, IDuplexPipe transport, IDuplexPipe application, HttpConnectionDispatcherOptions options) + public HttpConnectionContext(string connectionId, string connectionToken, ILogger logger, MetricsContext metricsContext, + IDuplexPipe transport, IDuplexPipe application, HttpConnectionDispatcherOptions options, bool useAcks) { Transport = transport; _applicationStream = new PipeWriterStream(application.Output); @@ -95,8 +97,11 @@ public HttpConnectionContext(string connectionId, string connectionToken, ILogge _connectionCloseRequested = new CancellationTokenSource(); ConnectionClosedRequested = _connectionCloseRequested.Token; AuthenticationExpiration = DateTimeOffset.MaxValue; + _useAcks = useAcks; } + public bool UseAcks => _useAcks; + public CancellationTokenSource? Cancellation { get; set; } public HttpTransportType TransportType { get; set; } @@ -528,6 +533,12 @@ internal async Task CancelPreviousPoll(HttpContext context) // Cancel the previous request cts?.Cancel(); + // TODO: remove transport check once other transports support acks + if (UseAcks && TransportType == HttpTransportType.WebSockets) + { + Application.Input.CancelPendingRead(); + } + try { // Wait for the previous request to drain diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs index c129671b4dee..a38072c94054 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs @@ -190,7 +190,6 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti var transport = HttpTransportType.LongPolling; if (context.WebSockets.IsWebSocketRequest) { - transport = HttpTransportType.WebSockets; } else @@ -221,10 +220,13 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti return; } - if (!await connection.CancelPreviousPoll(context)) + if (connection.TransportType != HttpTransportType.WebSockets || connection.UseAcks) { - // Connection closed. It's already set the response status code. - return; + if (!await connection.CancelPreviousPoll(context)) + { + // Connection closed. It's already set the response status code. + return; + } } // Create a new Tcs every poll to keep track of the poll finishing, so we can properly wait on previous polls @@ -236,7 +238,10 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti break; case HttpTransportType.WebSockets: var ws = new WebSocketsServerTransport(options.WebSockets, connection.Application, connection, _loggerFactory); - connection.TryActivatePersistentConnection(connectionDelegate, ws, currentRequestTcs.Task, context, _logger); + if (!connection.TryActivatePersistentConnection(connectionDelegate, ws, currentRequestTcs.Task, context, _logger)) + { + return; + } break; case HttpTransportType.LongPolling: if (!connection.TryActivateLongPollingConnection( @@ -299,10 +304,11 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti } else { - Console.WriteLine("waiting transporttask"); + // If false then the transport was ungracefully closed, this can mean a temporary network disconnection + // We'll mark the connection as inactive and allow the connection to reconnect if that's the case. + // TODO: If acks aren't enabled we can close the connection immediately if (await connection.TransportTask!) { - Console.WriteLine("disposing"); await _manager.DisposeAndRemoveAsync(connection, closeGracefully: true); } else diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs index 11da9a2ecc7d..0d34cd55b362 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs @@ -93,13 +93,13 @@ internal HttpConnectionContext CreateConnection(HttpConnectionDispatcherOptions _metrics.ConnectionStart(metricsContext); var pair = DuplexPipe.CreateConnectionPair(options.TransportPipeOptions, options.AppPipeOptions); - var connection = new HttpConnectionContext(id, connectionToken, _connectionLogger, metricsContext, pair.Application, pair.Transport, options); + var connection = new HttpConnectionContext(id, connectionToken, _connectionLogger, metricsContext, pair.Application, pair.Transport, options, useAck); _connections.TryAdd(connectionToken, (connection, startTimestamp)); return connection; - static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) + static DuplexPipePair CreateAckConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) { var input = new Pipe(inputOptions); var output = new Pipe(outputOptions); diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs index 4ebc5add3f71..f182201bee68 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs @@ -83,8 +83,8 @@ public async Task ProcessSocketAsync(WebSocket socket) ? WebSocketMessageType.Binary : WebSocketMessageType.Text); buf = new byte[AckPipeWriter.FrameSize]; - AckPipeWriter.WriteFrame(buf.AsSpan(), 0, ((AckPipeWriter)(_connection.Transport.Output)).lastAck); - _logger.LogInformation("sending resend ack {lastack}", ((AckPipeWriter)(_connection.Transport.Output)).lastAck); + AckPipeWriter.WriteFrame(buf.AsSpan(), 0, ((AckPipeWriter)(_connection.Transport.Output)).LastAck); + _logger.LogInformation("sending resend ack {lastack}", ((AckPipeWriter)(_connection.Transport.Output)).LastAck); await socket.SendAsync(buf, webSocketMessageType, endOfMessage: true, _connection.SendingToken); } } @@ -203,7 +203,7 @@ void LogBytes(Memory memory, ILogger logger) } Log.MessageReceived(_logger, receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage); - LogBytes(memory.Slice(0, receiveResult.Count), _logger); + //LogBytes(memory.Slice(0, receiveResult.Count), _logger); _application.Output.Advance(receiveResult.Count); @@ -275,7 +275,7 @@ private async Task StartSending(WebSocket socket, bool ignoreFirstCancel) ? WebSocketMessageType.Binary : WebSocketMessageType.Text); - LogBytes(buffer.ToArray(), _logger); + //LogBytes(buffer.ToArray(), _logger); void LogBytes(Memory memory, ILogger logger) { diff --git a/src/SignalR/common/Http.Connections/test/WebSocketsTests.cs b/src/SignalR/common/Http.Connections/test/WebSocketsTests.cs index a3283f1b8a51..48163419726b 100644 --- a/src/SignalR/common/Http.Connections/test/WebSocketsTests.cs +++ b/src/SignalR/common/Http.Connections/test/WebSocketsTests.cs @@ -110,7 +110,8 @@ public async Task WebSocketTransportSetsMessageTypeBasedOnTransferFormatFeature( private HttpConnectionContext CreateHttpConnectionContext(DuplexPipe.DuplexPipePair pair, string loggerName = null) { - return new HttpConnectionContext("foo", connectionToken: null, LoggerFactory.CreateLogger(loggerName ?? nameof(HttpConnectionContext)), metricsContext: default, pair.Transport, pair.Application, new()); + return new HttpConnectionContext("foo", connectionToken: null, LoggerFactory.CreateLogger(loggerName ?? nameof(HttpConnectionContext)), + metricsContext: default, pair.Transport, pair.Application, new(), useAcks: false); } [Fact] diff --git a/src/SignalR/common/Shared/AcknowledgePipeV2.cs b/src/SignalR/common/Shared/AcknowledgePipeV2.cs index 0dc90fa0ae5b..64a87953fbcf 100644 --- a/src/SignalR/common/Shared/AcknowledgePipeV2.cs +++ b/src/SignalR/common/Shared/AcknowledgePipeV2.cs @@ -44,9 +44,7 @@ public void Ack(long byteID) { return; } - //Debug.Assert(byteID >= _ackId); _ackDiff = byteID - _ackId; - //Console.WriteLine($"AckId: {byteID}"); if (_totalWritten < byteID) { @@ -61,13 +59,14 @@ static void Throw(long id, long total) public bool Resend() { + // TODO: Do we need to check this? Debug.Assert(_resend == false); if (_totalWritten == 0) { return false; } // Unblocks ReadAsync and gives a buffer with the examined but not consumed bytes - // This avoids the issue where we wait for someone to write to the pipe before completing the reconnect handshake + // This avoids the issue where we have to wait for someone to write to the pipe before completing the reconnect handshake CancelPendingRead(); _resend = true; return true; @@ -136,13 +135,15 @@ public override async ValueTask ReadAsync(CancellationToken cancella } else if (buffer.Slice(_consumed).Length > ackSlice.Length) { + // ack is greater than consumed, should not be possible + + // TODO: verify that if ack is less than total but more than consumed this isn't hit + // e.g. 13 bytes in underlying pipe, only consumed 11 during Read+Advance. Will an ack id of 12 be allowed? Debug.Assert(false); } else if (buffer.Slice(_consumed).Length < ackSlice.Length) { // this is normal, ack id is less than total written - - //_totalWritten += ackSlice.Length - buffer.Slice(_consumed).Length; } } @@ -150,11 +151,6 @@ public override async ValueTask ReadAsync(CancellationToken cancella _ackId += _ackDiff; _ackDiff = 0; _ackPosition = buffer.Start; - //if (buffer.Length == 0) - //{ - // _ackPosition = default; - // _consumed = default; - //} } } bool wasResend = _resend; @@ -181,19 +177,14 @@ public override async ValueTask ReadAsync(CancellationToken cancella else if (buffer.Length > 0) { _ackPosition = buffer.Start; - // TODO: buffer.Length is 0 sometimes, figure out why and verify behavior - if (buffer.Length > 0 && !_consumed.Equals(default)) + if (!_consumed.Equals(default)) { buffer = buffer.Slice(_consumed); } _totalWritten += (uint)buffer.Length; } + res = new(buffer, res.IsCanceled, res.IsCompleted); - //if (buffer.Length == 0) - //{ - // // everything has been acked - // _ackPosition = default; - //} return res; } @@ -208,7 +199,7 @@ internal sealed class AckPipeWriter : PipeWriter { public const int FrameSize = 24; private readonly PipeWriter _inner; - internal long lastAck; + internal long LastAck; Memory _frameHeader; bool _shouldAdvanceFrameHeader; @@ -248,7 +239,7 @@ public override ValueTask FlushAsync(CancellationToken cancellation { Debug.Assert(_frameHeader.Length >= FrameSize); - WriteFrame(_frameHeader.Span, _buffered, lastAck); + WriteFrame(_frameHeader.Span, _buffered, LastAck); _frameHeader = Memory.Empty; _buffered = 0; diff --git a/src/SignalR/common/Shared/ParseAckPipeReader.cs b/src/SignalR/common/Shared/ParseAckPipeReader.cs index 9a248f9abfd0..5b940562f68d 100644 --- a/src/SignalR/common/Shared/ParseAckPipeReader.cs +++ b/src/SignalR/common/Shared/ParseAckPipeReader.cs @@ -89,7 +89,7 @@ public override async ValueTask ReadAsync(CancellationToken cancella ReadOnlySequence buffer = res.Buffer; if (_remaining == 0) { - // TODO: didn't get 16 bytes + // TODO: didn't get 24 bytes var frame = buffer.Slice(0, FrameSize); var len = ParseFrame(in frame, _ackPipeReader); _totalBytes += len; @@ -112,7 +112,7 @@ public override async ValueTask ReadAsync(CancellationToken cancella _currentRead = buffer; // 0 length means it was part of the reconnect handshake and not sent over the pipe, ignore it for acking purposes // TODO: check if 0 byte writes are possible in ConnectionHandlers and possibly handle them differently - _ackPipeWriter.lastAck += buffer.Length == 0 ? 0 : buffer.Length + FrameSize; + _ackPipeWriter.LastAck += buffer.Length == 0 ? 0 : buffer.Length + FrameSize; } else { @@ -121,10 +121,9 @@ public override async ValueTask ReadAsync(CancellationToken cancella // We'll need to start buffering to parse multiple frames of data if (_remaining <= _currentRead.Length && buffer.Length > _remaining) { - // TODO - Console.WriteLine("multi frame"); + // TODO: multi-frame support } - _ackPipeWriter.lastAck += Math.Min(_remaining, newBytes); + _ackPipeWriter.LastAck += Math.Min(_remaining, newBytes); _currentRead = buffer; buffer = buffer.Slice(0, Math.Min(_remaining, buffer.Length)); } diff --git a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs index 612688ca5013..472f77d8cd0e 100644 --- a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs +++ b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs @@ -582,7 +582,7 @@ public async Task TriggerResendWithEverythingAcked() res = await duplexPipe.Application.Input.ReadAsync(); Assert.Equal(0, res.Buffer.Length); - Assert.False(res.IsCanceled); + Assert.True(res.IsCanceled); Assert.False(res.IsCompleted); duplexPipe.Application.Input.AdvanceTo(res.Buffer.End); @@ -594,7 +594,7 @@ public async Task TriggerResendWithEverythingAcked() Assert.Equal(FrameSize + 2, res.Buffer.Length); var (len, ackId) = ReadFrame(res.Buffer.ToArray()); Assert.Equal(2, len); - Assert.Equal(FrameSize, ackId); + Assert.Equal(0, ackId); Assert.False(res.IsCanceled); Assert.False(res.IsCompleted); } @@ -657,7 +657,8 @@ public async Task TriggerResendWhenPartialFrameAcked() await duplexPipe.Application.Output.WriteAsync(buffer); // Updates ack from App.Output in App.Input - _ = await duplexPipe.Transport.Input.ReadAsync(); + res = await duplexPipe.Transport.Input.ReadAsync(); + duplexPipe.Transport.Input.AdvanceTo(res.Buffer.End); var reader = (AckPipeReader)duplexPipe.Application.Input; reader.Resend(); @@ -665,7 +666,7 @@ public async Task TriggerResendWhenPartialFrameAcked() res = await duplexPipe.Application.Input.ReadAsync(); Assert.Equal(5, res.Buffer.Length); - Assert.False(res.IsCanceled); + Assert.True(res.IsCanceled); Assert.False(res.IsCompleted); Assert.Equal(new byte[] { 3, 4, 5, 6, 7 }, res.Buffer.ToArray()); @@ -679,7 +680,7 @@ public async Task TriggerResendWhenPartialFrameAcked() Assert.Equal(FrameSize + 2, res.Buffer.Length); var (len, ackId) = ReadFrame(res.Buffer.ToArray()); Assert.Equal(2, len); - Assert.Equal(FrameSize, ackId); + Assert.Equal(0, ackId); Assert.Equal(new byte[] { 9, 7 }, res.Buffer.ToArray().AsSpan(FrameSize, 2).ToArray()); Assert.False(res.IsCanceled); Assert.False(res.IsCompleted); diff --git a/src/SignalR/server/Core/src/HubConnectionHandler.cs b/src/SignalR/server/Core/src/HubConnectionHandler.cs index df573d42c2ab..418154e421f2 100644 --- a/src/SignalR/server/Core/src/HubConnectionHandler.cs +++ b/src/SignalR/server/Core/src/HubConnectionHandler.cs @@ -251,7 +251,7 @@ private async Task DispatchMessagesAsync(HubConnectionContext connection) { var result = await input.ReadAsync(); var buffer = result.Buffer; - LogBytes(buffer.ToArray(), _logger); + //LogBytes(buffer.ToArray(), _logger); void LogBytes(Memory memory, ILogger logger) { diff --git a/src/SignalR/server/SignalR/test/EndToEndTests.cs b/src/SignalR/server/SignalR/test/EndToEndTests.cs index bdde2c2673f1..e280f20af9ef 100644 --- a/src/SignalR/server/SignalR/test/EndToEndTests.cs +++ b/src/SignalR/server/SignalR/test/EndToEndTests.cs @@ -341,7 +341,7 @@ async Task ReceiveMessage() { logger.LogInformation("Receiving message"); // Big timeout here because it can take a while to receive all the bytes - var receivedData = await connection.Transport.Input.ReadAsync(bytes.Length).DefaultTimeout(TimeSpan.FromMinutes(2)); + var receivedData = await connection.Transport.Input.ReadAsync(bytes.Length).DefaultTimeout(); Assert.Equal(message, Encoding.UTF8.GetString(receivedData)); logger.LogInformation("Completed receive"); } From 3cb065c3251f3ebc2cdf02f6cd37680f1deafa8d Mon Sep 17 00:00:00 2001 From: Brennan Date: Fri, 14 Apr 2023 10:48:01 -0700 Subject: [PATCH 06/25] cleanup/comments --- .../csharp/Client.Core/src/HubConnection.cs | 1 - .../FunctionalTests/HubConnectionTests.cs | 2 +- .../src/Internal/ServerSentEventsTransport.cs | 1 - .../src/Internal/WebSocketsTransport.cs | 126 ++++++------------ .../src/NegotiationResponse.cs | 3 + .../src/Internal/HttpConnectionDispatcher.cs | 16 +-- .../Transports/WebSocketsServerTransport.cs | 60 +++------ .../common/Shared/AcknowledgePipeV2.cs | 21 +-- .../common/Shared/ParseAckPipeReader.cs | 73 ++++------ .../test/Internal/Protocol/AckPipeTests.cs | 56 ++++++-- ...oft.AspNetCore.SignalR.Common.Tests.csproj | 1 - .../server/Core/src/HubConnectionHandler.cs | 13 -- 12 files changed, 154 insertions(+), 219 deletions(-) diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index 862bb4a0cacc..797ecee716c8 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -1329,7 +1329,6 @@ async Task StartProcessingInvocationMessages(ChannelReader in { var result = await input.ReadAsync().ConfigureAwait(false); var buffer = result.Buffer; - _logger.LogInformation("recv {len}", buffer.Length); try { diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs index 625c0ebb1d0f..c113d7b9a2ed 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs @@ -2698,7 +2698,7 @@ public async Task ServerAbortsConnectionNoReconnectAttempted(/*string protocolNa connectionBuilder.Services.AddSingleton(protocol); var connection = connectionBuilder.Build(); - var closedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var closedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); connection.Closed += ex => { closedTcs.SetResult(ex); diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs index 48c45bf8647f..cd49ee5d203c 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs @@ -76,7 +76,6 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio throw; } - // Create the pipe pair (Application's writer is connected to Transport's reader, and vice versa) DuplexPipePair pair; if (_useAck) diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs index ba89974a0df5..1dcfcc450c2d 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs @@ -289,13 +289,15 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio _stopCts = new CancellationTokenSource(); + var ignoreFirstCanceled = false; + if (_transport is null) { // Create the pipe pair (Application's writer is connected to Transport's reader, and vice versa) DuplexPipePair pair; if (_useAck) { - pair = CreateConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); + pair = CreateAckConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); } else { @@ -307,59 +309,44 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio } else { - // TODO: set pipe to start resend if (_application!.Input is AckPipeReader reader) { - _logger.LogInformation("start resend"); - // write nothing so just the ackid gets sent to server - // server will then send everything client may have missed as well as the last ackid so the client can resend - var buf = new byte[AckPipeWriter.FrameSize]; - AckPipeWriter.WriteFrame(buf.AsSpan(), 0, ((AckPipeWriter)(_transport.Output)).LastAck); - await _webSocket.SendAsync(new ArraySegment(buf, 0, AckPipeWriter.FrameSize), _webSocketMessageType, true, default).ConfigureAwait(false); - _logger.LogInformation("send resend {lastAck}", ((AckPipeWriter)(_transport.Output)).LastAck); - // set after first send to server - reader.Resend(); - // once we've received something from the server (which will contain the ack id for the client) - // we can start the normal read/write loops, clients first send will resend everything the server missed - var memory = _application.Output.GetMemory(); - var isArray = MemoryMarshal.TryGetArray(memory, out var arraySegment); - Debug.Assert(isArray); - - // Exceptions are handled above where the send and receive tasks are being run. - var receiveResult = await _webSocket.ReceiveAsync(arraySegment, _stopCts.Token).ConfigureAwait(false); - _application.Output.Advance(receiveResult.Count); - // server sends 0 length, but with latest ack, so there shouldn't be more than a frame of data sent - Debug.Assert(receiveResult.Count == AckPipeWriter.FrameSize); - //LogBytes(memory.Slice(0, receiveResult.Count), _logger); - // Parsing ack id and updating reader here avoids issue where we send to server before receive loop runs, which is what normally updates ack - // This avoids resending data that was already acked - var parsedLen = ParseAckPipeReader.ParseFrame(new ReadOnlySequence(memory), reader); - Debug.Assert(parsedLen == 0); - void LogBytes(Memory memory, ILogger logger) + if (reader.Resend()) { - var sb = new StringBuilder(); - sb.Append("received: "); - foreach (var b in memory.Span) - { - sb.Append($"0x{b:x} "); - } - logger.LogDebug(sb.ToString()); + // Start reconnect ack handshake + // 1. Send ack ID to server for last message we recieved from server before we disconnected + // 2. Read from server to get the last ack ID it received before we disconnecting + // 3. Resume normal send/receive loops + + ignoreFirstCanceled = true; + var buf = new byte[AckPipeWriter.FrameSize]; + AckPipeWriter.WriteFrame(buf.AsSpan(), 0, ((AckPipeWriter)_transport.Output).LastAck); + await _webSocket.SendAsync(new ArraySegment(buf, 0, AckPipeWriter.FrameSize), _webSocketMessageType, true, _stopCts.Token).ConfigureAwait(false); + + Array.Clear(buf, 0, buf.Length); + var receiveResult = await _webSocket.ReceiveAsync(new ArraySegment(buf, 0, AckPipeWriter.FrameSize), _stopCts.Token).ConfigureAwait(false); + // server sends 0 length, but with latest ack, so there shouldn't be more than a frame of data sent + Debug.Assert(receiveResult.Count == AckPipeWriter.FrameSize); + // Parsing ack id and updating reader here avoids issue where we send to server before receive loop runs, which is what normally updates ack + // This avoids resending data that was already acked + var parsedLen = ParseAckPipeReader.ParseFrame(new ReadOnlySequence(buf), reader); + + // TODO: why do we need to unblock the receive loop to not delay/block shutdown sometimes? + // Looks like calling stop/dispose on the client doesn't avoid the reconnect cycle in the transport, we'll need to fix that + var flushResult = await _application.Output.FlushAsync(default).ConfigureAwait(false); } - _logger.LogInformation("recv resend"); - - var flushResult = await _application.Output.FlushAsync().ConfigureAwait(false); - _logger.LogInformation("done resend"); - // TODO: figure out solution - // delay to allow receive loop to read, which updates the ack position so we don't resend data to the server - //await Task.Delay(2000); + } + else + { + Debug.Assert(false); } } // TODO: Handle TCP connection errors // https://github.com/SignalR/SignalR/blob/1fba14fa3437e24c204dfaf8a18db3fce8acad3c/src/Microsoft.AspNet.SignalR.Core/Owin/WebSockets/WebSocketHandler.cs#L248-L251 - Running = ProcessSocketAsync(_webSocket, url); + Running = ProcessSocketAsync(_webSocket, url, ignoreFirstCanceled); - static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) + static DuplexPipePair CreateAckConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) { var input = new Pipe(inputOptions); var output = new Pipe(outputOptions); @@ -375,7 +362,7 @@ static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions } } - private async Task ProcessSocketAsync(WebSocket socket, Uri url) + private async Task ProcessSocketAsync(WebSocket socket, Uri url, bool ignoreFirstCanceled) { Debug.Assert(_application != null); @@ -383,7 +370,7 @@ private async Task ProcessSocketAsync(WebSocket socket, Uri url) { // Begin sending and receiving. var receiving = StartReceiving(socket); - var sending = StartSending(socket); + var sending = StartSending(socket, ignoreFirstCanceled); // Wait for send or receive to complete var trigger = await Task.WhenAny(receiving, sending).ConfigureAwait(false); @@ -396,11 +383,8 @@ private async Task ProcessSocketAsync(WebSocket socket, Uri url) // 1. Waiting for application data // 2. Waiting for a websocket send to complete - //if (_closed) - { - // Cancel the application so that ReadAsync yields - _application.Input.CancelPendingRead(); - } + // Cancel the application so that ReadAsync yields + _application.Input.CancelPendingRead(); var resultTask = await Task.WhenAny(sending, Task.Delay(_closeTimeout, _stopCts.Token)).ConfigureAwait(false); @@ -431,8 +415,6 @@ private async Task ProcessSocketAsync(WebSocket socket, Uri url) } } - Console.WriteLine("closed socket"); - if (_useAck && !_closed) { await StartAsync(url, _webSocketMessageType == WebSocketMessageType.Binary ? TransferFormat.Binary : TransferFormat.Text, default).ConfigureAwait(false); @@ -445,7 +427,6 @@ private async Task StartReceiving(WebSocket socket) try { - _logger.LogInformation("recv started"); while (true) { #if NETSTANDARD2_1 || NETCOREAPP @@ -494,18 +475,6 @@ private async Task StartReceiving(WebSocket socket) Log.MessageReceived(_logger, receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage); - //LogBytes(memory.Slice(0, receiveResult.Count), _logger); - void LogBytes(Memory memory, ILogger logger) - { - var sb = new StringBuilder(); - sb.Append("received: "); - foreach (var b in memory.Span) - { - sb.Append($"0x{b:x} "); - } - logger.LogDebug(sb.ToString()); - } - _application.Output.Advance(receiveResult.Count); var flushResult = await _application.Output.FlushAsync().ConfigureAwait(false); @@ -549,7 +518,7 @@ void LogBytes(Memory memory, ILogger logger) } } - private async Task StartSending(WebSocket socket) + private async Task StartSending(WebSocket socket, bool ignoreFirstCanceled) { Debug.Assert(_application != null); @@ -557,26 +526,11 @@ private async Task StartSending(WebSocket socket) try { - // TODO: only for acks - var ignoreFirstCanceled = true; - _logger.LogInformation("send started"); while (true) { var result = await _application.Input.ReadAsync().ConfigureAwait(false); var buffer = result.Buffer; - //LogBytes(buffer.ToArray(), _logger); - void LogBytes(Memory memory, ILogger logger) - { - var sb = new StringBuilder(); - sb.Append("sending: "); - foreach (var b in memory.Span) - { - sb.Append($"0x{b:x} "); - } - logger.LogDebug(sb.ToString()); - } - // Get a frame from the application try @@ -654,13 +608,9 @@ void LogBytes(Memory memory, ILogger logger) if (_closed) { - _application.Input.Complete(); - } - - if (error is not null) - { - _logger.LogInformation(error, "send loop"); + _application.Input.Complete(error); } + // TODO: log error in else? Log.SendStopped(_logger); } diff --git a/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs b/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs index 17ba7671064d..4ec99bba1055 100644 --- a/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs +++ b/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs @@ -45,5 +45,8 @@ public class NegotiationResponse /// public string? Error { get; set; } + /// + /// + /// public bool UseAcking { get; set; } } diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs index a38072c94054..16060e692351 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs @@ -155,7 +155,7 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti if (connection.TryActivatePersistentConnection(connectionDelegate, sse, Task.CompletedTask, context, _logger)) { - await DoPersistentConnection(connectionDelegate, sse, context, connection); + await DoPersistentConnection(connection); } } //else if (context.WebSockets.IsWebSocketRequest) @@ -327,16 +327,12 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti } } - private async Task DoPersistentConnection(ConnectionDelegate connectionDelegate, - IHttpTransport transport, - HttpContext context, - HttpConnectionContext connection) + private async Task DoPersistentConnection(HttpConnectionContext connection, HttpContext context) { - //if (connection.TryActivatePersistentConnection(connectionDelegate, transport, context, _logger)) - { - context.Features.Get()?.DisableTimeout(); - // Wait for any of them to end - await Task.WhenAny(connection.ApplicationTask!, connection.TransportTask!); + context.Features.Get()?.DisableTimeout(); + + // Wait for any of them to end + await Task.WhenAny(connection.ApplicationTask!, connection.TransportTask!); await _manager.DisposeAndRemoveAsync(connection, closeGracefully: true, HttpConnectionStopStatus.NormalClosure); } diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs index f182201bee68..e133e08ba06b 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs @@ -1,12 +1,10 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; using System.Buffers; using System.Diagnostics; using System.IO.Pipelines; using System.Net.WebSockets; -using System.Text; using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.Logging; using PipelinesOverNetwork; @@ -21,6 +19,7 @@ internal sealed partial class WebSocketsServerTransport : IHttpTransport private readonly HttpConnectionContext _connection; private volatile bool _aborted; + // Used to determine if the close was graceful or a network issue private bool _closed; public WebSocketsServerTransport(WebSocketOptions options, IDuplexPipe application, HttpConnectionContext connection, ILoggerFactory loggerFactory) @@ -62,8 +61,6 @@ public async Task ProcessRequestAsync(HttpContext context, CancellationTok public async Task ProcessSocketAsync(WebSocket socket) { - Task receiving; - Task sending; if (_application.Input is AckPipeReader reader) { _aborted = false; @@ -71,31 +68,38 @@ public async Task ProcessSocketAsync(WebSocket socket) // Currently checked in Resend if (reader.Resend()) { + // Start reconnect ack handshake + // 1. Read from client to get the last ack ID it received before disconnecting + // 2. Send ack ID to client for last message we received from client before it disconnected + // 3. Resume normal send/receive loops + // wait for first read? var buf = new byte[AckPipeWriter.FrameSize]; var res = await socket.ReceiveAsync(buf, _connection.Cancellation?.Token ?? default); Debug.Assert(res.Count == AckPipeWriter.FrameSize); + // Needed so that the readers ack position gets updated and we don't re-send messages to client + // Normally this would be done by the HubConnectionHandler loop, but that requires a new message to be read + // so we instead make sure it's updated immediately here var parsedLen = ParseAckPipeReader.ParseFrame(new ReadOnlySequence(buf), reader); Debug.Assert(parsedLen == 0); - await _application.Output.WriteAsync(buf); + // we don't need to write to the pipe if we parse the frame? + //await _application.Output.WriteAsync(buf); var webSocketMessageType = (_connection.ActiveFormat == TransferFormat.Binary ? WebSocketMessageType.Binary : WebSocketMessageType.Text); - buf = new byte[AckPipeWriter.FrameSize]; - AckPipeWriter.WriteFrame(buf.AsSpan(), 0, ((AckPipeWriter)(_connection.Transport.Output)).LastAck); - _logger.LogInformation("sending resend ack {lastack}", ((AckPipeWriter)(_connection.Transport.Output)).LastAck); + Array.Clear(buf); + Debug.Assert(_connection.Transport.Output is AckPipeWriter); + AckPipeWriter.WriteFrame(buf.AsSpan(), 0, ((AckPipeWriter)_connection.Transport.Output).LastAck); + _connection.StartSendCancellation(); + // send without going through the Pipe, we don't treat this as an ackable message await socket.SendAsync(buf, webSocketMessageType, endOfMessage: true, _connection.SendingToken); + _connection.StopSendCancellation(); } } - // if (_application.Input.HasBeenUsedBefore) - // read first to get the ack id for resending - // set resend id on output pipe - // start send loop which will resend and tell the client the last ack id it got from the read side - // Begin sending and receiving. Receiving must be started first because ExecuteAsync enables SendAsync. - receiving = StartReceiving(socket); - sending = StartSending(socket, true); + var receiving = StartReceiving(socket); + var sending = StartSending(socket, true); // Wait for send or receive to complete var trigger = await Task.WhenAny(receiving, sending); @@ -167,7 +171,6 @@ private async Task StartReceiving(WebSocket socket) try { - _logger.LogInformation("start recv"); while (!token.IsCancellationRequested) { @@ -191,17 +194,6 @@ private async Task StartReceiving(WebSocket socket) return; } - void LogBytes(Memory memory, ILogger logger) - { - var sb = new StringBuilder(); - sb.Append("received: "); - foreach (var b in memory.Span) - { - sb.Append($"0x{b:x} "); - } - logger.LogDebug(sb.ToString()); - } - Log.MessageReceived(_logger, receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage); //LogBytes(memory.Slice(0, receiveResult.Count), _logger); @@ -250,7 +242,6 @@ private async Task StartSending(WebSocket socket, bool ignoreFirstCancel) try { - _logger.LogInformation("start send"); while (true) { var result = await _application.Input.ReadAsync(); @@ -275,19 +266,6 @@ private async Task StartSending(WebSocket socket, bool ignoreFirstCancel) ? WebSocketMessageType.Binary : WebSocketMessageType.Text); - //LogBytes(buffer.ToArray(), _logger); - - void LogBytes(Memory memory, ILogger logger) - { - var sb = new StringBuilder(); - sb.Append("sending: "); - foreach (var b in memory.Span) - { - sb.Append($"0x{b:x} "); - } - logger.LogDebug(sb.ToString()); - } - if (WebSocketCanSend(socket)) { _connection.StartSendCancellation(); diff --git a/src/SignalR/common/Shared/AcknowledgePipeV2.cs b/src/SignalR/common/Shared/AcknowledgePipeV2.cs index 64a87953fbcf..4ff94e831e05 100644 --- a/src/SignalR/common/Shared/AcknowledgePipeV2.cs +++ b/src/SignalR/common/Shared/AcknowledgePipeV2.cs @@ -66,7 +66,8 @@ public bool Resend() return false; } // Unblocks ReadAsync and gives a buffer with the examined but not consumed bytes - // This avoids the issue where we have to wait for someone to write to the pipe before completing the reconnect handshake + // This avoids the issue where we have to wait for someone to write to the pipe before + // the receive loop will see what might have been written during disconnect CancelPendingRead(); _resend = true; return true; @@ -113,12 +114,11 @@ public override async ValueTask ReadAsync(CancellationToken cancella { var res = await _inner.ReadAsync(cancellationToken).ConfigureAwait(false); var buffer = res.Buffer; - long hadAck = 0; + lock (_lock) { if (_ackDiff != 0) { - hadAck = _ackDiff; // This detects the odd scenario where _consumed points to the end of a Segment and buffer.Slice(_ackDiff) points to the beginning of the next Segment // While they technically point to different positions, they point to the same concept of "beginning of the next buffer" var ackSlice = buffer.Slice(_ackDiff); @@ -129,11 +129,12 @@ public override async ValueTask ReadAsync(CancellationToken cancella } else if (!_consumed.Equals(default)) { - if (buffer.Slice(_consumed).Length == ackSlice.Length) + var consumedLength = buffer.Slice(_consumed).Length; + if (consumedLength == ackSlice.Length) { _consumed = default; } - else if (buffer.Slice(_consumed).Length > ackSlice.Length) + else if (consumedLength > ackSlice.Length) { // ack is greater than consumed, should not be possible @@ -141,7 +142,7 @@ public override async ValueTask ReadAsync(CancellationToken cancella // e.g. 13 bytes in underlying pipe, only consumed 11 during Read+Advance. Will an ack id of 12 be allowed? Debug.Assert(false); } - else if (buffer.Slice(_consumed).Length < ackSlice.Length) + else if (consumedLength < ackSlice.Length) { // this is normal, ack id is less than total written } @@ -153,7 +154,7 @@ public override async ValueTask ReadAsync(CancellationToken cancella _ackPosition = buffer.Start; } } - bool wasResend = _resend; + // Slice consumed, unless resending, then slice to ackPosition if (_resend) { @@ -231,8 +232,10 @@ public override void Complete(Exception? exception = null) _inner.Complete(exception); } - // X - 8 byte size of payload as uint - // Y - 8 byte number of acked bytes + // TODO: We could reduce this to 16 bytes for binary transports and avoid the base64 encode/decode + // TODO: We could also reduce this to 1 + 12 (or 8) bytes occasionally if we add a flag for no new ack ID and avoid sending an ack + // X - 12 byte - size of payload as long and base64 encoded + // Y - 12 byte - number of acked bytes as long and base64 encoded // Z - payload // [ XXXX YYYY ZZZZ ] public override ValueTask FlushAsync(CancellationToken cancellationToken = default) diff --git a/src/SignalR/common/Shared/ParseAckPipeReader.cs b/src/SignalR/common/Shared/ParseAckPipeReader.cs index 5b940562f68d..9e723777f605 100644 --- a/src/SignalR/common/Shared/ParseAckPipeReader.cs +++ b/src/SignalR/common/Shared/ParseAckPipeReader.cs @@ -17,7 +17,7 @@ namespace PipelinesOverNetwork; // Parse framing and slice the read so the application doesn't see the framing // Notify outbound pipe of framing details for when sending back // Notify application pipe of ack id provided by other side of the network -internal class ParseAckPipeReader : PipeReader +internal sealed class ParseAckPipeReader : PipeReader { private const int FrameSize = 24; private readonly PipeReader _inner; @@ -91,19 +91,13 @@ public override async ValueTask ReadAsync(CancellationToken cancella { // TODO: didn't get 24 bytes var frame = buffer.Slice(0, FrameSize); - var len = ParseFrame(in frame, _ackPipeReader); + var len = ParseFrame(frame, _ackPipeReader); _totalBytes += len; - // 0 len sent on reconnect and not part of acks - if (len != 0) - { - //Console.WriteLine($"lastack: {_ackPipeWriter.lastAck} to {_ackPipeWriter.lastAck + res.Buffer.Length}"); - //_ackPipeWriter.lastAck += res.Buffer.Length; - } _remaining = len; // if the buffer doesn't have enough data we need to update how much we're slicing - if (len >= buffer.Length - FrameSize) + if (len > buffer.Length - FrameSize) { len = buffer.Length - FrameSize; } @@ -145,49 +139,40 @@ public override async ValueTask ReadAsync(CancellationToken cancella return res; } - public static long ParseFrame(in ReadOnlySequence frame, AckPipeReader ackPipeReader) + public static long ParseFrame(ReadOnlySequence frame, AckPipeReader ackPipeReader) { Debug.Assert(frame.Length >= FrameSize); + frame = frame.Slice(0, FrameSize); long len; long ackId; + + // TODO: check perf of single Span check vs Stackalloc + Span buffer = stackalloc byte[FrameSize]; + frame.CopyTo(buffer); + var status = Base64.DecodeFromUtf8InPlace(buffer.Slice(0, FrameSize / 2), out var written); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(written == 8); + #if NETSTANDARD2_1_OR_GREATER || NET5_0_OR_GREATER - // Both the Span check and Stackalloc paths are faster than using SequenceReader - var frameSpan = frame.FirstSpan; - if (frameSpan.Length >= FrameSize) - { - Span decodedBytes = stackalloc byte[8]; - var status = Base64.DecodeFromUtf8(frameSpan.Slice(0, 12), decodedBytes, out var consumed, out var written, isFinalBlock: true); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(consumed == 12); - Debug.Assert(written == 8); - len = BitConverter.ToInt64(decodedBytes); - status = Base64.DecodeFromUtf8(frameSpan.Slice(12, 12), decodedBytes, out consumed, out written, isFinalBlock: true); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(consumed == 12); - Debug.Assert(written == 8); - ackId = BitConverter.ToInt64(decodedBytes); - } - else - { - Span buffer = stackalloc byte[FrameSize]; - frame.CopyTo(buffer); - var status = Base64.DecodeFromUtf8InPlace(buffer, out var written); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(written == 8); - len = BitConverter.ToInt64(buffer); - status = Base64.DecodeFromUtf8InPlace(buffer.Slice(12), out written); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(written == 8); - ackId = BitConverter.ToInt64(buffer.Slice(12)); - } + len = BitConverter.ToInt64(buffer); #else -// TODO - Span buffer = stackalloc byte[FrameSize]; - frame.CopyTo(buffer); - len = BitConverter.ToInt64(buffer.Slice(0, 8).ToArray(), 0); - ackId = BitConverter.ToInt64(buffer.Slice(8).ToArray(), 0); + var longBuf = new byte[8]; + buffer.Slice(0, 8).CopyTo(longBuf); + len = BitConverter.ToInt64(longBuf, 0); #endif + + status = Base64.DecodeFromUtf8InPlace(buffer.Slice(FrameSize / 2), out written); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(written == 8); + +#if NETSTANDARD2_1_OR_GREATER || NET5_0_OR_GREATER + ackId = BitConverter.ToInt64(buffer.Slice(FrameSize / 2)); +#else + buffer.Slice(12, 8).CopyTo(longBuf); + ackId = BitConverter.ToInt64(longBuf, 0); +#endif + // Update ack id provided by other side, so the underlying pipe can release buffered memory ackPipeReader.Ack(ackId); return len; diff --git a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs index 472f77d8cd0e..d0adfdc64488 100644 --- a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs +++ b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs @@ -1,18 +1,11 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; using System.Buffers; using System.Buffers.Text; -using System.Collections; -using System.Collections.Generic; using System.Diagnostics; using System.IO.Pipelines; -using System.Linq; -using System.Text; -using System.Threading.Tasks; using PipelinesOverNetwork; -using static PipelinesOverNetwork.AckDuplexPipe; namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol; @@ -23,7 +16,7 @@ public class AckPipeTests [Fact] public async Task CanSendAndReceiveTransport() { - var duplexPipe = AckDuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions()); + var duplexPipe = CreateConnectionPair(new PipeOptions(), new PipeOptions()); var values = new byte[] { 1, 2, 3, 4, 5 }; var flushRes = await duplexPipe.Transport.Output.WriteAsync(values); @@ -42,7 +35,7 @@ public async Task CanSendAndReceiveTransport() [Fact] public async Task CanSendAndReceiveLargeAmount() { - var duplexPipe = AckDuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions()); + var duplexPipe = CreateConnectionPair(new PipeOptions(), new PipeOptions()); var values = new byte[20000]; Random.Shared.NextBytes(values); @@ -62,7 +55,7 @@ public async Task CanSendAndReceiveLargeAmount() [Fact] public async Task CanSendAndReceiveLargeAmount_ManyWritesSingleFlush() { - var duplexPipe = AckDuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions()); + var duplexPipe = CreateConnectionPair(new PipeOptions(), new PipeOptions()); var values = new byte[20000]; Random.Shared.NextBytes(values); @@ -763,4 +756,47 @@ internal static (long PayloadLength, long AckId) ReadFrame(ref Span header return (len, ackId); } + + internal static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) + { + var input = new Pipe(inputOptions); + var output = new Pipe(outputOptions); + + // wire up both sides for testing + var ackWriterApp = new AckPipeWriter(output.Writer); + var ackReaderApp = new AckPipeReader(output.Reader); + var ackWriterClient = new AckPipeWriter(input.Writer); + var ackReaderClient = new AckPipeReader(input.Reader); + var transportReader = new ParseAckPipeReader(input.Reader, ackWriterApp, ackReaderApp); + var applicationReader = new ParseAckPipeReader(ackReaderApp, ackWriterClient, ackReaderClient); + var transportToApplication = new DuplexPipe(applicationReader, ackWriterClient); + var applicationToTransport = new DuplexPipe(transportReader, ackWriterApp); + + return new DuplexPipePair(applicationToTransport, transportToApplication); + } + + internal sealed class DuplexPipe : IDuplexPipe + { + public DuplexPipe(PipeReader reader, PipeWriter writer) + { + Input = reader; + Output = writer; + } + + public PipeReader Input { get; } + + public PipeWriter Output { get; } + } + + public readonly struct DuplexPipePair + { + public IDuplexPipe Transport { get; } + public IDuplexPipe Application { get; } + + public DuplexPipePair(IDuplexPipe transport, IDuplexPipe application) + { + Transport = transport; + Application = application; + } + } } diff --git a/src/SignalR/common/SignalR.Common/test/Microsoft.AspNetCore.SignalR.Common.Tests.csproj b/src/SignalR/common/SignalR.Common/test/Microsoft.AspNetCore.SignalR.Common.Tests.csproj index a6631b548a30..9a08da67fabb 100644 --- a/src/SignalR/common/SignalR.Common/test/Microsoft.AspNetCore.SignalR.Common.Tests.csproj +++ b/src/SignalR/common/SignalR.Common/test/Microsoft.AspNetCore.SignalR.Common.Tests.csproj @@ -7,7 +7,6 @@ - diff --git a/src/SignalR/server/Core/src/HubConnectionHandler.cs b/src/SignalR/server/Core/src/HubConnectionHandler.cs index 418154e421f2..ab3d0f5bbd7b 100644 --- a/src/SignalR/server/Core/src/HubConnectionHandler.cs +++ b/src/SignalR/server/Core/src/HubConnectionHandler.cs @@ -1,9 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Buffers; using System.Linq; -using System.Text; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; @@ -251,17 +249,6 @@ private async Task DispatchMessagesAsync(HubConnectionContext connection) { var result = await input.ReadAsync(); var buffer = result.Buffer; - //LogBytes(buffer.ToArray(), _logger); - - void LogBytes(Memory memory, ILogger logger) - { - var sb = new StringBuilder(); - foreach (var b in memory.Span) - { - sb.Append($"0x{b:x} "); - } - logger.LogDebug($"read: {sb}"); - } try { From 47e34566a649e4d68a796d198d944aa9860d7b26 Mon Sep 17 00:00:00 2001 From: Brennan Date: Fri, 14 Apr 2023 13:14:48 -0700 Subject: [PATCH 07/25] move files and namespace --- .../src/Internal/LongPollingTransport.cs | 1 - .../src/Internal/ServerSentEventsTransport.cs | 1 - .../src/Internal/WebSocketsTransport.cs | 1 - ....AspNetCore.Http.Connections.Client.csproj | 3 +- .../src/Internal/HttpConnectionManager.cs | 1 - .../Transports/WebSocketsServerTransport.cs | 1 - ...crosoft.AspNetCore.Http.Connections.csproj | 3 +- ...{AcknowledgePipeV2.cs => AckPipeReader.cs} | 111 +---------------- src/SignalR/common/Shared/AckPipeWriter.cs | 115 ++++++++++++++++++ .../Shared/AcknowledgePipe/DuplexPipe.cs | 93 -------------- .../common/Shared/ParseAckPipeReader.cs | 2 +- .../test/Internal/Protocol/AckPipeTests.cs | 2 +- ...oft.AspNetCore.SignalR.Common.Tests.csproj | 3 +- 13 files changed, 127 insertions(+), 210 deletions(-) rename src/SignalR/common/Shared/{AcknowledgePipeV2.cs => AckPipeReader.cs} (65%) create mode 100644 src/SignalR/common/Shared/AckPipeWriter.cs delete mode 100644 src/SignalR/common/Shared/AcknowledgePipe/DuplexPipe.cs diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs index d3edf3356567..04cb1a7c7982 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs @@ -11,7 +11,6 @@ using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -using PipelinesOverNetwork; using static System.IO.Pipelines.DuplexPipe; namespace Microsoft.AspNetCore.Http.Connections.Client.Internal; diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs index cd49ee5d203c..e8583be8d516 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs @@ -13,7 +13,6 @@ using Microsoft.AspNetCore.Shared; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -using PipelinesOverNetwork; using static System.IO.Pipelines.DuplexPipe; namespace Microsoft.AspNetCore.Http.Connections.Client.Internal; diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs index 1dcfcc450c2d..3872e13c51cc 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs @@ -19,7 +19,6 @@ using Microsoft.AspNetCore.Shared; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -using PipelinesOverNetwork; using static System.IO.Pipelines.DuplexPipe; namespace Microsoft.AspNetCore.Http.Connections.Client.Internal; diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Microsoft.AspNetCore.Http.Connections.Client.csproj b/src/SignalR/clients/csharp/Http.Connections.Client/src/Microsoft.AspNetCore.Http.Connections.Client.csproj index 54cd016c2db4..d6204265fcb0 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Microsoft.AspNetCore.Http.Connections.Client.csproj +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Microsoft.AspNetCore.Http.Connections.Client.csproj @@ -11,7 +11,8 @@ - + + diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs index 0d34cd55b362..d0bd03d47c74 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs @@ -11,7 +11,6 @@ using Microsoft.Extensions.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using PipelinesOverNetwork; using static System.IO.Pipelines.DuplexPipe; namespace Microsoft.AspNetCore.Http.Connections.Internal; diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs index e133e08ba06b..c6d20d67b0fb 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs @@ -7,7 +7,6 @@ using System.Net.WebSockets; using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.Logging; -using PipelinesOverNetwork; namespace Microsoft.AspNetCore.Http.Connections.Internal.Transports; diff --git a/src/SignalR/common/Http.Connections/src/Microsoft.AspNetCore.Http.Connections.csproj b/src/SignalR/common/Http.Connections/src/Microsoft.AspNetCore.Http.Connections.csproj index c731acef4c54..10c11fbe5ac6 100644 --- a/src/SignalR/common/Http.Connections/src/Microsoft.AspNetCore.Http.Connections.csproj +++ b/src/SignalR/common/Http.Connections/src/Microsoft.AspNetCore.Http.Connections.csproj @@ -17,7 +17,8 @@ - + + diff --git a/src/SignalR/common/Shared/AcknowledgePipeV2.cs b/src/SignalR/common/Shared/AckPipeReader.cs similarity index 65% rename from src/SignalR/common/Shared/AcknowledgePipeV2.cs rename to src/SignalR/common/Shared/AckPipeReader.cs index 4ff94e831e05..189028881c9e 100644 --- a/src/SignalR/common/Shared/AcknowledgePipeV2.cs +++ b/src/SignalR/common/Shared/AckPipeReader.cs @@ -1,17 +1,15 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; -using System.Buffers; -using System.Buffers.Text; using System.Diagnostics; using System.IO.Pipelines; -using System.Threading; using System.Threading.Tasks; +using System.Threading; +using System; #nullable enable -namespace PipelinesOverNetwork; +namespace Microsoft.AspNetCore.Http.Connections; // Wrapper around a PipeReader that adds an Ack position which replaces Consumed // This allows the underlying pipe to keep un-acked data in the pipe while still providing only new data to the reader @@ -88,7 +86,7 @@ public override void AdvanceTo(SequencePosition consumed, SequencePosition exami //} //else //{ - _inner.AdvanceTo(_ackPosition, examined); + _inner.AdvanceTo(_ackPosition, examined); //} if (_consumed.Equals(_ackPosition)) @@ -194,104 +192,3 @@ public override bool TryRead(out ReadResult result) throw new NotImplementedException(); } } - -// Wrapper around a PipeWriter that adds framing to writes -internal sealed class AckPipeWriter : PipeWriter -{ - public const int FrameSize = 24; - private readonly PipeWriter _inner; - internal long LastAck; - - Memory _frameHeader; - bool _shouldAdvanceFrameHeader; - private long _buffered; - - public AckPipeWriter(PipeWriter inner) - { - _inner = inner; - } - - public override void Advance(int bytes) - { - _buffered += bytes; - if (_shouldAdvanceFrameHeader) - { - bytes += FrameSize; - _shouldAdvanceFrameHeader = false; - } - _inner.Advance(bytes); - } - - public override void CancelPendingFlush() - { - _inner.CancelPendingFlush(); - } - - public override void Complete(Exception? exception = null) - { - _inner.Complete(exception); - } - - // TODO: We could reduce this to 16 bytes for binary transports and avoid the base64 encode/decode - // TODO: We could also reduce this to 1 + 12 (or 8) bytes occasionally if we add a flag for no new ack ID and avoid sending an ack - // X - 12 byte - size of payload as long and base64 encoded - // Y - 12 byte - number of acked bytes as long and base64 encoded - // Z - payload - // [ XXXX YYYY ZZZZ ] - public override ValueTask FlushAsync(CancellationToken cancellationToken = default) - { - Debug.Assert(_frameHeader.Length >= FrameSize); - - WriteFrame(_frameHeader.Span, _buffered, LastAck); - - _frameHeader = Memory.Empty; - _buffered = 0; - return _inner.FlushAsync(cancellationToken); - } - - public override Memory GetMemory(int sizeHint = 0) - { - var segment = _inner.GetMemory(Math.Max(FrameSize + 1, sizeHint)); - if (_frameHeader.IsEmpty || _buffered == 0) - { - Debug.Assert(segment.Length > FrameSize); - - _frameHeader = segment.Slice(0, FrameSize); - segment = segment.Slice(FrameSize); - _shouldAdvanceFrameHeader = true; - } - return segment; - } - - public override Span GetSpan(int sizeHint = 0) - { - return GetMemory(sizeHint).Span; - } - - public static void WriteFrame(Span header, long length, long ack) - { - Debug.Assert(header.Length >= FrameSize); - -#if NETSTANDARD2_1_OR_GREATER || NET5_0_OR_GREATER - var res = BitConverter.TryWriteBytes(header, length); - Debug.Assert(res); - var status = Base64.EncodeToUtf8InPlace(header, 8, out var written); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(written == 12); - res = BitConverter.TryWriteBytes(header.Slice(12), ack); - Debug.Assert(res); - status = Base64.EncodeToUtf8InPlace(header.Slice(12), 8, out written); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(written == 12); -#else - BitConverter.GetBytes(length).CopyTo(header); - var status = Base64.EncodeToUtf8InPlace(header, 8, out var written); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(written == 12); - BitConverter.GetBytes(ack).CopyTo(header.Slice(12)); - status = Base64.EncodeToUtf8InPlace(header.Slice(12), 8, out written); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(written == 12); -#endif - } -} diff --git a/src/SignalR/common/Shared/AckPipeWriter.cs b/src/SignalR/common/Shared/AckPipeWriter.cs new file mode 100644 index 000000000000..a52f0573a3bf --- /dev/null +++ b/src/SignalR/common/Shared/AckPipeWriter.cs @@ -0,0 +1,115 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers.Text; +using System.Buffers; +using System.Diagnostics; +using System.IO.Pipelines; +using System.Threading.Tasks; +using System.Threading; +using System; + +#nullable enable + +namespace Microsoft.AspNetCore.Http.Connections; + +// Wrapper around a PipeWriter that adds framing to writes +internal sealed class AckPipeWriter : PipeWriter +{ + public const int FrameSize = 24; + private readonly PipeWriter _inner; + internal long LastAck; + + Memory _frameHeader; + bool _shouldAdvanceFrameHeader; + private long _buffered; + + public AckPipeWriter(PipeWriter inner) + { + _inner = inner; + } + + public override void Advance(int bytes) + { + _buffered += bytes; + if (_shouldAdvanceFrameHeader) + { + bytes += FrameSize; + _shouldAdvanceFrameHeader = false; + } + _inner.Advance(bytes); + } + + public override void CancelPendingFlush() + { + _inner.CancelPendingFlush(); + } + + public override void Complete(Exception? exception = null) + { + _inner.Complete(exception); + } + + // TODO: We could reduce this to 16 bytes for binary transports and avoid the base64 encode/decode + // TODO: We could also reduce this to 1 + 12 (or 8) bytes occasionally if we add a flag for no new ack ID and avoid sending an ack + // X - 12 byte - size of payload as long and base64 encoded + // Y - 12 byte - number of acked bytes as long and base64 encoded + // Z - payload + // [ XXXX YYYY ZZZZ ] + public override ValueTask FlushAsync(CancellationToken cancellationToken = default) + { + Debug.Assert(_frameHeader.Length >= FrameSize); + + WriteFrame(_frameHeader.Span, _buffered, LastAck); + + _frameHeader = Memory.Empty; + _buffered = 0; + return _inner.FlushAsync(cancellationToken); + } + + public override Memory GetMemory(int sizeHint = 0) + { + var segment = _inner.GetMemory(Math.Max(FrameSize + 1, sizeHint)); + if (_frameHeader.IsEmpty || _buffered == 0) + { + Debug.Assert(segment.Length > FrameSize); + + _frameHeader = segment.Slice(0, FrameSize); + segment = segment.Slice(FrameSize); + _shouldAdvanceFrameHeader = true; + } + return segment; + } + + public override Span GetSpan(int sizeHint = 0) + { + return GetMemory(sizeHint).Span; + } + + public static void WriteFrame(Span header, long length, long ack) + { + Debug.Assert(header.Length >= FrameSize); + +#if NETSTANDARD2_1_OR_GREATER || NET5_0_OR_GREATER + var res = BitConverter.TryWriteBytes(header, length); + Debug.Assert(res); + var status = Base64.EncodeToUtf8InPlace(header, 8, out var written); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(written == 12); + res = BitConverter.TryWriteBytes(header.Slice(12), ack); + Debug.Assert(res); + status = Base64.EncodeToUtf8InPlace(header.Slice(12), 8, out written); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(written == 12); +#else + BitConverter.GetBytes(length).CopyTo(header); + var status = Base64.EncodeToUtf8InPlace(header, 8, out var written); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(written == 12); + BitConverter.GetBytes(ack).CopyTo(header.Slice(12)); + status = Base64.EncodeToUtf8InPlace(header.Slice(12), 8, out written); + Debug.Assert(status == OperationStatus.Done); + Debug.Assert(written == 12); +#endif + } +} diff --git a/src/SignalR/common/Shared/AcknowledgePipe/DuplexPipe.cs b/src/SignalR/common/Shared/AcknowledgePipe/DuplexPipe.cs deleted file mode 100644 index b6d72a7df146..000000000000 --- a/src/SignalR/common/Shared/AcknowledgePipe/DuplexPipe.cs +++ /dev/null @@ -1,93 +0,0 @@ -using System.IO.Pipelines; - -namespace PipelinesOverNetwork -{ - internal sealed class DuplexPipe : IDuplexPipe - { - public DuplexPipe(PipeReader reader, PipeWriter writer) - { - Input = reader; - Output = writer; - } - - public PipeReader Input { get; } - - public PipeWriter Output { get; } - - public static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) - { - var input = new Pipe(inputOptions); - var output = new Pipe(outputOptions); - - var transportToApplication = new DuplexPipe(output.Reader, input.Writer); - var applicationToTransport = new DuplexPipe(input.Reader, output.Writer); - - return new DuplexPipePair(applicationToTransport, transportToApplication); - } - - // This class exists to work around issues with value tuple on .NET Framework - public readonly struct DuplexPipePair - { - public IDuplexPipe Transport { get; } - public IDuplexPipe Application { get; } - - public DuplexPipePair(IDuplexPipe transport, IDuplexPipe application) - { - Transport = transport; - Application = application; - } - } - } - - internal sealed class AckDuplexPipe : IDuplexPipe - { - - public AckDuplexPipe(PipeReader reader, PipeWriter writer) - { - Input = reader; - Output = writer; - } - - public PipeReader Input { get; } - - public PipeWriter Output { get; } - - public static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) - { - var input = new Pipe(inputOptions); - var output = new Pipe(outputOptions); - - // wire up both sides for testing - var ackWriterApp = new AckPipeWriter(output.Writer); - var ackReaderApp = new AckPipeReader(output.Reader); - var ackWriterClient = new AckPipeWriter(input.Writer); - var ackReaderClient = new AckPipeReader(input.Reader); - var transportReader = new ParseAckPipeReader(input.Reader, ackWriterApp, ackReaderApp); - var applicationReader = new ParseAckPipeReader(ackReaderApp, ackWriterClient, ackReaderClient); - var transportToApplication = new AckDuplexPipe(applicationReader, ackWriterClient); - var applicationToTransport = new AckDuplexPipe(transportReader, ackWriterApp); - - // Use for one side only, i.e. server - //var ackWriter = new AckPipeWriter(output.Writer); - //var ackReader = new AckPipeReader(output.Reader); - //var transportReader = new ParseAckPipeReader(input.Reader, ackWriter, ackReader); - //var transportToApplication = new DuplexPipe(ackReader, input.Writer); - //var applicationToTransport = new DuplexPipe(transportReader, ackWriter); - - return new DuplexPipePair(applicationToTransport, transportToApplication); - } - - // This class exists to work around issues with value tuple on .NET Framework - public readonly struct DuplexPipePair - { - public IDuplexPipe Transport { get; } - public IDuplexPipe Application { get; } - - public DuplexPipePair(IDuplexPipe transport, IDuplexPipe application) - { - Transport = transport; - Application = application; - } - } - } -} diff --git a/src/SignalR/common/Shared/ParseAckPipeReader.cs b/src/SignalR/common/Shared/ParseAckPipeReader.cs index 9e723777f605..d2dd82f163bf 100644 --- a/src/SignalR/common/Shared/ParseAckPipeReader.cs +++ b/src/SignalR/common/Shared/ParseAckPipeReader.cs @@ -11,7 +11,7 @@ #nullable enable -namespace PipelinesOverNetwork; +namespace Microsoft.AspNetCore.Http.Connections; // Read from "network" // Parse framing and slice the read so the application doesn't see the framing diff --git a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs index d0adfdc64488..8e1225369ae7 100644 --- a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs +++ b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs @@ -5,7 +5,7 @@ using System.Buffers.Text; using System.Diagnostics; using System.IO.Pipelines; -using PipelinesOverNetwork; +using Microsoft.AspNetCore.Http.Connections; namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol; diff --git a/src/SignalR/common/SignalR.Common/test/Microsoft.AspNetCore.SignalR.Common.Tests.csproj b/src/SignalR/common/SignalR.Common/test/Microsoft.AspNetCore.SignalR.Common.Tests.csproj index 9a08da67fabb..1f5288af77e8 100644 --- a/src/SignalR/common/SignalR.Common/test/Microsoft.AspNetCore.SignalR.Common.Tests.csproj +++ b/src/SignalR/common/SignalR.Common/test/Microsoft.AspNetCore.SignalR.Common.Tests.csproj @@ -7,7 +7,8 @@ - + + From 2d67269697333b1470ff10946adbde1aa90da335 Mon Sep 17 00:00:00 2001 From: Brennan Date: Fri, 14 Apr 2023 14:47:10 -0700 Subject: [PATCH 08/25] rebase --- .../src/Internal/HttpConnectionContext.cs | 3 +++ .../src/Internal/HttpConnectionDispatcher.cs | 11 ++++------- .../Internal/Transports/WebSocketsServerTransport.cs | 5 ++++- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs index 9927481610c6..70c8a3f341bf 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs @@ -11,6 +11,7 @@ using Microsoft.AspNetCore.Http.Connections.Features; using Microsoft.AspNetCore.Http.Connections.Internal.Transports; using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Http.Timeouts; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -407,6 +408,8 @@ internal bool TryActivatePersistentConnection( // Start the transport TransportTask = transport.ProcessRequestAsync(context, context.RequestAborted); + context.Features.Get()?.DisableTimeout(); + return true; } else diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs index 16060e692351..2f83004a346b 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs @@ -285,7 +285,7 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti { if (transport != HttpTransportType.LongPolling) { - await _manager.DisposeAndRemoveAsync(connection, closeGracefully: false); + await _manager.DisposeAndRemoveAsync(connection, closeGracefully: false, HttpConnectionStopStatus.NormalClosure); } else { @@ -309,7 +309,7 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti // TODO: If acks aren't enabled we can close the connection immediately if (await connection.TransportTask!) { - await _manager.DisposeAndRemoveAsync(connection, closeGracefully: true); + await _manager.DisposeAndRemoveAsync(connection, closeGracefully: true, HttpConnectionStopStatus.NormalClosure); } else { @@ -327,15 +327,12 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti } } - private async Task DoPersistentConnection(HttpConnectionContext connection, HttpContext context) + private async Task DoPersistentConnection(HttpConnectionContext connection) { - context.Features.Get()?.DisableTimeout(); - // Wait for any of them to end await Task.WhenAny(connection.ApplicationTask!, connection.TransportTask!); - await _manager.DisposeAndRemoveAsync(connection, closeGracefully: true, HttpConnectionStopStatus.NormalClosure); - } + await _manager.DisposeAndRemoveAsync(connection, closeGracefully: true, HttpConnectionStopStatus.NormalClosure); } private async Task ProcessNegotiate(HttpContext context, HttpConnectionDispatcherOptions options, ConnectionLogScope logScope) diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs index c6d20d67b0fb..391317715e15 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs @@ -60,9 +60,12 @@ public async Task ProcessRequestAsync(HttpContext context, CancellationTok public async Task ProcessSocketAsync(WebSocket socket) { + var ignoreFirstCancel = false; if (_application.Input is AckPipeReader reader) { _aborted = false; + // TODO: why is this needed on initial connection start, ideally should be in if condition below + ignoreFirstCancel = true; // TODO: check if the pipe was used previously? // Currently checked in Resend if (reader.Resend()) @@ -98,7 +101,7 @@ public async Task ProcessSocketAsync(WebSocket socket) } var receiving = StartReceiving(socket); - var sending = StartSending(socket, true); + var sending = StartSending(socket, ignoreFirstCancel); // Wait for send or receive to complete var trigger = await Task.WhenAny(receiving, sending); From 84da68f6490fceaf9b6e774f6f7f214825c54143 Mon Sep 17 00:00:00 2001 From: Brennan Date: Fri, 14 Apr 2023 16:50:47 -0700 Subject: [PATCH 09/25] cleanup --- .../test/FunctionalTests/ProxyStartup.cs | 104 ------------------ .../src/HttpConnection.cs | 2 + .../src/Internal/LongPollingTransport.cs | 26 +---- .../src/Internal/ServerSentEventsTransport.cs | 26 +---- .../src/Internal/HttpConnectionDispatcher.cs | 38 +------ .../Transports/WebSocketsServerTransport.cs | 29 ++--- src/SignalR/common/Shared/AckPipeReader.cs | 9 +- src/SignalR/samples/ClientSample/HubSample.cs | 3 +- src/SignalR/samples/SignalRSamples/Program.cs | 3 +- src/SignalR/samples/SignalRSamples/Startup.cs | 7 +- 10 files changed, 20 insertions(+), 227 deletions(-) delete mode 100644 src/SignalR/clients/csharp/Client/test/FunctionalTests/ProxyStartup.cs diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/ProxyStartup.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/ProxyStartup.cs deleted file mode 100644 index 1150baa476ab..000000000000 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/ProxyStartup.cs +++ /dev/null @@ -1,104 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System; -using System.IdentityModel.Tokens.Jwt; -using System.IO; -using System.Net.Http; -using System.Net.WebSockets; -using System.Security.Claims; -using Microsoft.AspNetCore.Authentication.JwtBearer; -using Microsoft.AspNetCore.Authentication.Negotiate; -using Microsoft.AspNetCore.Authorization; -using Microsoft.AspNetCore.Builder; -using Microsoft.AspNetCore.DataProtection; -using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Http.Connections; -using Microsoft.AspNetCore.Routing; -using Microsoft.AspNetCore.WebSockets; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.IdentityModel.Tokens; -using Newtonsoft.Json; - -namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests; - -public class ProxyStartup -{ - private string ServerUrl; - - public void ConfigureServices(IServiceCollection services) - { - // Since tests run in parallel, it's possible multiple servers will startup and read files being written by another test - // Use a unique directory per server to avoid this collision - services.AddDataProtection() - .PersistKeysToFileSystem(Directory.CreateDirectory(Path.GetRandomFileName())); - - services.AddWebSockets(o => o.KeepAliveInterval = TimeSpan.Zero); - - services.AddRouting(); - } - - public void Configure(IApplicationBuilder app) - { - app.UseRouting(); - app.UseWebSockets(); - - app.Use(next => - { - return async context => - { - if (context.Request.Path.Value.EndsWith("/server", StringComparison.Ordinal)) - { - ServerUrl = context.Request.Query["url"]; - } - else if (context.Request.Path.Value.EndsWith("/drop", StringComparison.Ordinal)) - { - // TODO: drop connection - // for testing seamless reconnect - } - else - { - // TODO: forward to server - if (context.WebSockets.IsWebSocketRequest) - { - var uriBuilder = new UriBuilder(ServerUrl); - uriBuilder.Path = context.Request.Path; - uriBuilder.Scheme = context.Request.IsHttps ? "wss" : "ws"; - uriBuilder.Query = context.Request.QueryString.Value; - using var ws = await context.WebSockets.AcceptWebSocketAsync(); - using var forwardingWebsocket = new ClientWebSocket(); - await forwardingWebsocket.ConnectAsync(uriBuilder.Uri, new CancellationTokenSource(TimeSpan.FromSeconds(30)).Token); - var recvTask = Forward(ws, forwardingWebsocket); - var sendTask = Forward(forwardingWebsocket, ws); - - await Task.WhenAny(recvTask, sendTask); - } - else - { - var uriBuilder = new UriBuilder(ServerUrl); - uriBuilder.Path = context.Request.Path; - uriBuilder.Query = context.Request.QueryString.Value; - using var httpClient = new HttpClient(); - var request = new HttpRequestMessage(new HttpMethod(context.Request.Method), uriBuilder.ToString()); - request.Content = new StreamContent(context.Request.Body); - var resp = await httpClient.SendAsync(request); - - context.Response.StatusCode = (int)resp.StatusCode; - await resp.Content.CopyToAsync(context.Response.Body); - } - } - await next(context); - }; - }); - } - - private static async Task Forward(WebSocket ws, WebSocket forwardWebSocket) - { - var buffer = new byte[4096]; - while (forwardWebSocket.CloseStatus is null) - { - var res = await ws.ReceiveAsync(buffer, cancellationToken: default); - await forwardWebSocket.SendAsync(buffer.AsMemory(..res.Count), res.MessageType, res.EndOfMessage, cancellationToken: default); - } - } -} diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs index 9a40615c1c4e..18115c572434 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs @@ -397,6 +397,8 @@ private async Task SelectAndStartTransport(TransferFormat transferFormat, Cancel // The negotiation response gets cleared in the fallback scenario. if (negotiationResponse == null) { + // Temporary until other transports work + _httpConnectionOptions.UseAcks = transportType == HttpTransportType.WebSockets ? _httpConnectionOptions.UseAcks : false; negotiationResponse = await GetNegotiationResponseAsync(uri, cancellationToken).ConfigureAwait(false); connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionToken); } diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs index 04cb1a7c7982..66d2a02c015b 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs @@ -11,7 +11,6 @@ using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -using static System.IO.Pipelines.DuplexPipe; namespace Microsoft.AspNetCore.Http.Connections.Client.Internal; @@ -60,35 +59,12 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio } // Create the pipe pair (Application's writer is connected to Transport's reader, and vice versa) - DuplexPipePair pair; - if (_useAck) - { - pair = CreateConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); - } - else - { - pair = DuplexPipe.CreateConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); - } + var pair = DuplexPipe.CreateConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); _transport = pair.Transport; _application = pair.Application; Running = ProcessAsync(url); - - static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) - { - var input = new Pipe(inputOptions); - var output = new Pipe(outputOptions); - - // Use for one side only, i.e. server - var ackWriterApp = new AckPipeWriter(output.Writer); - var ackReader = new AckPipeReader(output.Reader); - var transportReader = new ParseAckPipeReader(input.Reader, ackWriterApp, ackReader); - var transportToApplication = new DuplexPipe(ackReader, input.Writer); - var applicationToTransport = new DuplexPipe(transportReader, ackWriterApp); - - return new DuplexPipePair(applicationToTransport, transportToApplication); - } } private async Task ProcessAsync(Uri url) diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs index e8583be8d516..d6758849df4f 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs @@ -13,7 +13,6 @@ using Microsoft.AspNetCore.Shared; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -using static System.IO.Pipelines.DuplexPipe; namespace Microsoft.AspNetCore.Http.Connections.Client.Internal; @@ -76,15 +75,7 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio } // Create the pipe pair (Application's writer is connected to Transport's reader, and vice versa) - DuplexPipePair pair; - if (_useAck) - { - pair = CreateConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); - } - else - { - pair = DuplexPipe.CreateConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); - } + var pair = DuplexPipe.CreateConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); _transport = pair.Transport; _application = pair.Application; @@ -95,21 +86,6 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio // _application.Input.OnWriterCompleted((exception, state) => ((CancellationTokenSource)state).Cancel(), inputCts); Running = ProcessAsync(url, response); - - static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) - { - var input = new Pipe(inputOptions); - var output = new Pipe(outputOptions); - - // Use for one side only, i.e. server - var ackWriter = new AckPipeWriter(output.Writer); - var ackReader = new AckPipeReader(output.Reader); - var transportReader = new ParseAckPipeReader(input.Reader, ackWriter, ackReader); - var transportToApplication = new DuplexPipe(ackReader, input.Writer); - var applicationToTransport = new DuplexPipe(transportReader, ackWriter); - - return new DuplexPipePair(applicationToTransport, transportToApplication); - } } private async Task ProcessAsync(Uri url, HttpResponseMessage response) diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs index 2f83004a346b..e50bd09fee47 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs @@ -158,52 +158,20 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti await DoPersistentConnection(connection); } } - //else if (context.WebSockets.IsWebSocketRequest) - //{ - // // Connection can be established lazily - // var connection = await GetOrCreateConnectionAsync(context, options); - // if (connection == null) - // { - // // No such connection, GetOrCreateConnection already set the response status code - // return; - // } - - // if (!await EnsureConnectionStateAsync(connection, context, HttpTransportType.WebSockets, supportedTransports, logScope)) - // { - // // Bad connection state. It's already set the response status code. - // return; - // } - - // Log.EstablishedConnection(_logger); - - // // Allow the reads to be canceled - // connection.Cancellation = new CancellationTokenSource(); - - // var ws = new WebSocketsServerTransport(options.WebSockets, connection.Application, connection, _loggerFactory); - - // await DoPersistentConnection(connectionDelegate, ws, context, connection); - //} else { - // GET /{path} maps to long polling + // GET /{path} maps to long polling or WebSockets + HttpConnectionContext? connection; var transport = HttpTransportType.LongPolling; if (context.WebSockets.IsWebSocketRequest) { transport = HttpTransportType.WebSockets; - } - else - { - AddNoCacheHeaders(context.Response); - } - - HttpConnectionContext? connection; - if (transport == HttpTransportType.WebSockets) - { connection = await GetOrCreateConnectionAsync(context, options); } else { + AddNoCacheHeaders(context.Response); // Connection must already exist connection = await GetConnectionAsync(context); } diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs index 391317715e15..21c2e804af58 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs @@ -75,7 +75,6 @@ public async Task ProcessSocketAsync(WebSocket socket) // 2. Send ack ID to client for last message we received from client before it disconnected // 3. Resume normal send/receive loops - // wait for first read? var buf = new byte[AckPipeWriter.FrameSize]; var res = await socket.ReceiveAsync(buf, _connection.Cancellation?.Token ?? default); Debug.Assert(res.Count == AckPipeWriter.FrameSize); @@ -175,7 +174,6 @@ private async Task StartReceiving(WebSocket socket) { while (!token.IsCancellationRequested) { - // Do a 0 byte read so that idle connections don't allocate a buffer when waiting for a read var result = await socket.ReceiveAsync(Memory.Empty, token); @@ -197,7 +195,6 @@ private async Task StartReceiving(WebSocket socket) } Log.MessageReceived(_logger, receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage); - //LogBytes(memory.Slice(0, receiveResult.Count), _logger); _application.Output.Advance(receiveResult.Count); @@ -258,6 +255,8 @@ private async Task StartSending(WebSocket socket, bool ignoreFirstCancel) break; } + ignoreFirstCancel = false; + if (!buffer.IsEmpty) { try @@ -281,7 +280,7 @@ private async Task StartSending(WebSocket socket, bool ignoreFirstCancel) catch (OperationCanceledException ex) when (ex.CancellationToken == _connection.SendingToken) { _closed = true; - // Log + // TODO: probably log break; } catch (Exception ex) @@ -297,16 +296,6 @@ private async Task StartSending(WebSocket socket, bool ignoreFirstCancel) { break; } - else if (ignoreFirstCancel) - { - //var webSocketMessageType = (_connection.ActiveFormat == TransferFormat.Binary - // ? WebSocketMessageType.Binary - // : WebSocketMessageType.Text); - //var buf = new byte[AckPipeWriter.FrameSize]; - //AckPipeWriter.WriteFrame(buf.AsSpan(), 0, ((AckPipeWriter)(_connection.Transport.Output)).lastAck); - //await socket.SendAsync(buffer, webSocketMessageType, _connection.SendingToken); - } - ignoreFirstCancel = false; } finally { @@ -335,15 +324,15 @@ private async Task StartSending(WebSocket socket, bool ignoreFirstCancel) } } - //if (error is not null) - //{ - // _logger.LogError("Error in send {ex}.", error); - //} - if (_closed) { - _application.Input.Complete(); + _application.Input.Complete(error); } + // TODO + //else if (error is not null) + //{ + // _logger.LogError("Error in send {ex}.", error); + //} } } diff --git a/src/SignalR/common/Shared/AckPipeReader.cs b/src/SignalR/common/Shared/AckPipeReader.cs index 189028881c9e..ed9e5d402abe 100644 --- a/src/SignalR/common/Shared/AckPipeReader.cs +++ b/src/SignalR/common/Shared/AckPipeReader.cs @@ -79,15 +79,8 @@ public override void AdvanceTo(SequencePosition consumed) public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) { _consumed = consumed; - //if (_ackPosition.Equals(default)) - //{ - // Debug.Assert(false); - // _inner.AdvanceTo(consumed, examined); - //} - //else - //{ + // Consumed stays at the ack positions, we store the passed in consumed value for use in ReadAsync so we can give the user only new data _inner.AdvanceTo(_ackPosition, examined); - //} if (_consumed.Equals(_ackPosition)) { diff --git a/src/SignalR/samples/ClientSample/HubSample.cs b/src/SignalR/samples/ClientSample/HubSample.cs index c1b121fa4ef6..506a520b2238 100644 --- a/src/SignalR/samples/ClientSample/HubSample.cs +++ b/src/SignalR/samples/ClientSample/HubSample.cs @@ -34,7 +34,7 @@ public static async Task ExecuteAsync(string baseUrl) var connectionBuilder = new HubConnectionBuilder() .ConfigureLogging(logging => { - //logging.AddConsole(); + logging.AddConsole(); }); connectionBuilder.Services.Configure(options => @@ -55,7 +55,6 @@ public static async Task ExecuteAsync(string baseUrl) using var closedTokenSource = new CancellationTokenSource(); var connection = connectionBuilder.Build(); - connection.ServerTimeout = TimeSpan.FromSeconds(15); try { diff --git a/src/SignalR/samples/SignalRSamples/Program.cs b/src/SignalR/samples/SignalRSamples/Program.cs index 757e269b9e99..c610b486315f 100644 --- a/src/SignalR/samples/SignalRSamples/Program.cs +++ b/src/SignalR/samples/SignalRSamples/Program.cs @@ -25,8 +25,7 @@ public static Task Main(string[] args) { factory.AddConfiguration(c.Configuration.GetSection("Logging")); factory.AddConsole(); - factory.SetMinimumLevel(LogLevel.Trace); - //factory.SetMinimumLevel(LogLevel.Debug); + factory.SetMinimumLevel(LogLevel.Debug); }) .UseKestrel(options => { diff --git a/src/SignalR/samples/SignalRSamples/Startup.cs b/src/SignalR/samples/SignalRSamples/Startup.cs index 5c42ffdb7293..5a3d67e481c3 100644 --- a/src/SignalR/samples/SignalRSamples/Startup.cs +++ b/src/SignalR/samples/SignalRSamples/Startup.cs @@ -18,12 +18,7 @@ public void ConfigureServices(IServiceCollection services) { services.AddConnections(); - services.AddSignalR(o => - { - o.MaximumParallelInvocationsPerClient = 10; - o.ClientTimeoutInterval = TimeSpan.FromSeconds(100); - o.KeepAliveInterval = TimeSpan.FromSeconds(5); - }) + services.AddSignalR() .AddMessagePackProtocol(); //.AddStackExchangeRedis(); } From e561e383802bfa83aef7d3f02240e1926f24fcfe Mon Sep 17 00:00:00 2001 From: Brennan Date: Tue, 18 Apr 2023 17:49:22 -0700 Subject: [PATCH 10/25] spec --- .../FunctionalTests/HubConnectionTests.cs | 12 +-- .../src/HttpConnectionOptions.cs | 7 +- src/SignalR/docs/specs/TransportProtocols.md | 83 +++++++++++++++++++ 3 files changed, 92 insertions(+), 10 deletions(-) diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs index c113d7b9a2ed..ce464f5b179c 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs @@ -2539,10 +2539,8 @@ public async Task ServerSentEventsWorksWithHttp2OnlyEndpoint() } [Fact] - //[Theory] - //[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] [Repeat(500)] - public async Task CanReconnectAndSendMessageWhileDisconnected(/*string protocolName, HttpTransportType transportType, string path*/) + public async Task CanReconnectAndSendMessageWhileDisconnected() { var protocol = HubProtocols["json"]; await using (var server = await StartServer()) @@ -2601,10 +2599,8 @@ public async Task CanReconnectAndSendMessageWhileDisconnected(/*string protocolN } [Fact] - //[Theory] - //[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] [Repeat(500)] - public async Task CanReconnectAndSendMessageOnceConnected(/*string protocolName, HttpTransportType transportType, string path*/) + public async Task CanReconnectAndSendMessageOnceConnected() { var protocol = HubProtocols["json"]; await using (var server = await StartServer()) @@ -2673,10 +2669,8 @@ public async Task CanReconnectAndSendMessageOnceConnected(/*string protocolName, } [Fact] - //[Theory] - //[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] [Repeat(500)] - public async Task ServerAbortsConnectionNoReconnectAttempted(/*string protocolName, HttpTransportType transportType, string path*/) + public async Task ServerAbortsConnectionNoReconnectAttempted() { var protocol = HubProtocols["json"]; await using (var server = await StartServer()) diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnectionOptions.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnectionOptions.cs index d63a60c136e4..c7a2ae78ed07 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnectionOptions.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnectionOptions.cs @@ -276,8 +276,13 @@ public Action? WebSocketConfiguration } /// - /// TODO + /// Setting to enable acking bytes sent between client and server, this allows reconnecting that preserves messages sent while disconnected. + /// Also preserves the when the reconnect is successful. /// + /// + /// Only works with WebSockets transport currently. + /// API likely to change in future previews. + /// public bool UseAcks { get; set; } private static void ThrowIfUnsupportedPlatform() diff --git a/src/SignalR/docs/specs/TransportProtocols.md b/src/SignalR/docs/specs/TransportProtocols.md index a4c10f4eadfa..a4478220e57a 100644 --- a/src/SignalR/docs/specs/TransportProtocols.md +++ b/src/SignalR/docs/specs/TransportProtocols.md @@ -20,12 +20,20 @@ Throughout this document, the term `[endpoint-base]` is used to refer to the rou The `POST [endpoint-base]/negotiate` request is used to establish a connection between the client and the server. +*negotiateVersion:* + In the POST request the client sends a query string parameter with the key "negotiateVersion" and the value as the negotiate protocol version it would like to use. If the query string is omitted, the server treats the version as zero. The server will include a "negotiateVersion" property in the json response that says which version it will be using. The version is chosen as described below: * If the servers minimum supported protocol version is greater than the version requested by the client it will send an error response and close the connection * If the server supports the request version it will respond with the requested version * If the requested version is greater than the servers largest supported version the server will respond with its largest supported version The client may close the connection if the "negotiateVersion" in the response is not acceptable. +*useAck:* + +In the POST request the client may include a query string parameter with the key "useAck" and the value of "true". If this is included the server will decide if it supports/allows the [ack protocol](#ack-protocol) described below, and return "useAck": "true" as a json property in the negotiate response if it will use the ack protocol. If true, the client must use the ack protocol when sending/receiving otherwise the connection will be terminated. Similarly, the server must use the ack protocol when sending/receiving. If false, the client must not use the ack protocol and will be terminated if it does. If the "useAck" property is missing from the negotiate response this also implies false, so the ack protocol should not be used. + +----------- + The content type of the response is `application/json` and is a JSON payload containing properties to assist the client in establishing a persistent connection. Extra JSON properties that the client does not know about should be ignored. This allows for future additions without breaking older clients. ### Version 1 @@ -197,3 +205,78 @@ When data is available, the server responds with a body in one of the two format If the `id` parameter is missing, a `400 Bad Request` response is returned. If there is no connection with the ID specified in `id`, a `404 Not Found` response is returned. When the client has finished with the connection, it can issue a `DELETE` request to `[endpoint-base]` (with the `id` in the query string) to gracefully terminate the connection. The server will complete the latest poll with `204` to indicate that it has shut down. + +## Ack Protocol + +The ack protocol primarily consists of writing and reading framing around the data being sent and received. +All sends need to start with a 24 byte frame. The frame is 2 12 byte base64 encoded values. The first base64 value when decoded is the length of the payload being sent (minus the framing) as an int64 value. The second base64 value when decoded is the ack ID as an int64 of how many bytes have been received from the other side so far. + +The second part of the protocol is for when the transport ungracefully reconnects and uses the Ack IDs to get any data that might have been missed during the disconnect window. This will be described after showing the framing. + +### Framing + +Consider the following example: + +0x41 0x67 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x3d 0x48 0x51 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x3d 0x48 0x69 + +This is a 26 byte message, the first 24 bytes are the framing, which we'll split into two 12 byte sections and the 2 remaining bytes +0x41 0x67 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x3d - Base64 represention as bytes +AgAAAAAAAAA= - Base64 representation in ASCII +2 0 0 0 0 0 0 0 0 0 0 0 - Base64 decoded, int64 value of 2, representing a 2 length payload after the framing + +0x48 0x51 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x3d - Base64 represention as bytes +HQAAAAAAAAA= - Base64 representation in ASCII +29 0 0 0 0 0 0 0 0 0 0 0 - Base64 decoded, int64 value of 29, representing an ack id of 29 bytes received from the endpoint so far + +0x48 0x69 +Hi + +From now on we'll use `[ , ]` annotation to represent the framing, with an implicit payload attached to it. + +To explain the Ack IDs we'll use the following example which is sending between a client and server, C and S respectively: + +``` +C->S: [ 5, 0 ] +S->C: [ 10, 29 ] +S->C: [ 13, 29 ] +C->S: [ 22, 71 ] +S->C: [ 1, 75 ] +``` + +The first send will send an Ack ID of 0 because the client hasn't received any data yet, so there is nothing to ack. When the server sends after it's received a message from the client it will send an Ack ID of the payload length (5) + the frame length (24), so 29. In this example we also send another message which won't have an updated Ack ID, because nothing new was received, so we send the previous value. The client in its next send adds all the received messages together to get the Ack ID to send to the server, 24 + 10 from the first message received, 24 + 13 from the second message received, for a total of 71. And then finally, the server adds its previously sent Ack ID of 29 with the message(s) received since its last send (24 + 22), for a total of 75 for the Ack ID it sends to the client. + +### Reconnect + +The second part of the protocol is what makes use of the Ack IDs. + +If a transport ungracefully disconnects the client can attempt to reconnect using the same `id` it was using before. The server is free to reject any reconnect attempts, but generally should allow a few seconds grace period. + +On a successful reconnect the client must send an Ack ID with a 0 length payload to the server indicating the last message it received before disconnecting. The client then waits for a message from the server that will contain the last Ack ID the server received before the disconnection, as well as a 0 length payload. This message **does not** increment the Ack ID tracking. The Ack ID received from the server will be used to send any missed messages from the client to the server. The normal send/receive loops can now start and if there is any unacked data on the client side the send loop should immediately send the missed data (framing and all). + +On a successful reconnect the server must wait for the client to send the last Ack ID it received before disconnecting. This message **does not** increment the Ack ID tracking. The Ack ID received from the client will be used to send any missed messages from the server to the client. The server will then send the last Ack ID it received before the disconnect occurred as well as a 0 length payload. The normal send/receive loops can now start and if there is any unacked data on the server side the send loop should immediately send the missed data (framing and all). + +The following example will send a few messages between client and server before having an ungraceful disconnect to show the reconnect flow: + +``` +C->S: [ 10, 0 ] +S->C: [ 1, 34 ] +C->S: [ 11, 25 ] +// Ungraceful disconnect +C->S: [ 0, 25 ] +S->C: [ 0, 34 ] +// normal send/receive loops for both sides are now started +C->S: [ 11, 25 ] // resend 11 byte payload that server didn't get before disconnect occurred +``` + +Another example that is the same as the last example except that the server did receive the clients last send before the disconnect: + +``` +C->S: [ 10, 0 ] +S->C: [ 1, 34 ] +C->S: [ 11, 25 ] +// Ungraceful disconnect +C->S: [ 0, 25 ] +S->C: [ 0, 69 ] +// normal send/receive loops for both sides are now started +// 11 bytes from C->S not resent because server did get it before the disconnect, as can be seen by the new Ack ID +``` \ No newline at end of file From b389e2c2dc579856e71c99936021b4c58f7c7569 Mon Sep 17 00:00:00 2001 From: Brennan Date: Mon, 24 Apr 2023 13:11:16 -0700 Subject: [PATCH 11/25] some fb --- .../src/Internal/WebSocketsTransport.cs | 49 ++++++++++++++----- .../Transports/WebSocketsServerTransport.cs | 26 ++++++++-- src/SignalR/common/Shared/AckPipeWriter.cs | 35 +++++-------- .../common/Shared/ParseAckPipeReader.cs | 45 +++++++---------- src/SignalR/docs/specs/TransportProtocols.md | 6 +-- .../test/DefaultTransportFactoryTests.cs | 10 ++-- 6 files changed, 97 insertions(+), 74 deletions(-) diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs index 3872e13c51cc..c1bb7c98adb9 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs @@ -4,6 +4,7 @@ using System; using System.Buffers; using System.Diagnostics; +using System.IO; using System.IO.Pipelines; using System.Net; using System.Net.Http; @@ -37,7 +38,9 @@ internal sealed partial class WebSocketsTransport : ITransport private readonly bool _useAck; private IDuplexPipe? _transport; - private bool _closed; + // Used for reconnect (when enabled) to determine if the close was ungraceful or not, reconnect only happens on ungraceful disconnect + // The assumption is that a graceful close was triggered purposefully by either the client or server and a reconnect shouldn't occur + private bool _gracefulClose; internal Task Running { get; private set; } = Task.CompletedTask; @@ -318,14 +321,34 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio // 3. Resume normal send/receive loops ignoreFirstCanceled = true; - var buf = new byte[AckPipeWriter.FrameSize]; + var buf = new byte[AckPipeWriter.FrameHeaderSize]; AckPipeWriter.WriteFrame(buf.AsSpan(), 0, ((AckPipeWriter)_transport.Output).LastAck); - await _webSocket.SendAsync(new ArraySegment(buf, 0, AckPipeWriter.FrameSize), _webSocketMessageType, true, _stopCts.Token).ConfigureAwait(false); + await _webSocket.SendAsync(new ArraySegment(buf, 0, AckPipeWriter.FrameHeaderSize), _webSocketMessageType, true, _stopCts.Token).ConfigureAwait(false); Array.Clear(buf, 0, buf.Length); - var receiveResult = await _webSocket.ReceiveAsync(new ArraySegment(buf, 0, AckPipeWriter.FrameSize), _stopCts.Token).ConfigureAwait(false); // server sends 0 length, but with latest ack, so there shouldn't be more than a frame of data sent - Debug.Assert(receiveResult.Count == AckPipeWriter.FrameSize); + var readLength = 0; + WebSocketReceiveResult? receiveResult; + do + { + receiveResult = await _webSocket.ReceiveAsync(new ArraySegment(buf, readLength, AckPipeWriter.FrameHeaderSize - readLength), _stopCts.Token).ConfigureAwait(false); + readLength += receiveResult.Count; + } while (readLength < AckPipeWriter.FrameHeaderSize && !receiveResult.EndOfMessage); + + if (readLength != AckPipeWriter.FrameHeaderSize) + { + _application.Output.Complete(new InvalidDataException("WebSocket reconnect handshake received less data than expected.")); + _application.Input.Complete(); + return; + } + + if (!receiveResult.EndOfMessage) + { + _application.Output.Complete(new InvalidDataException("WebSocket reconnect handshake received more data than expected.")); + _application.Input.Complete(); + return; + } + // Parsing ack id and updating reader here avoids issue where we send to server before receive loop runs, which is what normally updates ack // This avoids resending data that was already acked var parsedLen = ParseAckPipeReader.ParseFrame(new ReadOnlySequence(buf), reader); @@ -407,14 +430,14 @@ private async Task ProcessSocketAsync(WebSocket socket, Uri url, bool ignoreFirs socket.Abort(); // Cancel any pending flush so that we can quit - if (_closed) + if (_gracefulClose) { _application.Output.CancelPendingFlush(); } } } - if (_useAck && !_closed) + if (_useAck && !_gracefulClose) { await StartAsync(url, _webSocketMessageType == WebSocketMessageType.Binary ? TransferFormat.Binary : TransferFormat.Text, default).ConfigureAwait(false); } @@ -434,7 +457,7 @@ private async Task StartReceiving(WebSocket socket) if (result.MessageType == WebSocketMessageType.Close) { - _closed = true; + _gracefulClose = true; Log.WebSocketClosed(_logger, socket.CloseStatus); if (socket.CloseStatus != WebSocketCloseStatus.NormalClosure) @@ -461,7 +484,7 @@ private async Task StartReceiving(WebSocket socket) // Need to check again for netstandard2.1 because a close can happen between a 0-byte read and the actual read if (receiveResult.MessageType == WebSocketMessageType.Close) { - _closed = true; + _gracefulClose = true; Log.WebSocketClosed(_logger, socket.CloseStatus); if (socket.CloseStatus != WebSocketCloseStatus.NormalClosure) @@ -494,7 +517,7 @@ private async Task StartReceiving(WebSocket socket) { if (!_aborted) { - if (_closed) + if (_gracefulClose) { _application.Output.Complete(ex); } @@ -508,7 +531,7 @@ private async Task StartReceiving(WebSocket socket) finally { // We're done writing - if (_closed) + if (_gracefulClose) { _application.Output.Complete(); } @@ -605,7 +628,7 @@ private async Task StartSending(WebSocket socket, bool ignoreFirstCanceled) } } - if (_closed) + if (_gracefulClose) { _application.Input.Complete(error); } @@ -639,7 +662,7 @@ private static Uri ResolveWebSocketsUrl(Uri url) public async Task StopAsync() { - _closed = true; + _gracefulClose = true; Log.TransportStopping(_logger); if (_application == null) diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs index 21c2e804af58..202ec9df9697 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs @@ -75,9 +75,29 @@ public async Task ProcessSocketAsync(WebSocket socket) // 2. Send ack ID to client for last message we received from client before it disconnected // 3. Resume normal send/receive loops - var buf = new byte[AckPipeWriter.FrameSize]; - var res = await socket.ReceiveAsync(buf, _connection.Cancellation?.Token ?? default); - Debug.Assert(res.Count == AckPipeWriter.FrameSize); + var buf = new byte[AckPipeWriter.FrameHeaderSize]; + WebSocketReceiveResult? res; + var readLength = 0; + do + { + res = await socket.ReceiveAsync(new ArraySegment(buf, readLength, AckPipeWriter.FrameHeaderSize - readLength), _connection.Cancellation?.Token ?? default); + readLength += res.Count; + } while (readLength < AckPipeWriter.FrameHeaderSize && !res.EndOfMessage); + + if (readLength != AckPipeWriter.FrameHeaderSize) + { + _application.Output.Complete(new InvalidDataException("WebSocket reconnect handshake received less data than expected.")); + _application.Input.Complete(); + return; + } + + if (!res.EndOfMessage) + { + _application.Output.Complete(new InvalidDataException("WebSocket reconnect handshake received more data than expected.")); + _application.Input.Complete(); + return; + } + // Needed so that the readers ack position gets updated and we don't re-send messages to client // Normally this would be done by the HubConnectionHandler loop, but that requires a new message to be read // so we instead make sure it's updated immediately here diff --git a/src/SignalR/common/Shared/AckPipeWriter.cs b/src/SignalR/common/Shared/AckPipeWriter.cs index a52f0573a3bf..88350fdf8754 100644 --- a/src/SignalR/common/Shared/AckPipeWriter.cs +++ b/src/SignalR/common/Shared/AckPipeWriter.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using System.Threading; using System; +using System.Buffers.Binary; #nullable enable @@ -16,7 +17,7 @@ namespace Microsoft.AspNetCore.Http.Connections; // Wrapper around a PipeWriter that adds framing to writes internal sealed class AckPipeWriter : PipeWriter { - public const int FrameSize = 24; + public const int FrameHeaderSize = 24; private readonly PipeWriter _inner; internal long LastAck; @@ -34,7 +35,7 @@ public override void Advance(int bytes) _buffered += bytes; if (_shouldAdvanceFrameHeader) { - bytes += FrameSize; + bytes += FrameHeaderSize; _shouldAdvanceFrameHeader = false; } _inner.Advance(bytes); @@ -58,7 +59,7 @@ public override void Complete(Exception? exception = null) // [ XXXX YYYY ZZZZ ] public override ValueTask FlushAsync(CancellationToken cancellationToken = default) { - Debug.Assert(_frameHeader.Length >= FrameSize); + Debug.Assert(_frameHeader.Length >= FrameHeaderSize); WriteFrame(_frameHeader.Span, _buffered, LastAck); @@ -69,13 +70,13 @@ public override ValueTask FlushAsync(CancellationToken cancellation public override Memory GetMemory(int sizeHint = 0) { - var segment = _inner.GetMemory(Math.Max(FrameSize + 1, sizeHint)); + var segment = _inner.GetMemory(Math.Max(FrameHeaderSize + 1, sizeHint)); if (_frameHeader.IsEmpty || _buffered == 0) { - Debug.Assert(segment.Length > FrameSize); + Debug.Assert(segment.Length > FrameHeaderSize); - _frameHeader = segment.Slice(0, FrameSize); - segment = segment.Slice(FrameSize); + _frameHeader = segment.Slice(0, FrameHeaderSize); + segment = segment.Slice(FrameHeaderSize); _shouldAdvanceFrameHeader = true; } return segment; @@ -88,28 +89,16 @@ public override Span GetSpan(int sizeHint = 0) public static void WriteFrame(Span header, long length, long ack) { - Debug.Assert(header.Length >= FrameSize); + Debug.Assert(header.Length >= FrameHeaderSize); -#if NETSTANDARD2_1_OR_GREATER || NET5_0_OR_GREATER - var res = BitConverter.TryWriteBytes(header, length); - Debug.Assert(res); + BinaryPrimitives.WriteInt64LittleEndian(header, length); var status = Base64.EncodeToUtf8InPlace(header, 8, out var written); Debug.Assert(status == OperationStatus.Done); Debug.Assert(written == 12); - res = BitConverter.TryWriteBytes(header.Slice(12), ack); - Debug.Assert(res); - status = Base64.EncodeToUtf8InPlace(header.Slice(12), 8, out written); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(written == 12); -#else - BitConverter.GetBytes(length).CopyTo(header); - var status = Base64.EncodeToUtf8InPlace(header, 8, out var written); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(written == 12); - BitConverter.GetBytes(ack).CopyTo(header.Slice(12)); + + BinaryPrimitives.WriteInt64LittleEndian(header.Slice(12), ack); status = Base64.EncodeToUtf8InPlace(header.Slice(12), 8, out written); Debug.Assert(status == OperationStatus.Done); Debug.Assert(written == 12); -#endif } } diff --git a/src/SignalR/common/Shared/ParseAckPipeReader.cs b/src/SignalR/common/Shared/ParseAckPipeReader.cs index d2dd82f163bf..390bee81f5b1 100644 --- a/src/SignalR/common/Shared/ParseAckPipeReader.cs +++ b/src/SignalR/common/Shared/ParseAckPipeReader.cs @@ -3,6 +3,7 @@ using System; using System.Buffers; +using System.Buffers.Binary; using System.Buffers.Text; using System.Diagnostics; using System.IO.Pipelines; @@ -19,7 +20,7 @@ namespace Microsoft.AspNetCore.Http.Connections; // Notify application pipe of ack id provided by other side of the network internal sealed class ParseAckPipeReader : PipeReader { - private const int FrameSize = 24; + private const int FrameHeaderSize = 24; private readonly PipeReader _inner; private readonly AckPipeWriter _ackPipeWriter; private readonly AckPipeReader _ackPipeReader; @@ -79,9 +80,9 @@ public override async ValueTask ReadAsync(CancellationToken cancella if (res.IsCompleted || res.IsCanceled) { // TODO: figure out behavior - if (res.Buffer.Length >= FrameSize) + if (res.Buffer.Length >= FrameHeaderSize) { - res = new(res.Buffer.Slice(FrameSize), res.IsCanceled, res.IsCompleted); + res = new(res.Buffer.Slice(FrameHeaderSize), res.IsCanceled, res.IsCompleted); } return res; } @@ -90,23 +91,23 @@ public override async ValueTask ReadAsync(CancellationToken cancella if (_remaining == 0) { // TODO: didn't get 24 bytes - var frame = buffer.Slice(0, FrameSize); + var frame = buffer.Slice(0, FrameHeaderSize); var len = ParseFrame(frame, _ackPipeReader); _totalBytes += len; _remaining = len; // if the buffer doesn't have enough data we need to update how much we're slicing - if (len > buffer.Length - FrameSize) + if (len > buffer.Length - FrameHeaderSize) { - len = buffer.Length - FrameSize; + len = buffer.Length - FrameHeaderSize; } - buffer = buffer.Slice(FrameSize, len); + buffer = buffer.Slice(FrameHeaderSize, len); _currentRead = buffer; // 0 length means it was part of the reconnect handshake and not sent over the pipe, ignore it for acking purposes // TODO: check if 0 byte writes are possible in ConnectionHandlers and possibly handle them differently - _ackPipeWriter.LastAck += buffer.Length == 0 ? 0 : buffer.Length + FrameSize; + _ackPipeWriter.LastAck += buffer.Length == 0 ? 0 : buffer.Length + FrameHeaderSize; } else { @@ -141,37 +142,26 @@ public override async ValueTask ReadAsync(CancellationToken cancella public static long ParseFrame(ReadOnlySequence frame, AckPipeReader ackPipeReader) { - Debug.Assert(frame.Length >= FrameSize); - frame = frame.Slice(0, FrameSize); + Debug.Assert(frame.Length >= FrameHeaderSize); + frame = frame.Slice(0, FrameHeaderSize); long len; long ackId; // TODO: check perf of single Span check vs Stackalloc - Span buffer = stackalloc byte[FrameSize]; + Span buffer = stackalloc byte[FrameHeaderSize]; frame.CopyTo(buffer); - var status = Base64.DecodeFromUtf8InPlace(buffer.Slice(0, FrameSize / 2), out var written); + var status = Base64.DecodeFromUtf8InPlace(buffer.Slice(0, FrameHeaderSize / 2), out var written); Debug.Assert(status == OperationStatus.Done); Debug.Assert(written == 8); -#if NETSTANDARD2_1_OR_GREATER || NET5_0_OR_GREATER - len = BitConverter.ToInt64(buffer); -#else - var longBuf = new byte[8]; - buffer.Slice(0, 8).CopyTo(longBuf); - len = BitConverter.ToInt64(longBuf, 0); -#endif + len = BinaryPrimitives.ReadInt64LittleEndian(buffer); - status = Base64.DecodeFromUtf8InPlace(buffer.Slice(FrameSize / 2), out written); + var ackFrame = buffer.Slice(FrameHeaderSize / 2); + status = Base64.DecodeFromUtf8InPlace(ackFrame, out written); Debug.Assert(status == OperationStatus.Done); Debug.Assert(written == 8); - -#if NETSTANDARD2_1_OR_GREATER || NET5_0_OR_GREATER - ackId = BitConverter.ToInt64(buffer.Slice(FrameSize / 2)); -#else - buffer.Slice(12, 8).CopyTo(longBuf); - ackId = BitConverter.ToInt64(longBuf, 0); -#endif + ackId = BinaryPrimitives.ReadInt64LittleEndian(ackFrame); // Update ack id provided by other side, so the underlying pipe can release buffered memory ackPipeReader.Ack(ackId); @@ -180,6 +170,7 @@ public static long ParseFrame(ReadOnlySequence frame, AckPipeReader ackPip public override bool TryRead(out ReadResult result) { + // TODO: Not needed for SignalR, but could be called in ConnectionHandler layer of user code throw new NotImplementedException(); } } diff --git a/src/SignalR/docs/specs/TransportProtocols.md b/src/SignalR/docs/specs/TransportProtocols.md index a4478220e57a..13bf6d03f530 100644 --- a/src/SignalR/docs/specs/TransportProtocols.md +++ b/src/SignalR/docs/specs/TransportProtocols.md @@ -209,7 +209,7 @@ When the client has finished with the connection, it can issue a `DELETE` reques ## Ack Protocol The ack protocol primarily consists of writing and reading framing around the data being sent and received. -All sends need to start with a 24 byte frame. The frame is 2 12 byte base64 encoded values. The first base64 value when decoded is the length of the payload being sent (minus the framing) as an int64 value. The second base64 value when decoded is the ack ID as an int64 of how many bytes have been received from the other side so far. +All sends need to start with a 24 byte frame. The frame consists of 2 64-bit little-endian values, both base-64 encoded (preserving padding) for a total of 12 bytes. The first base-64 value when decoded is the length of the payload being sent (minus the framing) as an int64 value. The second base-64 value when decoded is the ack ID as an int64 of how many bytes have been received from the other side so far. The second part of the protocol is for when the transport ungracefully reconnects and uses the Ack IDs to get any data that might have been missed during the disconnect window. This will be described after showing the framing. @@ -220,11 +220,11 @@ Consider the following example: 0x41 0x67 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x3d 0x48 0x51 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x3d 0x48 0x69 This is a 26 byte message, the first 24 bytes are the framing, which we'll split into two 12 byte sections and the 2 remaining bytes -0x41 0x67 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x3d - Base64 represention as bytes +(hex) 41 67 41 41 41 41 41 41 41 41 41 3d - Base64 represention as bytes AgAAAAAAAAA= - Base64 representation in ASCII 2 0 0 0 0 0 0 0 0 0 0 0 - Base64 decoded, int64 value of 2, representing a 2 length payload after the framing -0x48 0x51 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x3d - Base64 represention as bytes +(hex) 48 51 41 41 41 41 41 41 41 41 41 3d - Base64 represention as bytes HQAAAAAAAAA= - Base64 representation in ASCII 29 0 0 0 0 0 0 0 0 0 0 0 - Base64 decoded, int64 value of 29, representing an ack id of 29 bytes received from the endpoint so far diff --git a/src/SignalR/server/SignalR/test/DefaultTransportFactoryTests.cs b/src/SignalR/server/SignalR/test/DefaultTransportFactoryTests.cs index d553e88db86c..3f954b0e37a2 100644 --- a/src/SignalR/server/SignalR/test/DefaultTransportFactoryTests.cs +++ b/src/SignalR/server/SignalR/test/DefaultTransportFactoryTests.cs @@ -53,7 +53,7 @@ public void DefaultTransportFactoryCreatesRequestedTransportIfAvailable(HttpTran { var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null); Assert.IsType(expectedTransportType, - transportFactory.CreateTransport(AllTransportTypes, true)); + transportFactory.CreateTransport(AllTransportTypes, useAck: true)); } [Theory] @@ -66,7 +66,7 @@ public void DefaultTransportFactoryThrowsIfItCannotCreateRequestedTransport(Http var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null); var ex = Assert.Throws( - () => transportFactory.CreateTransport(~requestedTransport, true)); + () => transportFactory.CreateTransport(~requestedTransport, useAck: true)); Assert.Equal("No requested transports available on the server.", ex.Message); } @@ -77,7 +77,7 @@ public void DefaultTransportFactoryCreatesWebSocketsTransportIfAvailable() { Assert.IsType( new DefaultTransportFactory(AllTransportTypes, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null) - .CreateTransport(AllTransportTypes, true)); + .CreateTransport(AllTransportTypes, useAck: true)); } [Theory] @@ -90,7 +90,7 @@ public void DefaultTransportFactoryCreatesRequestedTransportIfAvailable_Win7(Htt { var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null); Assert.IsType(expectedTransportType, - transportFactory.CreateTransport(AllTransportTypes, true)); + transportFactory.CreateTransport(AllTransportTypes, useAck: true)); } } @@ -103,7 +103,7 @@ public void DefaultTransportFactoryThrowsIfItCannotCreateRequestedTransport_Win7 var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null); var ex = Assert.Throws( - () => transportFactory.CreateTransport(AllTransportTypes, true)); + () => transportFactory.CreateTransport(AllTransportTypes, useAck: true)); Assert.Equal("No requested transports available on the server.", ex.Message); } From d1eccc17d420a194edf1f476bf9f2e1918f1f227 Mon Sep 17 00:00:00 2001 From: Brennan Date: Mon, 24 Apr 2023 16:04:59 -0700 Subject: [PATCH 12/25] backpressure and concrete pipe --- .../src/Internal/WebSocketsTransport.cs | 4 +- .../src/Internal/HttpConnectionManager.cs | 4 +- src/SignalR/common/Shared/AckPipeReader.cs | 7 ++- src/SignalR/common/Shared/AckPipeWriter.cs | 7 ++- .../test/Internal/Protocol/AckPipeTests.cs | 44 +++++++++++++++---- src/SignalR/docs/specs/TransportProtocols.md | 2 +- 6 files changed, 51 insertions(+), 17 deletions(-) diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs index c1bb7c98adb9..0078f421c263 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs @@ -374,8 +374,8 @@ static DuplexPipePair CreateAckConnectionPair(PipeOptions inputOptions, PipeOpti var output = new Pipe(outputOptions); // Use for one side only, i.e. server - var ackWriter = new AckPipeWriter(output.Writer); - var ackReader = new AckPipeReader(output.Reader); + var ackWriter = new AckPipeWriter(output); + var ackReader = new AckPipeReader(output); var transportReader = new ParseAckPipeReader(input.Reader, ackWriter, ackReader); var transportToApplication = new DuplexPipe(ackReader, input.Writer); var applicationToTransport = new DuplexPipe(transportReader, ackWriter); diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs index d0bd03d47c74..742e990f1f7e 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs @@ -104,8 +104,8 @@ static DuplexPipePair CreateAckConnectionPair(PipeOptions inputOptions, PipeOpti var output = new Pipe(outputOptions); // Use for one side only, i.e. server - var ackWriterApp = new AckPipeWriter(output.Writer); - var ackReader = new AckPipeReader(output.Reader); + var ackWriterApp = new AckPipeWriter(output); + var ackReader = new AckPipeReader(output); var transportReader = new ParseAckPipeReader(input.Reader, ackWriterApp, ackReader); var transportToApplication = new DuplexPipe(ackReader, input.Writer); var applicationToTransport = new DuplexPipe(transportReader, ackWriterApp); diff --git a/src/SignalR/common/Shared/AckPipeReader.cs b/src/SignalR/common/Shared/AckPipeReader.cs index ed9e5d402abe..081aef7a3aa0 100644 --- a/src/SignalR/common/Shared/AckPipeReader.cs +++ b/src/SignalR/common/Shared/AckPipeReader.cs @@ -25,9 +25,12 @@ internal sealed class AckPipeReader : PipeReader private long _totalWritten; private bool _resend; - public AckPipeReader(PipeReader inner) + // Accept Pipe instead of PipeReader because we don't want custom pipe implementations to be used with this type + // and Pipe is sealed so a custom one can't be provided + // We rely on undefined implementation details of the default Pipe + public AckPipeReader(Pipe innerPipe) { - _inner = inner; + _inner = innerPipe.Reader; } // Update the ack position. This number includes the framing size. diff --git a/src/SignalR/common/Shared/AckPipeWriter.cs b/src/SignalR/common/Shared/AckPipeWriter.cs index 88350fdf8754..6dff353ebffe 100644 --- a/src/SignalR/common/Shared/AckPipeWriter.cs +++ b/src/SignalR/common/Shared/AckPipeWriter.cs @@ -25,9 +25,12 @@ internal sealed class AckPipeWriter : PipeWriter bool _shouldAdvanceFrameHeader; private long _buffered; - public AckPipeWriter(PipeWriter inner) + // Accept Pipe instead of PipeWriter because we don't want custom pipe implementations to be used with this type + // and Pipe is sealed so a custom one can't be provided + // We rely on undefined implementation details of the default Pipe + public AckPipeWriter(Pipe innerPipe) { - _inner = inner; + _inner = innerPipe.Writer; } public override void Advance(int bytes) diff --git a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs index 8e1225369ae7..0b99495c9b47 100644 --- a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs +++ b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs @@ -6,6 +6,7 @@ using System.Diagnostics; using System.IO.Pipelines; using Microsoft.AspNetCore.Http.Connections; +using Microsoft.AspNetCore.Testing; namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol; @@ -679,14 +680,41 @@ public async Task TriggerResendWhenPartialFrameAcked() Assert.False(res.IsCompleted); } + [Fact] + public async Task BackpressureIsAppliedInBothDirections() + { + var duplexPipe = CreateClient(inputOptions: new PipeOptions(pauseWriterThreshold: 10, resumeWriterThreshold: 5, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline), + outputOptions: new PipeOptions(pauseWriterThreshold: 10, resumeWriterThreshold: 5, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline)); + + var buffer = new byte[FrameSize + 1]; + WriteFrame(buffer, 1, 0); + var writeTask = duplexPipe.Application.Output.WriteAsync(buffer); + // Shouldn't complete until the reader reads due to pauseWriterThreshold being 10 and we wrote 25 + Assert.False(writeTask.IsCompleted); + + var res = await duplexPipe.Transport.Input.ReadAsync(); + Assert.Equal(1, res.Buffer.Length); + duplexPipe.Transport.Input.AdvanceTo(res.Buffer.Start, res.Buffer.End); + await writeTask.DefaultTimeout(); + + writeTask = duplexPipe.Transport.Output.WriteAsync(new byte[2] { 4, 5 }); + // Shouldn't complete until the reader reads due to pauseWriterThreshold being 10 and we wrote 26 + Assert.False(writeTask.IsCompleted); + + res = await duplexPipe.Application.Input.ReadAsync(); + Assert.Equal(26, res.Buffer.Length); + duplexPipe.Application.Input.AdvanceTo(res.Buffer.End); + await writeTask.DefaultTimeout(); + } + internal static DuplexPipePair CreateClient(PipeOptions inputOptions = default, PipeOptions outputOptions = default) { var input = new Pipe(inputOptions ?? new()); var output = new Pipe(outputOptions ?? new()); // Use for one side only, this is client side - var ackWriter = new AckPipeWriter(output.Writer); - var ackReader = new AckPipeReader(output.Reader); + var ackWriter = new AckPipeWriter(output); + var ackReader = new AckPipeReader(output); var transportReader = new ParseAckPipeReader(input.Reader, ackWriter, ackReader); var transportToApplication = new DuplexPipe(ackReader, input.Writer); var applicationToTransport = new DuplexPipe(transportReader, ackWriter); @@ -703,8 +731,8 @@ internal static DuplexPipePair CreateServer(PipeOptions inputOptions = default, var output = new Pipe(outputOptions ?? new()); // Use for one side only, this is server side - var ackWriter = new AckPipeWriter(output.Writer); - var ackReader = new AckPipeReader(output.Reader); + var ackWriter = new AckPipeWriter(output); + var ackReader = new AckPipeReader(output); var transportReader = new ParseAckPipeReader(input.Reader, ackWriter, ackReader); var transportToApplication = new DuplexPipe(ackReader, input.Writer); var applicationToTransport = new DuplexPipe(transportReader, ackWriter); @@ -763,10 +791,10 @@ internal static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, Pi var output = new Pipe(outputOptions); // wire up both sides for testing - var ackWriterApp = new AckPipeWriter(output.Writer); - var ackReaderApp = new AckPipeReader(output.Reader); - var ackWriterClient = new AckPipeWriter(input.Writer); - var ackReaderClient = new AckPipeReader(input.Reader); + var ackWriterApp = new AckPipeWriter(output); + var ackReaderApp = new AckPipeReader(output); + var ackWriterClient = new AckPipeWriter(input); + var ackReaderClient = new AckPipeReader(input); var transportReader = new ParseAckPipeReader(input.Reader, ackWriterApp, ackReaderApp); var applicationReader = new ParseAckPipeReader(ackReaderApp, ackWriterClient, ackReaderClient); var transportToApplication = new DuplexPipe(applicationReader, ackWriterClient); diff --git a/src/SignalR/docs/specs/TransportProtocols.md b/src/SignalR/docs/specs/TransportProtocols.md index 13bf6d03f530..4e155b9b8952 100644 --- a/src/SignalR/docs/specs/TransportProtocols.md +++ b/src/SignalR/docs/specs/TransportProtocols.md @@ -209,7 +209,7 @@ When the client has finished with the connection, it can issue a `DELETE` reques ## Ack Protocol The ack protocol primarily consists of writing and reading framing around the data being sent and received. -All sends need to start with a 24 byte frame. The frame consists of 2 64-bit little-endian values, both base-64 encoded (preserving padding) for a total of 12 bytes. The first base-64 value when decoded is the length of the payload being sent (minus the framing) as an int64 value. The second base-64 value when decoded is the ack ID as an int64 of how many bytes have been received from the other side so far. +All sends need to start with a 24 byte frame. The frame consists of 2 64-bit little-endian values (8 bytes), both base-64 encoded (preserving padding) for a total of 2 12 byte base-64 values. The first base-64 value when decoded is the length of the payload being sent (minus the framing) as an int64 value. The second base-64 value when decoded is the ack ID as an int64 of how many bytes have been received from the other side so far. The second part of the protocol is for when the transport ungracefully reconnects and uses the Ack IDs to get any data that might have been missed during the disconnect window. This will be described after showing the framing. From a2b212e416d4a76a8176a1d39c211d5e5072f581 Mon Sep 17 00:00:00 2001 From: Brennan Date: Thu, 11 May 2023 09:22:30 -0700 Subject: [PATCH 13/25] stash --- .../src/Protocol/JsonHubProtocol.cs | 75 +++++++++++++++---- .../SignalR.Common/src/Protocol/AckMessage.cs | 30 ++++++++ .../src/Protocol/CancelInvocationMessage.cs | 9 +++ .../src/Protocol/CompletionMessage.cs | 22 ++++++ .../src/Protocol/HubInvocationMessage.cs | 16 ++++ .../Protocol/HubMethodInvocationMessage.cs | 40 ++++++++++ .../src/Protocol/HubProtocolConstants.cs | 5 ++ .../src/Protocol/StreamItemMessage.cs | 11 +++ .../PublicAPI/net462/PublicAPI.Unshipped.txt | 10 +++ .../server/Core/src/HubConnectionContext.cs | 25 ++++++- .../Core/src/Internal/DefaultHubDispatcher.cs | 4 + .../server/Core/src/Internal/MessageBuffer.cs | 36 +++++++++ 12 files changed, 263 insertions(+), 20 deletions(-) create mode 100644 src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs create mode 100644 src/SignalR/server/Core/src/Internal/MessageBuffer.cs diff --git a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs index 3feaefc13d10..236b5ac91e4b 100644 --- a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs +++ b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs @@ -43,6 +43,8 @@ public sealed class JsonHubProtocol : IHubProtocol private static readonly JsonEncodedText ArgumentsPropertyNameBytes = JsonEncodedText.Encode(ArgumentsPropertyName); private const string HeadersPropertyName = "headers"; private static readonly JsonEncodedText HeadersPropertyNameBytes = JsonEncodedText.Encode(HeadersPropertyName); + private const string SequenceIdPropertyName = "sequenceId"; + private static readonly JsonEncodedText SequenceIdPropertyNameBytes = JsonEncodedText.Encode(SequenceIdPropertyName); private const string ProtocolName = "json"; private const int ProtocolVersion = 1; @@ -139,6 +141,7 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) Dictionary? headers = null; var completed = false; var allowReconnect = false; + string? sequenceId = null; var reader = new Utf8JsonReader(input, isFinalBlock: true, state: default); @@ -325,6 +328,10 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) reader.CheckRead(); headers = ReadHeaders(ref reader); } + else if (reader.ValueTextEquals(SequenceIdPropertyNameBytes.EncodedUtf8Bytes)) + { + sequenceId = reader.ReadAsString(SequenceIdPropertyName); + } else { reader.CheckRead(); @@ -365,7 +372,7 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) message = argumentBindingException != null ? new InvocationBindingFailureMessage(invocationId, target, argumentBindingException) - : BindInvocationMessage(invocationId, target, arguments, hasArguments, streamIds); + : BindInvocationMessage(invocationId, sequenceId, target, arguments, hasArguments, streamIds); } break; case HubProtocolConstants.StreamInvocationMessageType: @@ -391,7 +398,7 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) message = argumentBindingException != null ? new InvocationBindingFailureMessage(invocationId, target, argumentBindingException) - : BindStreamInvocationMessage(invocationId, target, arguments, hasArguments, streamIds); + : BindStreamInvocationMessage(invocationId, sequenceId, target, arguments, hasArguments, streamIds); } break; case HubProtocolConstants.StreamItemMessageType: @@ -414,7 +421,7 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) } } - message = BindStreamItemMessage(invocationId, item, hasItem); + message = BindStreamItemMessage(invocationId, sequenceId, item, hasItem); break; case HubProtocolConstants.CompletionMessageType: if (invocationId is null) @@ -443,15 +450,17 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) } } - message = BindCompletionMessage(invocationId, error, result, hasResult); + message = BindCompletionMessage(invocationId, sequenceId, error, result, hasResult); break; case HubProtocolConstants.CancelInvocationMessageType: - message = BindCancelInvocationMessage(invocationId); + message = BindCancelInvocationMessage(invocationId, sequenceId); break; case HubProtocolConstants.PingMessageType: return PingMessage.Instance; case HubProtocolConstants.CloseMessageType: return BindCloseMessage(error, allowReconnect); + case HubProtocolConstants.AckMessageType: + return BindAckMessage(sequenceId); case null: throw new InvalidDataException($"Missing required property '{TypePropertyName}'."); default: @@ -544,6 +553,10 @@ private void WriteMessageCore(HubMessage message, IBufferWriter stream) WriteMessageType(writer, HubProtocolConstants.CloseMessageType); WriteCloseMessage(m, writer); break; + case AckMessage m: + WriteMessageType(writer, HubProtocolConstants.AckMessageType); + WriteAckMessage(m, writer); + break; default: throw new InvalidOperationException($"Unsupported message type: {message.GetType().FullName}"); } @@ -573,6 +586,7 @@ private static void WriteHeaders(Utf8JsonWriter writer, HubInvocationMessage mes private void WriteCompletionMessage(CompletionMessage message, Utf8JsonWriter writer) { WriteInvocationId(message, writer); + WriteSequenceId(message, writer); if (!string.IsNullOrEmpty(message.Error)) { writer.WriteString(ErrorPropertyNameBytes, message.Error); @@ -601,11 +615,13 @@ private void WriteCompletionMessage(CompletionMessage message, Utf8JsonWriter wr private static void WriteCancelInvocationMessage(CancelInvocationMessage message, Utf8JsonWriter writer) { WriteInvocationId(message, writer); + WriteSequenceId(message, writer); } private void WriteStreamItemMessage(StreamItemMessage message, Utf8JsonWriter writer) { WriteInvocationId(message, writer); + WriteSequenceId(message, writer); writer.WritePropertyName(ItemPropertyNameBytes); if (message.Item == null) @@ -621,6 +637,7 @@ private void WriteStreamItemMessage(StreamItemMessage message, Utf8JsonWriter wr private void WriteInvocationMessage(InvocationMessage message, Utf8JsonWriter writer) { WriteInvocationId(message, writer); + WriteSequenceId(message, writer); writer.WriteString(TargetPropertyNameBytes, message.Target); WriteArguments(message.Arguments, writer); @@ -631,6 +648,7 @@ private void WriteInvocationMessage(InvocationMessage message, Utf8JsonWriter wr private void WriteStreamInvocationMessage(StreamInvocationMessage message, Utf8JsonWriter writer) { WriteInvocationId(message, writer); + WriteSequenceId(message, writer); writer.WriteString(TargetPropertyNameBytes, message.Target); WriteArguments(message.Arguments, writer); @@ -651,6 +669,11 @@ private static void WriteCloseMessage(CloseMessage message, Utf8JsonWriter write } } + private static void WriteAckMessage(AckMessage message, Utf8JsonWriter writer) + { + writer.WriteString(SequenceIdPropertyName, message.SequenceId); + } + private void WriteArguments(object?[] arguments, Utf8JsonWriter writer) { writer.WriteStartArray(ArgumentsPropertyNameBytes); @@ -691,22 +714,30 @@ private static void WriteInvocationId(HubInvocationMessage message, Utf8JsonWrit } } + private static void WriteSequenceId(HubInvocationMessage message, Utf8JsonWriter writer) + { + if (!string.IsNullOrEmpty(message.SequenceId)) + { + writer.WriteString(SequenceIdPropertyNameBytes, message.SequenceId); + } + } + private static void WriteMessageType(Utf8JsonWriter writer, int type) { writer.WriteNumber(TypePropertyNameBytes, type); } - private static HubMessage BindCancelInvocationMessage(string? invocationId) + private static HubMessage BindCancelInvocationMessage(string? invocationId, string? sequenceId) { if (string.IsNullOrEmpty(invocationId)) { throw new InvalidDataException($"Missing required property '{InvocationIdPropertyName}'."); } - return new CancelInvocationMessage(invocationId); + return new CancelInvocationMessage(invocationId, sequenceId); } - private static HubMessage BindCompletionMessage(string invocationId, string? error, object? result, bool hasResult) + private static HubMessage BindCompletionMessage(string invocationId, string? sequenceId, string? error, object? result, bool hasResult) { if (string.IsNullOrEmpty(invocationId)) { @@ -720,13 +751,13 @@ private static HubMessage BindCompletionMessage(string invocationId, string? err if (hasResult) { - return new CompletionMessage(invocationId, error, result, hasResult: true); + return new CompletionMessage(invocationId, sequenceId, error, result, hasResult: true); } - return new CompletionMessage(invocationId, error, result: null, hasResult: false); + return new CompletionMessage(invocationId, sequenceId, error, result: null, hasResult: false); } - private static HubMessage BindStreamItemMessage(string invocationId, object? item, bool hasItem) + private static HubMessage BindStreamItemMessage(string invocationId, string? sequenceId, object? item, bool hasItem) { if (string.IsNullOrEmpty(invocationId)) { @@ -738,10 +769,11 @@ private static HubMessage BindStreamItemMessage(string invocationId, object? ite throw new InvalidDataException($"Missing required property '{ItemPropertyName}'."); } - return new StreamItemMessage(invocationId, item); + return new StreamItemMessage(invocationId, sequenceId, item); } - private static HubMessage BindStreamInvocationMessage(string? invocationId, string target, object?[]? arguments, bool hasArguments, string[]? streamIds) + private static HubMessage BindStreamInvocationMessage(string? invocationId, string? sequenceId, string target, + object?[]? arguments, bool hasArguments, string[]? streamIds) { if (string.IsNullOrEmpty(invocationId)) { @@ -760,10 +792,11 @@ private static HubMessage BindStreamInvocationMessage(string? invocationId, stri Debug.Assert(arguments != null); - return new StreamInvocationMessage(invocationId, target, arguments, streamIds); + return new StreamInvocationMessage(invocationId, sequenceId, target, arguments, streamIds); } - private static HubMessage BindInvocationMessage(string? invocationId, string target, object?[]? arguments, bool hasArguments, string[]? streamIds) + private static HubMessage BindInvocationMessage(string? invocationId, string? sequenceId, string target, + object?[]? arguments, bool hasArguments, string[]? streamIds) { if (string.IsNullOrEmpty(target)) { @@ -777,7 +810,7 @@ private static HubMessage BindInvocationMessage(string? invocationId, string tar Debug.Assert(arguments != null); - return new InvocationMessage(invocationId, target, arguments, streamIds); + return new InvocationMessage(invocationId, sequenceId, target, arguments, streamIds); } private object? BindType(ref Utf8JsonReader reader, ReadOnlySequence input, Type type) @@ -853,6 +886,16 @@ private static CloseMessage BindCloseMessage(string? error, bool allowReconnect) return new CloseMessage(error, allowReconnect); } + private static AckMessage BindAckMessage(string? sequenceId) + { + if (string.IsNullOrEmpty(sequenceId)) + { + throw new InvalidDataException("Missing 'sequenceId' in Ack message."); + } + + return new AckMessage(sequenceId); + } + private static HubMessage ApplyHeaders(HubMessage message, Dictionary? headers) { if (headers != null && message is HubInvocationMessage invocationMessage) diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs new file mode 100644 index 000000000000..30e342c07dd7 --- /dev/null +++ b/src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR.Protocol; + +/// +/// +/// +public sealed class AckMessage : HubMessage +{ + /// + /// + /// + /// + public AckMessage(string sequenceId) + { + SequenceId = sequenceId; + } + + /// + /// + /// + public string SequenceId { get; } +} diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/CancelInvocationMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/CancelInvocationMessage.cs index 5cc9dba052a1..31e80b8fd1b9 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/CancelInvocationMessage.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/CancelInvocationMessage.cs @@ -17,4 +17,13 @@ public class CancelInvocationMessage : HubInvocationMessage public CancelInvocationMessage(string invocationId) : base(invocationId) { } + + /// + /// + /// + /// + /// + public CancelInvocationMessage(string invocationId, string? sequenceId) : base(invocationId, sequenceId) + { + } } diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/CompletionMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/CompletionMessage.cs index 440e431c6bb8..9620247787e9 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/CompletionMessage.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/CompletionMessage.cs @@ -46,6 +46,28 @@ public CompletionMessage(string invocationId, string? error, object? result, boo HasResult = hasResult; } + /// + /// + /// + /// + /// + /// + /// + /// + /// + public CompletionMessage(string invocationId, string? sequenceId, string? error, object? result, bool hasResult) + : base(invocationId, sequenceId) + { + if (error is not null && hasResult) + { + throw new ArgumentException($"Expected either '{nameof(error)}' or '{nameof(result)}' to be provided, but not both"); + } + + Error = error; + Result = result; + HasResult = hasResult; + } + /// public override string ToString() { diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/HubInvocationMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/HubInvocationMessage.cs index fd97e969563f..0ebaaf7e8fdc 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/HubInvocationMessage.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/HubInvocationMessage.cs @@ -20,6 +20,11 @@ public abstract class HubInvocationMessage : HubMessage /// public string? InvocationId { get; } + /// + /// TODO; monotonically increasing ID that identifies this message + /// + public string? SequenceId { get; } + /// /// Initializes a new instance of the class. /// @@ -28,4 +33,15 @@ protected HubInvocationMessage(string? invocationId) { InvocationId = invocationId; } + + /// + /// TODO + /// + /// + /// + protected HubInvocationMessage(string? invocationId, string? sequenceId) + : this(invocationId) + { + SequenceId = sequenceId; + } } diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/HubMethodInvocationMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/HubMethodInvocationMessage.cs index cc6be99a25ba..841a6a220612 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/HubMethodInvocationMessage.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/HubMethodInvocationMessage.cs @@ -27,6 +27,28 @@ public abstract class HubMethodInvocationMessage : HubInvocationMessage /// public string[]? StreamIds { get; } + /// + /// + /// + /// + /// + /// + /// + /// + /// + protected HubMethodInvocationMessage(string? invocationId, string? sequenceId, string target, object?[] arguments, string[]? streamIds) + : base(invocationId, sequenceId) + { + if (string.IsNullOrEmpty(target)) + { + throw new ArgumentNullException(nameof(target)); + } + + Target = target; + Arguments = arguments; + StreamIds = streamIds; + } + /// /// Initializes a new instance of the class. /// @@ -94,6 +116,11 @@ public InvocationMessage(string? invocationId, string target, object?[] argument { } + public InvocationMessage(string? invocationId, string? sequenceId, string target, object?[] arguments, string[]? streamIds) + : base(invocationId, sequenceId, target, arguments, streamIds) + { + } + /// public override string ToString() { @@ -149,6 +176,19 @@ public StreamInvocationMessage(string invocationId, string target, object?[] arg { } + /// + /// + /// + /// + /// + /// + /// + /// + public StreamInvocationMessage(string invocationId, string? sequenceId, string target, object?[] arguments, string[]? streamIds) + : base(invocationId, sequenceId, target, arguments, streamIds) + { + } + /// public override string ToString() { diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/HubProtocolConstants.cs b/src/SignalR/common/SignalR.Common/src/Protocol/HubProtocolConstants.cs index eb1e3914ac17..538e07ce0e03 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/HubProtocolConstants.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/HubProtocolConstants.cs @@ -42,4 +42,9 @@ public static class HubProtocolConstants /// Represents the close message type. /// public const int CloseMessageType = 7; + + /// + /// + /// + public const int AckMessageType = 8; } diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/StreamItemMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/StreamItemMessage.cs index dafba133a0f5..a4c410abef57 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/StreamItemMessage.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/StreamItemMessage.cs @@ -23,6 +23,17 @@ public StreamItemMessage(string invocationId, object? item) : base(invocationId) Item = item; } + /// + /// + /// + /// + /// + /// + public StreamItemMessage(string invocationId, string? sequenceId, object? item) : base(invocationId, sequenceId) + { + Item = item; + } + /// public override string ToString() { diff --git a/src/SignalR/common/SignalR.Common/src/PublicAPI/net462/PublicAPI.Unshipped.txt b/src/SignalR/common/SignalR.Common/src/PublicAPI/net462/PublicAPI.Unshipped.txt index 7dc5c58110bf..0afbd3fec7cc 100644 --- a/src/SignalR/common/SignalR.Common/src/PublicAPI/net462/PublicAPI.Unshipped.txt +++ b/src/SignalR/common/SignalR.Common/src/PublicAPI/net462/PublicAPI.Unshipped.txt @@ -1 +1,11 @@ #nullable enable +Microsoft.AspNetCore.SignalR.Protocol.AckMessage +Microsoft.AspNetCore.SignalR.Protocol.AckMessage.AckMessage(string! sequenceId) -> void +Microsoft.AspNetCore.SignalR.Protocol.AckMessage.SequenceId.get -> string! +Microsoft.AspNetCore.SignalR.Protocol.CancelInvocationMessage.CancelInvocationMessage(string! invocationId, string! sequenceId) -> void +Microsoft.AspNetCore.SignalR.Protocol.CompletionMessage.CompletionMessage(string! invocationId, string! sequenceId, string? error, object? result, bool hasResult) -> void +Microsoft.AspNetCore.SignalR.Protocol.HubInvocationMessage.HubInvocationMessage(string? invocationId, string! sequenceId) -> void +Microsoft.AspNetCore.SignalR.Protocol.HubInvocationMessage.SequenceId.get -> string? +Microsoft.AspNetCore.SignalR.Protocol.HubMethodInvocationMessage.HubMethodInvocationMessage(string? invocationId, string! sequenceId, string! target, object?[]! arguments, string![]? streamIds) -> void +Microsoft.AspNetCore.SignalR.Protocol.StreamInvocationMessage.StreamInvocationMessage(string! invocationId, string! sequenceId, string! target, object?[]! arguments, string![]? streamIds) -> void +Microsoft.AspNetCore.SignalR.Protocol.StreamItemMessage.StreamItemMessage(string! invocationId, string! sequenceId, object? item) -> void diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index a2a9f24429ef..670300770aef 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -254,11 +254,22 @@ private ValueTask WriteCore(HubMessage message, CancellationToken c { try { - // We know that we are only writing this message to one receiver, so we can - // write it without caching. - Protocol.WriteMessage(message, _connectionContext.Transport.Output); + // TODO + var isAck = true; + if (isAck) + { + var m = new SerializedHubMessage(message); + var bytes = m.GetSerializedMessage(Protocol); + return _connectionContext.Transport.Output.WriteAsync(bytes, cancellationToken); + } + else + { + // We know that we are only writing this message to one receiver, so we can + // write it without caching. + Protocol.WriteMessage(message, _connectionContext.Transport.Output); - return _connectionContext.Transport.Output.FlushAsync(cancellationToken); + return _connectionContext.Transport.Output.FlushAsync(cancellationToken); + } } catch (Exception ex) { @@ -731,4 +742,10 @@ internal void Cleanup() // Use _streamTracker to avoid lazy init from StreamTracker getter if it doesn't exist _streamTracker?.CompleteAll(new OperationCanceledException("The underlying connection was closed.")); } + + internal void Ack(AckMessage ackMessage) + { + // Remove from ring buffer + // ackMessage.SequenceId + } } diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 3458f6760a9f..ec4dd32c7f15 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -186,6 +186,10 @@ public override Task DispatchMessageAsync(HubConnectionContext connection, HubMe } break; + case AckMessage ackMessage: + connection.Ack(ackMessage); + break; + // Other kind of message we weren't expecting default: Log.UnsupportedMessageReceived(_logger, hubMessage.GetType().FullName!); diff --git a/src/SignalR/server/Core/src/Internal/MessageBuffer.cs b/src/SignalR/server/Core/src/Internal/MessageBuffer.cs new file mode 100644 index 000000000000..d795f62c942a --- /dev/null +++ b/src/SignalR/server/Core/src/Internal/MessageBuffer.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.IO.Pipelines; +using Microsoft.AspNetCore.SignalR.Protocol; + +namespace Microsoft.AspNetCore.SignalR.Internal; + +internal sealed class MessageBuffer +{ + private readonly SerializedHubMessage[] _buffer; + private int _index; + + // TODO: pass in limits + public MessageBuffer() + { + _buffer = new SerializedHubMessage[10]; + } + + public async ValueTask WriteAsync(PipeWriter pipeWriter, SerializedHubMessage hubMessage, IHubProtocol protocol, + CancellationToken cancellationToken) + { + // No lock because this is always called in a single async loop? + // And other methods don't affect the checks here? + + // TODO: Backpressure + + if (_buffer[_index] is not null) + { + // ... + } + _buffer[_index] = hubMessage; + _index = _index + 1 % _buffer.Length; + await pipeWriter.WriteAsync(hubMessage.GetSerializedMessage(protocol), cancellationToken); + } +} From b6a91a78545d5b6da34e41670238c2af9a2c610e Mon Sep 17 00:00:00 2001 From: Brennan Conroy Date: Fri, 12 May 2023 09:18:01 -0700 Subject: [PATCH 14/25] stash --- .../src/Protocol/JsonHubProtocol.cs | 26 +++++----- .../common/Shared/SystemTextJsonExtensions.cs | 17 +++++++ .../SignalR.Common/src/Protocol/AckMessage.cs | 4 +- .../src/Protocol/CancelInvocationMessage.cs | 2 +- .../src/Protocol/CompletionMessage.cs | 23 ++++++++- .../src/Protocol/HubInvocationMessage.cs | 4 +- .../Protocol/HubMethodInvocationMessage.cs | 6 +-- .../src/Protocol/StreamItemMessage.cs | 2 +- .../server/Core/src/HubConnectionContext.cs | 46 ++++++++++++++--- .../server/Core/src/HubConnectionHandler.cs | 3 ++ src/SignalR/server/Core/src/HubOptions.cs | 2 + .../Core/src/Internal/DefaultHubDispatcher.cs | 49 ++++++++++++------- .../server/Core/src/Internal/MessageBuffer.cs | 41 +++++++++++++--- 13 files changed, 170 insertions(+), 55 deletions(-) diff --git a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs index 236b5ac91e4b..7dd2317f3130 100644 --- a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs +++ b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs @@ -141,7 +141,7 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) Dictionary? headers = null; var completed = false; var allowReconnect = false; - string? sequenceId = null; + long? sequenceId = null; var reader = new Utf8JsonReader(input, isFinalBlock: true, state: default); @@ -330,7 +330,7 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) } else if (reader.ValueTextEquals(SequenceIdPropertyNameBytes.EncodedUtf8Bytes)) { - sequenceId = reader.ReadAsString(SequenceIdPropertyName); + sequenceId = reader.ReadAsInt64(SequenceIdPropertyName); } else { @@ -671,7 +671,7 @@ private static void WriteCloseMessage(CloseMessage message, Utf8JsonWriter write private static void WriteAckMessage(AckMessage message, Utf8JsonWriter writer) { - writer.WriteString(SequenceIdPropertyName, message.SequenceId); + writer.WriteNumber(SequenceIdPropertyName, message.SequenceId); } private void WriteArguments(object?[] arguments, Utf8JsonWriter writer) @@ -716,9 +716,9 @@ private static void WriteInvocationId(HubInvocationMessage message, Utf8JsonWrit private static void WriteSequenceId(HubInvocationMessage message, Utf8JsonWriter writer) { - if (!string.IsNullOrEmpty(message.SequenceId)) + if (message.SequenceId is not null) { - writer.WriteString(SequenceIdPropertyNameBytes, message.SequenceId); + writer.WriteNumber(SequenceIdPropertyNameBytes, message.SequenceId.Value); } } @@ -727,7 +727,7 @@ private static void WriteMessageType(Utf8JsonWriter writer, int type) writer.WriteNumber(TypePropertyNameBytes, type); } - private static HubMessage BindCancelInvocationMessage(string? invocationId, string? sequenceId) + private static HubMessage BindCancelInvocationMessage(string? invocationId, long? sequenceId) { if (string.IsNullOrEmpty(invocationId)) { @@ -737,7 +737,7 @@ private static HubMessage BindCancelInvocationMessage(string? invocationId, stri return new CancelInvocationMessage(invocationId, sequenceId); } - private static HubMessage BindCompletionMessage(string invocationId, string? sequenceId, string? error, object? result, bool hasResult) + private static HubMessage BindCompletionMessage(string invocationId, long? sequenceId, string? error, object? result, bool hasResult) { if (string.IsNullOrEmpty(invocationId)) { @@ -757,7 +757,7 @@ private static HubMessage BindCompletionMessage(string invocationId, string? seq return new CompletionMessage(invocationId, sequenceId, error, result: null, hasResult: false); } - private static HubMessage BindStreamItemMessage(string invocationId, string? sequenceId, object? item, bool hasItem) + private static HubMessage BindStreamItemMessage(string invocationId, long? sequenceId, object? item, bool hasItem) { if (string.IsNullOrEmpty(invocationId)) { @@ -772,7 +772,7 @@ private static HubMessage BindStreamItemMessage(string invocationId, string? seq return new StreamItemMessage(invocationId, sequenceId, item); } - private static HubMessage BindStreamInvocationMessage(string? invocationId, string? sequenceId, string target, + private static HubMessage BindStreamInvocationMessage(string? invocationId, long? sequenceId, string target, object?[]? arguments, bool hasArguments, string[]? streamIds) { if (string.IsNullOrEmpty(invocationId)) @@ -795,7 +795,7 @@ private static HubMessage BindStreamInvocationMessage(string? invocationId, stri return new StreamInvocationMessage(invocationId, sequenceId, target, arguments, streamIds); } - private static HubMessage BindInvocationMessage(string? invocationId, string? sequenceId, string target, + private static HubMessage BindInvocationMessage(string? invocationId, long? sequenceId, string target, object?[]? arguments, bool hasArguments, string[]? streamIds) { if (string.IsNullOrEmpty(target)) @@ -886,14 +886,14 @@ private static CloseMessage BindCloseMessage(string? error, bool allowReconnect) return new CloseMessage(error, allowReconnect); } - private static AckMessage BindAckMessage(string? sequenceId) + private static AckMessage BindAckMessage(long? sequenceId) { - if (string.IsNullOrEmpty(sequenceId)) + if (sequenceId is null) { throw new InvalidDataException("Missing 'sequenceId' in Ack message."); } - return new AckMessage(sequenceId); + return new AckMessage(sequenceId.Value); } private static HubMessage ApplyHeaders(HubMessage message, Dictionary? headers) diff --git a/src/SignalR/common/Shared/SystemTextJsonExtensions.cs b/src/SignalR/common/Shared/SystemTextJsonExtensions.cs index c28ba9e16398..30f1c9adc6ab 100644 --- a/src/SignalR/common/Shared/SystemTextJsonExtensions.cs +++ b/src/SignalR/common/Shared/SystemTextJsonExtensions.cs @@ -97,4 +97,21 @@ public static string ReadAsString(this ref Utf8JsonReader reader, string propert return reader.GetInt32(); } + + public static long? ReadAsInt64(this ref Utf8JsonReader reader, string propertyName) + { + reader.Read(); + + if (reader.TokenType == JsonTokenType.Null) + { + return null; + } + + if (reader.TokenType != JsonTokenType.Number) + { + throw new InvalidDataException($"Expected '{propertyName}' to be of type {JsonTokenType.Number}."); + } + + return reader.GetInt64(); + } } diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs index 30e342c07dd7..916e8e0a2aba 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs @@ -18,7 +18,7 @@ public sealed class AckMessage : HubMessage /// /// /// - public AckMessage(string sequenceId) + public AckMessage(long sequenceId) { SequenceId = sequenceId; } @@ -26,5 +26,5 @@ public AckMessage(string sequenceId) /// /// /// - public string SequenceId { get; } + public long SequenceId { get; } } diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/CancelInvocationMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/CancelInvocationMessage.cs index 31e80b8fd1b9..6894b19be801 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/CancelInvocationMessage.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/CancelInvocationMessage.cs @@ -23,7 +23,7 @@ public CancelInvocationMessage(string invocationId) : base(invocationId) /// /// /// - public CancelInvocationMessage(string invocationId, string? sequenceId) : base(invocationId, sequenceId) + public CancelInvocationMessage(string invocationId, long? sequenceId) : base(invocationId, sequenceId) { } } diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/CompletionMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/CompletionMessage.cs index 9620247787e9..1ea5ce5a1fcd 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/CompletionMessage.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/CompletionMessage.cs @@ -55,7 +55,7 @@ public CompletionMessage(string invocationId, string? error, object? result, boo /// /// /// - public CompletionMessage(string invocationId, string? sequenceId, string? error, object? result, bool hasResult) + public CompletionMessage(string invocationId, long? sequenceId, string? error, object? result, bool hasResult) : base(invocationId, sequenceId) { if (error is not null && hasResult) @@ -104,4 +104,25 @@ public static CompletionMessage WithResult(string invocationId, object? payload) /// The constructed . public static CompletionMessage Empty(string invocationId) => new CompletionMessage(invocationId, error: null, result: null, hasResult: false); + + public static CompletionMessage WithError(string invocationId, long? sequenceId, string? error) + => new CompletionMessage(invocationId, sequenceId, error, result: null, hasResult: false); + + /// + /// Constructs a with a result. + /// + /// The ID of the invocation that is being completed. + /// The result from the invocation. + /// The constructed . + public static CompletionMessage WithResult(string invocationId, long? sequenceId, object? payload) + => new CompletionMessage(invocationId, sequenceId, error: null, result: payload, hasResult: true); + + /// + /// Constructs a without an error or result. + /// This means the invocation was successful but there is no return value. + /// + /// The ID of the invocation that is being completed. + /// The constructed . + public static CompletionMessage Empty(string invocationId, long? sequenceId) + => new CompletionMessage(invocationId, sequenceId, error: null, result: null, hasResult: false); } diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/HubInvocationMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/HubInvocationMessage.cs index 0ebaaf7e8fdc..43e6278cfd9a 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/HubInvocationMessage.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/HubInvocationMessage.cs @@ -23,7 +23,7 @@ public abstract class HubInvocationMessage : HubMessage /// /// TODO; monotonically increasing ID that identifies this message /// - public string? SequenceId { get; } + public long? SequenceId { get; } /// /// Initializes a new instance of the class. @@ -39,7 +39,7 @@ protected HubInvocationMessage(string? invocationId) /// /// /// - protected HubInvocationMessage(string? invocationId, string? sequenceId) + protected HubInvocationMessage(string? invocationId, long? sequenceId) : this(invocationId) { SequenceId = sequenceId; diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/HubMethodInvocationMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/HubMethodInvocationMessage.cs index 841a6a220612..5c1d7740bfc4 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/HubMethodInvocationMessage.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/HubMethodInvocationMessage.cs @@ -36,7 +36,7 @@ public abstract class HubMethodInvocationMessage : HubInvocationMessage /// /// /// - protected HubMethodInvocationMessage(string? invocationId, string? sequenceId, string target, object?[] arguments, string[]? streamIds) + protected HubMethodInvocationMessage(string? invocationId, long? sequenceId, string target, object?[] arguments, string[]? streamIds) : base(invocationId, sequenceId) { if (string.IsNullOrEmpty(target)) @@ -116,7 +116,7 @@ public InvocationMessage(string? invocationId, string target, object?[] argument { } - public InvocationMessage(string? invocationId, string? sequenceId, string target, object?[] arguments, string[]? streamIds) + public InvocationMessage(string? invocationId, long? sequenceId, string target, object?[] arguments, string[]? streamIds) : base(invocationId, sequenceId, target, arguments, streamIds) { } @@ -184,7 +184,7 @@ public StreamInvocationMessage(string invocationId, string target, object?[] arg /// /// /// - public StreamInvocationMessage(string invocationId, string? sequenceId, string target, object?[] arguments, string[]? streamIds) + public StreamInvocationMessage(string invocationId, long? sequenceId, string target, object?[] arguments, string[]? streamIds) : base(invocationId, sequenceId, target, arguments, streamIds) { } diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/StreamItemMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/StreamItemMessage.cs index a4c410abef57..93dd0e1f89ec 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/StreamItemMessage.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/StreamItemMessage.cs @@ -29,7 +29,7 @@ public StreamItemMessage(string invocationId, object? item) : base(invocationId) /// /// /// - public StreamItemMessage(string invocationId, string? sequenceId, object? item) : base(invocationId, sequenceId) + public StreamItemMessage(string invocationId, long? sequenceId, object? item) : base(invocationId, sequenceId) { Item = item; } diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index 670300770aef..ad315c86af9f 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -35,6 +35,7 @@ public partial class HubConnectionContext private readonly TimeProvider _timeProvider; private readonly CancellationTokenRegistration _closedRegistration; private readonly CancellationTokenRegistration? _closedRequestedRegistration; + private readonly MessageBuffer _messageBuffer = new(); private StreamTracker? _streamTracker; private long _lastSendTick; @@ -49,6 +50,10 @@ public partial class HubConnectionContext private long _receivedMessageTick; private ClaimsPrincipal? _user; + internal bool UseAcks; + private long _sequenceId; + private long _latestReceivedSequenceId = long.MinValue; + /// /// Initializes a new instance of the class. /// @@ -258,9 +263,7 @@ private ValueTask WriteCore(HubMessage message, CancellationToken c var isAck = true; if (isAck) { - var m = new SerializedHubMessage(message); - var bytes = m.GetSerializedMessage(Protocol); - return _connectionContext.Transport.Output.WriteAsync(bytes, cancellationToken); + return _messageBuffer.WriteAsync(_connectionContext.Transport.Output, new SerializedHubMessage(message), Protocol, cancellationToken); } else { @@ -286,10 +289,19 @@ private ValueTask WriteCore(SerializedHubMessage message, Cancellat { try { - // Grab a preserialized buffer for this protocol. - var buffer = message.GetSerializedMessage(Protocol); + // TODO + var isAck = true; + if (isAck) + { + return _messageBuffer.WriteAsync(_connectionContext.Transport.Output, message, Protocol, cancellationToken); + } + else + { + // Grab a potentially pre-serialized buffer for this protocol. + var buffer = message.GetSerializedMessage(Protocol); - return _connectionContext.Transport.Output.WriteAsync(buffer, cancellationToken); + return _connectionContext.Transport.Output.WriteAsync(buffer, cancellationToken); + } } catch (Exception ex) { @@ -746,6 +758,26 @@ internal void Cleanup() internal void Ack(AckMessage ackMessage) { // Remove from ring buffer - // ackMessage.SequenceId + _messageBuffer.Ack(ackMessage); + } + + private long? GetSequenceId() + { + if (UseAcks) + { + return Interlocked.Increment(ref _sequenceId); + } + return null; + } + + internal bool ShouldProcessMessage(HubInvocationMessage message) + { + if (message.SequenceId <= _latestReceivedSequenceId) + { + // Ignore, this is a duplicate message + return false; + } + _latestReceivedSequenceId = message.SequenceId!.Value; + return true; } } diff --git a/src/SignalR/server/Core/src/HubConnectionHandler.cs b/src/SignalR/server/Core/src/HubConnectionHandler.cs index ab3d0f5bbd7b..f0aa90491a34 100644 --- a/src/SignalR/server/Core/src/HubConnectionHandler.cs +++ b/src/SignalR/server/Core/src/HubConnectionHandler.cs @@ -70,6 +70,7 @@ IServiceScopeFactory serviceScopeFactory _enableDetailedErrors = _hubOptions.EnableDetailedErrors ?? _enableDetailedErrors; _maxParallelInvokes = _hubOptions.MaximumParallelInvocationsPerClient; disableImplicitFromServiceParameters = _hubOptions.DisableImplicitFromServicesParameters; + var _ = _hubOptions.UseAcks; if (_hubOptions.HubFilters != null) { @@ -82,6 +83,7 @@ IServiceScopeFactory serviceScopeFactory _enableDetailedErrors = _globalHubOptions.EnableDetailedErrors ?? _enableDetailedErrors; _maxParallelInvokes = _globalHubOptions.MaximumParallelInvocationsPerClient; disableImplicitFromServiceParameters = _globalHubOptions.DisableImplicitFromServicesParameters; + var _ = _globalHubOptions.UseAcks; if (_globalHubOptions.HubFilters != null) { @@ -94,6 +96,7 @@ IServiceScopeFactory serviceScopeFactory new HubContext(lifetimeManager), _enableDetailedErrors, disableImplicitFromServiceParameters, + useAcks: true, new Logger>(loggerFactory), hubFilters, lifetimeManager); diff --git a/src/SignalR/server/Core/src/HubOptions.cs b/src/SignalR/server/Core/src/HubOptions.cs index 3a4e0883f884..9e2a2fbe979b 100644 --- a/src/SignalR/server/Core/src/HubOptions.cs +++ b/src/SignalR/server/Core/src/HubOptions.cs @@ -79,4 +79,6 @@ public int MaximumParallelInvocationsPerClient /// False by default. Hub method arguments will be resolved from a DI container if possible. /// public bool DisableImplicitFromServicesParameters { get; set; } + + public bool UseAcks { get; set; } } diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index ec4dd32c7f15..46ea44b5ec2b 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -28,15 +28,17 @@ internal sealed partial class DefaultHubDispatcher : HubDispatcher w private readonly Func? _onConnectedMiddleware; private readonly Func? _onDisconnectedMiddleware; private readonly HubLifetimeManager _hubLifetimeManager; + private readonly bool _useAcks; public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContext hubContext, bool enableDetailedErrors, - bool disableImplicitFromServiceParameters, ILogger> logger, List? hubFilters, HubLifetimeManager lifetimeManager) + bool disableImplicitFromServiceParameters, bool useAcks, ILogger> logger, List? hubFilters, HubLifetimeManager lifetimeManager) { _serviceScopeFactory = serviceScopeFactory; _hubContext = hubContext; _enableDetailedErrors = enableDetailedErrors; _logger = logger; _hubLifetimeManager = lifetimeManager; + _useAcks = useAcks; DiscoverHubMethods(disableImplicitFromServiceParameters); var count = hubFilters?.Count ?? 0; @@ -72,6 +74,9 @@ public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContex public override async Task OnConnectedAsync(HubConnectionContext connection) { + // TODO: figure out when this should be true + connection.UseAcks = true; + await using var scope = _serviceScopeFactory.CreateAsyncScope(); var hubActivator = scope.ServiceProvider.GetRequiredService>(); @@ -130,6 +135,14 @@ public override Task DispatchMessageAsync(HubConnectionContext connection, HubMe // With parallel invokes enabled, messages run sequentially until they go async and then the next message will be allowed to start running. + if (_useAcks && hubMessage is HubInvocationMessage invocation) + { + if (!connection.ShouldProcessMessage(invocation)) + { + return Task.CompletedTask; + } + } + switch (hubMessage) { case InvocationBindingFailureMessage bindingFailureMessage: @@ -205,7 +218,7 @@ private Task ProcessInvocationBindingFailure(HubConnectionContext connection, In var errorMessage = ErrorMessageHelper.BuildErrorMessage($"Failed to invoke '{bindingFailureMessage.Target}' due to an error on the server.", bindingFailureMessage.BindingFailure.SourceException, _enableDetailedErrors); - return SendInvocationError(bindingFailureMessage.InvocationId, connection, errorMessage); + return SendInvocationError(bindingFailureMessage.InvocationId, connection.GetSequenceId(), connection, errorMessage); } private Task ProcessStreamBindingFailure(HubConnectionContext connection, StreamBindingFailureMessage bindingFailureMessage) @@ -214,7 +227,7 @@ private Task ProcessStreamBindingFailure(HubConnectionContext connection, Stream "Failed to bind Stream message.", bindingFailureMessage.BindingFailure.SourceException, _enableDetailedErrors); - var message = CompletionMessage.WithError(bindingFailureMessage.Id, errorString); + var message = CompletionMessage.WithError(bindingFailureMessage.Id, connection.GetSequenceId(), errorString); Log.ClosingStreamWithBindingError(_logger, message); // ignore failure, it means the client already completed the stream or the stream never existed on the server @@ -247,7 +260,7 @@ private Task ProcessInvocation(HubConnectionContext connection, { // Send an error to the client. Then let the normal completion process occur return connection.WriteAsync(CompletionMessage.WithError( - hubMethodInvocationMessage.InvocationId, $"Unknown hub method '{hubMethodInvocationMessage.Target}'")).AsTask(); + hubMethodInvocationMessage.InvocationId, connection.GetSequenceId(), $"Unknown hub method '{hubMethodInvocationMessage.Target}'")).AsTask(); } else { @@ -290,7 +303,7 @@ private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionCon if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection, descriptor, hubMethodInvocationMessage.Arguments, hub)) { Log.HubMethodNotAuthorized(_logger, hubMethodInvocationMessage.Target); - await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, + await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection.GetSequenceId(), connection, $"Failed to invoke '{hubMethodInvocationMessage.Target}' because user is unauthorized"); return true; } @@ -308,7 +321,7 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, { var ex = new HubException($"Client sent {clientStreamLength} stream(s), Hub method expects {serverStreamLength}."); Log.InvalidHubParameters(_logger, hubMethodInvocationMessage.Target, ex); - await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, + await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection.GetSequenceId(), connection, ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, _enableDetailedErrors)); return true; } @@ -361,7 +374,7 @@ static async Task ExecuteInvocation(DefaultHubDispatcher dispatcher, catch (Exception ex) { Log.FailedInvokingHubMethod(logger, hubMethodInvocationMessage.Target, ex); - await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, + await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection.GetSequenceId(), connection, ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, enableDetailedErrors)); return; } @@ -378,8 +391,8 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, // No InvocationId - Send Async, no response expected if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId)) { - // Invoke Async, one reponse expected - await connection.WriteAsync(CompletionMessage.WithResult(hubMethodInvocationMessage.InvocationId, result)); + // Invoke Async, one response expected + await connection.WriteAsync(CompletionMessage.WithResult(hubMethodInvocationMessage.InvocationId, connection.GetSequenceId(), result)); } } @@ -401,13 +414,13 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, catch (TargetInvocationException ex) { Log.FailedInvokingHubMethod(_logger, hubMethodInvocationMessage.Target, ex); - await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, + await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection.GetSequenceId(), connection, ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex.InnerException ?? ex, _enableDetailedErrors)); } catch (Exception ex) { Log.FailedInvokingHubMethod(_logger, hubMethodInvocationMessage.Target, ex); - await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, + await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection.GetSequenceId(), connection, ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, _enableDetailedErrors)); } } @@ -433,7 +446,7 @@ private static ValueTask CleanupInvocation(HubConnectionContext connection, HubM { foreach (var stream in hubMessage.StreamIds) { - connection.StreamTracker.TryComplete(CompletionMessage.Empty(stream)); + connection.StreamTracker.TryComplete(CompletionMessage.Empty(stream, connection.GetSequenceId())); } } @@ -482,7 +495,7 @@ private async Task StreamAsync(string invocationId, HubConnectionContext connect await using var enumerator = descriptor.FromReturnedStream(result, streamCts.Token); Log.StreamingResult(_logger, invocationId, descriptor.MethodExecutor); - var streamItemMessage = new StreamItemMessage(invocationId, null); + var streamItemMessage = new StreamItemMessage(invocationId, connection.GetSequenceId(), null); while (await enumerator.MoveNextAsync()) { @@ -513,7 +526,7 @@ private async Task StreamAsync(string invocationId, HubConnectionContext connect streamCts.Dispose(); connection.ActiveRequestCancellationSources.TryRemove(invocationId, out _); - await connection.WriteAsync(CompletionMessage.WithError(invocationId, error)); + await connection.WriteAsync(CompletionMessage.WithError(invocationId, connection.GetSequenceId(), error)); } } @@ -559,7 +572,7 @@ private async Task StreamAsync(string invocationId, HubConnectionContext connect } } - private static async Task SendInvocationError(string? invocationId, + private static async Task SendInvocationError(string? invocationId, long? sequenceId, HubConnectionContext connection, string errorMessage) { if (string.IsNullOrEmpty(invocationId)) @@ -567,7 +580,7 @@ private static async Task SendInvocationError(string? invocationId, return; } - await connection.WriteAsync(CompletionMessage.WithError(invocationId, errorMessage)); + await connection.WriteAsync(CompletionMessage.WithError(invocationId, sequenceId, errorMessage)); } private void InitializeHub(THub hub, HubConnectionContext connection, bool invokeAllowed = true) @@ -611,7 +624,7 @@ private async Task ValidateInvocationMode(HubMethodDescriptor hubMethodDes if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId)) { Log.StreamingMethodCalledWithInvoke(_logger, hubMethodInvocationMessage); - await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId, + await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId, connection.GetSequenceId(), $"The client attempted to invoke the streaming '{hubMethodInvocationMessage.Target}' method with a non-streaming invocation.")); } @@ -621,7 +634,7 @@ await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessa if (!hubMethodDescriptor.IsStreamResponse && isStreamResponse) { Log.NonStreamingMethodCalledWithStream(_logger, hubMethodInvocationMessage); - await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId!, + await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId!, connection.GetSequenceId(), $"The client attempted to invoke the non-streaming '{hubMethodInvocationMessage.Target}' method with a streaming invocation.")); return false; diff --git a/src/SignalR/server/Core/src/Internal/MessageBuffer.cs b/src/SignalR/server/Core/src/Internal/MessageBuffer.cs index d795f62c942a..4ff10efef53d 100644 --- a/src/SignalR/server/Core/src/Internal/MessageBuffer.cs +++ b/src/SignalR/server/Core/src/Internal/MessageBuffer.cs @@ -8,16 +8,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal; internal sealed class MessageBuffer { - private readonly SerializedHubMessage[] _buffer; + private readonly (SerializedHubMessage? Message, long? SequenceId)[] _buffer; private int _index; // TODO: pass in limits public MessageBuffer() { - _buffer = new SerializedHubMessage[10]; + _buffer = new (SerializedHubMessage? Message, long? SequenceId)[10]; } - public async ValueTask WriteAsync(PipeWriter pipeWriter, SerializedHubMessage hubMessage, IHubProtocol protocol, + public async ValueTask WriteAsync(PipeWriter pipeWriter, SerializedHubMessage hubMessage, IHubProtocol protocol, CancellationToken cancellationToken) { // No lock because this is always called in a single async loop? @@ -25,12 +25,39 @@ public async ValueTask WriteAsync(PipeWriter pipeWriter, SerializedHubMessage hu // TODO: Backpressure - if (_buffer[_index] is not null) + if (_buffer[_index].Message is not null) { // ... } - _buffer[_index] = hubMessage; - _index = _index + 1 % _buffer.Length; - await pipeWriter.WriteAsync(hubMessage.GetSerializedMessage(protocol), cancellationToken); + + long? sequenceId; + if (hubMessage.Message is HubInvocationMessage invocationMessage) + { + sequenceId = invocationMessage.SequenceId; + } + else + { + // Non-ackable message, don't add to buffer + return await pipeWriter.WriteAsync(hubMessage.GetSerializedMessage(protocol), cancellationToken); + } + + _buffer[_index] = (hubMessage, sequenceId); + _index = (_index + 1) % _buffer.Length; + return await pipeWriter.WriteAsync(hubMessage.GetSerializedMessage(protocol), cancellationToken); + } + + public void Ack(AckMessage ackMessage) + { + var index = _index; + for (var i = 0; i < _buffer.Length; i++) + { + var currentIndex = (index + i) % _buffer.Length; + if (_buffer[currentIndex].SequenceId is long id && id <= ackMessage.SequenceId) + { + _buffer[currentIndex] = (null, null); + } + } + + // Release backpressure? } } From 632c7648fa50e898275637476f5a52a40775bbc6 Mon Sep 17 00:00:00 2001 From: Brennan Conroy Date: Fri, 12 May 2023 17:35:31 -0700 Subject: [PATCH 15/25] stash --- .../src/Protocol/JsonHubProtocol.cs | 32 +++++++------- .../src/Protocol/CancelInvocationMessage.cs | 9 ---- .../src/Protocol/CompletionMessage.cs | 43 ------------------- .../src/Protocol/HubInvocationMessage.cs | 16 ------- .../Protocol/HubMethodInvocationMessage.cs | 40 ----------------- .../src/Protocol/StreamItemMessage.cs | 11 ----- .../server/Core/src/HubConnectionContext.cs | 25 ++++++----- .../Core/src/Internal/DefaultHubDispatcher.cs | 33 +++++++------- .../server/Core/src/Internal/MessageBuffer.cs | 15 ++++--- 9 files changed, 55 insertions(+), 169 deletions(-) diff --git a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs index 7dd2317f3130..be03c0d8fd45 100644 --- a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs +++ b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs @@ -372,7 +372,7 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) message = argumentBindingException != null ? new InvocationBindingFailureMessage(invocationId, target, argumentBindingException) - : BindInvocationMessage(invocationId, sequenceId, target, arguments, hasArguments, streamIds); + : BindInvocationMessage(invocationId, target, arguments, hasArguments, streamIds); } break; case HubProtocolConstants.StreamInvocationMessageType: @@ -398,7 +398,7 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) message = argumentBindingException != null ? new InvocationBindingFailureMessage(invocationId, target, argumentBindingException) - : BindStreamInvocationMessage(invocationId, sequenceId, target, arguments, hasArguments, streamIds); + : BindStreamInvocationMessage(invocationId, target, arguments, hasArguments, streamIds); } break; case HubProtocolConstants.StreamItemMessageType: @@ -421,7 +421,7 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) } } - message = BindStreamItemMessage(invocationId, sequenceId, item, hasItem); + message = BindStreamItemMessage(invocationId, item, hasItem); break; case HubProtocolConstants.CompletionMessageType: if (invocationId is null) @@ -450,10 +450,10 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) } } - message = BindCompletionMessage(invocationId, sequenceId, error, result, hasResult); + message = BindCompletionMessage(invocationId, error, result, hasResult); break; case HubProtocolConstants.CancelInvocationMessageType: - message = BindCancelInvocationMessage(invocationId, sequenceId); + message = BindCancelInvocationMessage(invocationId); break; case HubProtocolConstants.PingMessageType: return PingMessage.Instance; @@ -727,17 +727,17 @@ private static void WriteMessageType(Utf8JsonWriter writer, int type) writer.WriteNumber(TypePropertyNameBytes, type); } - private static HubMessage BindCancelInvocationMessage(string? invocationId, long? sequenceId) + private static HubMessage BindCancelInvocationMessage(string? invocationId) { if (string.IsNullOrEmpty(invocationId)) { throw new InvalidDataException($"Missing required property '{InvocationIdPropertyName}'."); } - return new CancelInvocationMessage(invocationId, sequenceId); + return new CancelInvocationMessage(invocationId); } - private static HubMessage BindCompletionMessage(string invocationId, long? sequenceId, string? error, object? result, bool hasResult) + private static HubMessage BindCompletionMessage(string invocationId, string? error, object? result, bool hasResult) { if (string.IsNullOrEmpty(invocationId)) { @@ -751,13 +751,13 @@ private static HubMessage BindCompletionMessage(string invocationId, long? seque if (hasResult) { - return new CompletionMessage(invocationId, sequenceId, error, result, hasResult: true); + return new CompletionMessage(invocationId, error, result, hasResult: true); } - return new CompletionMessage(invocationId, sequenceId, error, result: null, hasResult: false); + return new CompletionMessage(invocationId, error, result: null, hasResult: false); } - private static HubMessage BindStreamItemMessage(string invocationId, long? sequenceId, object? item, bool hasItem) + private static HubMessage BindStreamItemMessage(string invocationId, object? item, bool hasItem) { if (string.IsNullOrEmpty(invocationId)) { @@ -769,10 +769,10 @@ private static HubMessage BindStreamItemMessage(string invocationId, long? seque throw new InvalidDataException($"Missing required property '{ItemPropertyName}'."); } - return new StreamItemMessage(invocationId, sequenceId, item); + return new StreamItemMessage(invocationId, item); } - private static HubMessage BindStreamInvocationMessage(string? invocationId, long? sequenceId, string target, + private static HubMessage BindStreamInvocationMessage(string? invocationId, string target, object?[]? arguments, bool hasArguments, string[]? streamIds) { if (string.IsNullOrEmpty(invocationId)) @@ -792,10 +792,10 @@ private static HubMessage BindStreamInvocationMessage(string? invocationId, long Debug.Assert(arguments != null); - return new StreamInvocationMessage(invocationId, sequenceId, target, arguments, streamIds); + return new StreamInvocationMessage(invocationId, target, arguments, streamIds); } - private static HubMessage BindInvocationMessage(string? invocationId, long? sequenceId, string target, + private static HubMessage BindInvocationMessage(string? invocationId, string target, object?[]? arguments, bool hasArguments, string[]? streamIds) { if (string.IsNullOrEmpty(target)) @@ -810,7 +810,7 @@ private static HubMessage BindInvocationMessage(string? invocationId, long? sequ Debug.Assert(arguments != null); - return new InvocationMessage(invocationId, sequenceId, target, arguments, streamIds); + return new InvocationMessage(invocationId, target, arguments, streamIds); } private object? BindType(ref Utf8JsonReader reader, ReadOnlySequence input, Type type) diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/CancelInvocationMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/CancelInvocationMessage.cs index 6894b19be801..5cc9dba052a1 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/CancelInvocationMessage.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/CancelInvocationMessage.cs @@ -17,13 +17,4 @@ public class CancelInvocationMessage : HubInvocationMessage public CancelInvocationMessage(string invocationId) : base(invocationId) { } - - /// - /// - /// - /// - /// - public CancelInvocationMessage(string invocationId, long? sequenceId) : base(invocationId, sequenceId) - { - } } diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/CompletionMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/CompletionMessage.cs index 1ea5ce5a1fcd..440e431c6bb8 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/CompletionMessage.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/CompletionMessage.cs @@ -46,28 +46,6 @@ public CompletionMessage(string invocationId, string? error, object? result, boo HasResult = hasResult; } - /// - /// - /// - /// - /// - /// - /// - /// - /// - public CompletionMessage(string invocationId, long? sequenceId, string? error, object? result, bool hasResult) - : base(invocationId, sequenceId) - { - if (error is not null && hasResult) - { - throw new ArgumentException($"Expected either '{nameof(error)}' or '{nameof(result)}' to be provided, but not both"); - } - - Error = error; - Result = result; - HasResult = hasResult; - } - /// public override string ToString() { @@ -104,25 +82,4 @@ public static CompletionMessage WithResult(string invocationId, object? payload) /// The constructed . public static CompletionMessage Empty(string invocationId) => new CompletionMessage(invocationId, error: null, result: null, hasResult: false); - - public static CompletionMessage WithError(string invocationId, long? sequenceId, string? error) - => new CompletionMessage(invocationId, sequenceId, error, result: null, hasResult: false); - - /// - /// Constructs a with a result. - /// - /// The ID of the invocation that is being completed. - /// The result from the invocation. - /// The constructed . - public static CompletionMessage WithResult(string invocationId, long? sequenceId, object? payload) - => new CompletionMessage(invocationId, sequenceId, error: null, result: payload, hasResult: true); - - /// - /// Constructs a without an error or result. - /// This means the invocation was successful but there is no return value. - /// - /// The ID of the invocation that is being completed. - /// The constructed . - public static CompletionMessage Empty(string invocationId, long? sequenceId) - => new CompletionMessage(invocationId, sequenceId, error: null, result: null, hasResult: false); } diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/HubInvocationMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/HubInvocationMessage.cs index 43e6278cfd9a..fd97e969563f 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/HubInvocationMessage.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/HubInvocationMessage.cs @@ -20,11 +20,6 @@ public abstract class HubInvocationMessage : HubMessage /// public string? InvocationId { get; } - /// - /// TODO; monotonically increasing ID that identifies this message - /// - public long? SequenceId { get; } - /// /// Initializes a new instance of the class. /// @@ -33,15 +28,4 @@ protected HubInvocationMessage(string? invocationId) { InvocationId = invocationId; } - - /// - /// TODO - /// - /// - /// - protected HubInvocationMessage(string? invocationId, long? sequenceId) - : this(invocationId) - { - SequenceId = sequenceId; - } } diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/HubMethodInvocationMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/HubMethodInvocationMessage.cs index 5c1d7740bfc4..cc6be99a25ba 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/HubMethodInvocationMessage.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/HubMethodInvocationMessage.cs @@ -27,28 +27,6 @@ public abstract class HubMethodInvocationMessage : HubInvocationMessage /// public string[]? StreamIds { get; } - /// - /// - /// - /// - /// - /// - /// - /// - /// - protected HubMethodInvocationMessage(string? invocationId, long? sequenceId, string target, object?[] arguments, string[]? streamIds) - : base(invocationId, sequenceId) - { - if (string.IsNullOrEmpty(target)) - { - throw new ArgumentNullException(nameof(target)); - } - - Target = target; - Arguments = arguments; - StreamIds = streamIds; - } - /// /// Initializes a new instance of the class. /// @@ -116,11 +94,6 @@ public InvocationMessage(string? invocationId, string target, object?[] argument { } - public InvocationMessage(string? invocationId, long? sequenceId, string target, object?[] arguments, string[]? streamIds) - : base(invocationId, sequenceId, target, arguments, streamIds) - { - } - /// public override string ToString() { @@ -176,19 +149,6 @@ public StreamInvocationMessage(string invocationId, string target, object?[] arg { } - /// - /// - /// - /// - /// - /// - /// - /// - public StreamInvocationMessage(string invocationId, long? sequenceId, string target, object?[] arguments, string[]? streamIds) - : base(invocationId, sequenceId, target, arguments, streamIds) - { - } - /// public override string ToString() { diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/StreamItemMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/StreamItemMessage.cs index 93dd0e1f89ec..dafba133a0f5 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/StreamItemMessage.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/StreamItemMessage.cs @@ -23,17 +23,6 @@ public StreamItemMessage(string invocationId, object? item) : base(invocationId) Item = item; } - /// - /// - /// - /// - /// - /// - public StreamItemMessage(string invocationId, long? sequenceId, object? item) : base(invocationId, sequenceId) - { - Item = item; - } - /// public override string ToString() { diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index ad315c86af9f..f883dfadbd57 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -761,23 +761,28 @@ internal void Ack(AckMessage ackMessage) _messageBuffer.Ack(ackMessage); } - private long? GetSequenceId() - { - if (UseAcks) - { - return Interlocked.Increment(ref _sequenceId); - } - return null; - } + //private long? GetSequenceId() + //{ + // if (UseAcks) + // { + // return Interlocked.Increment(ref _sequenceId); + // } + // return null; + //} + + private long _currentReceivingSequenceId; internal bool ShouldProcessMessage(HubInvocationMessage message) { - if (message.SequenceId <= _latestReceivedSequenceId) + var currentId = _currentReceivingSequenceId; + _currentReceivingSequenceId++; + if (currentId <= _latestReceivedSequenceId) { // Ignore, this is a duplicate message return false; } - _latestReceivedSequenceId = message.SequenceId!.Value; + _latestReceivedSequenceId = currentId; + return true; } } diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 46ea44b5ec2b..85cae246ce3e 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -218,7 +218,7 @@ private Task ProcessInvocationBindingFailure(HubConnectionContext connection, In var errorMessage = ErrorMessageHelper.BuildErrorMessage($"Failed to invoke '{bindingFailureMessage.Target}' due to an error on the server.", bindingFailureMessage.BindingFailure.SourceException, _enableDetailedErrors); - return SendInvocationError(bindingFailureMessage.InvocationId, connection.GetSequenceId(), connection, errorMessage); + return SendInvocationError(bindingFailureMessage.InvocationId, connection, errorMessage); } private Task ProcessStreamBindingFailure(HubConnectionContext connection, StreamBindingFailureMessage bindingFailureMessage) @@ -227,7 +227,7 @@ private Task ProcessStreamBindingFailure(HubConnectionContext connection, Stream "Failed to bind Stream message.", bindingFailureMessage.BindingFailure.SourceException, _enableDetailedErrors); - var message = CompletionMessage.WithError(bindingFailureMessage.Id, connection.GetSequenceId(), errorString); + var message = CompletionMessage.WithError(bindingFailureMessage.Id, errorString); Log.ClosingStreamWithBindingError(_logger, message); // ignore failure, it means the client already completed the stream or the stream never existed on the server @@ -260,7 +260,7 @@ private Task ProcessInvocation(HubConnectionContext connection, { // Send an error to the client. Then let the normal completion process occur return connection.WriteAsync(CompletionMessage.WithError( - hubMethodInvocationMessage.InvocationId, connection.GetSequenceId(), $"Unknown hub method '{hubMethodInvocationMessage.Target}'")).AsTask(); + hubMethodInvocationMessage.InvocationId, $"Unknown hub method '{hubMethodInvocationMessage.Target}'")).AsTask(); } else { @@ -303,7 +303,7 @@ private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionCon if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection, descriptor, hubMethodInvocationMessage.Arguments, hub)) { Log.HubMethodNotAuthorized(_logger, hubMethodInvocationMessage.Target); - await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection.GetSequenceId(), connection, + await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, $"Failed to invoke '{hubMethodInvocationMessage.Target}' because user is unauthorized"); return true; } @@ -321,7 +321,7 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection.Ge { var ex = new HubException($"Client sent {clientStreamLength} stream(s), Hub method expects {serverStreamLength}."); Log.InvalidHubParameters(_logger, hubMethodInvocationMessage.Target, ex); - await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection.GetSequenceId(), connection, + await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, _enableDetailedErrors)); return true; } @@ -374,7 +374,7 @@ static async Task ExecuteInvocation(DefaultHubDispatcher dispatcher, catch (Exception ex) { Log.FailedInvokingHubMethod(logger, hubMethodInvocationMessage.Target, ex); - await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection.GetSequenceId(), connection, + await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, enableDetailedErrors)); return; } @@ -392,7 +392,7 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection.Ge if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId)) { // Invoke Async, one response expected - await connection.WriteAsync(CompletionMessage.WithResult(hubMethodInvocationMessage.InvocationId, connection.GetSequenceId(), result)); + await connection.WriteAsync(CompletionMessage.WithResult(hubMethodInvocationMessage.InvocationId, result)); } } @@ -414,13 +414,13 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection.Ge catch (TargetInvocationException ex) { Log.FailedInvokingHubMethod(_logger, hubMethodInvocationMessage.Target, ex); - await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection.GetSequenceId(), connection, + await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex.InnerException ?? ex, _enableDetailedErrors)); } catch (Exception ex) { Log.FailedInvokingHubMethod(_logger, hubMethodInvocationMessage.Target, ex); - await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection.GetSequenceId(), connection, + await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, _enableDetailedErrors)); } } @@ -446,7 +446,7 @@ private static ValueTask CleanupInvocation(HubConnectionContext connection, HubM { foreach (var stream in hubMessage.StreamIds) { - connection.StreamTracker.TryComplete(CompletionMessage.Empty(stream, connection.GetSequenceId())); + connection.StreamTracker.TryComplete(CompletionMessage.Empty(stream)); } } @@ -495,7 +495,7 @@ private async Task StreamAsync(string invocationId, HubConnectionContext connect await using var enumerator = descriptor.FromReturnedStream(result, streamCts.Token); Log.StreamingResult(_logger, invocationId, descriptor.MethodExecutor); - var streamItemMessage = new StreamItemMessage(invocationId, connection.GetSequenceId(), null); + var streamItemMessage = new StreamItemMessage(invocationId, null); while (await enumerator.MoveNextAsync()) { @@ -526,7 +526,7 @@ private async Task StreamAsync(string invocationId, HubConnectionContext connect streamCts.Dispose(); connection.ActiveRequestCancellationSources.TryRemove(invocationId, out _); - await connection.WriteAsync(CompletionMessage.WithError(invocationId, connection.GetSequenceId(), error)); + await connection.WriteAsync(CompletionMessage.WithError(invocationId, error)); } } @@ -572,15 +572,14 @@ private async Task StreamAsync(string invocationId, HubConnectionContext connect } } - private static async Task SendInvocationError(string? invocationId, long? sequenceId, - HubConnectionContext connection, string errorMessage) + private static async Task SendInvocationError(string? invocationId, HubConnectionContext connection, string errorMessage) { if (string.IsNullOrEmpty(invocationId)) { return; } - await connection.WriteAsync(CompletionMessage.WithError(invocationId, sequenceId, errorMessage)); + await connection.WriteAsync(CompletionMessage.WithError(invocationId, errorMessage)); } private void InitializeHub(THub hub, HubConnectionContext connection, bool invokeAllowed = true) @@ -624,7 +623,7 @@ private async Task ValidateInvocationMode(HubMethodDescriptor hubMethodDes if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId)) { Log.StreamingMethodCalledWithInvoke(_logger, hubMethodInvocationMessage); - await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId, connection.GetSequenceId(), + await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId, $"The client attempted to invoke the streaming '{hubMethodInvocationMessage.Target}' method with a non-streaming invocation.")); } @@ -634,7 +633,7 @@ await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessa if (!hubMethodDescriptor.IsStreamResponse && isStreamResponse) { Log.NonStreamingMethodCalledWithStream(_logger, hubMethodInvocationMessage); - await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId!, connection.GetSequenceId(), + await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId!, $"The client attempted to invoke the non-streaming '{hubMethodInvocationMessage.Target}' method with a streaming invocation.")); return false; diff --git a/src/SignalR/server/Core/src/Internal/MessageBuffer.cs b/src/SignalR/server/Core/src/Internal/MessageBuffer.cs index 4ff10efef53d..431225e0280a 100644 --- a/src/SignalR/server/Core/src/Internal/MessageBuffer.cs +++ b/src/SignalR/server/Core/src/Internal/MessageBuffer.cs @@ -8,13 +8,14 @@ namespace Microsoft.AspNetCore.SignalR.Internal; internal sealed class MessageBuffer { - private readonly (SerializedHubMessage? Message, long? SequenceId)[] _buffer; + private readonly (SerializedHubMessage? Message, long SequenceId)[] _buffer; private int _index; + private long _totalMessageCount; // TODO: pass in limits public MessageBuffer() { - _buffer = new (SerializedHubMessage? Message, long? SequenceId)[10]; + _buffer = new (SerializedHubMessage? Message, long SequenceId)[10]; } public async ValueTask WriteAsync(PipeWriter pipeWriter, SerializedHubMessage hubMessage, IHubProtocol protocol, @@ -30,10 +31,10 @@ public async ValueTask WriteAsync(PipeWriter pipeWriter, Serialized // ... } - long? sequenceId; if (hubMessage.Message is HubInvocationMessage invocationMessage) { - sequenceId = invocationMessage.SequenceId; + //sequenceId = invocationMessage.SequenceId; + _totalMessageCount++; } else { @@ -41,7 +42,7 @@ public async ValueTask WriteAsync(PipeWriter pipeWriter, Serialized return await pipeWriter.WriteAsync(hubMessage.GetSerializedMessage(protocol), cancellationToken); } - _buffer[_index] = (hubMessage, sequenceId); + _buffer[_index] = (hubMessage, _totalMessageCount); _index = (_index + 1) % _buffer.Length; return await pipeWriter.WriteAsync(hubMessage.GetSerializedMessage(protocol), cancellationToken); } @@ -52,9 +53,9 @@ public void Ack(AckMessage ackMessage) for (var i = 0; i < _buffer.Length; i++) { var currentIndex = (index + i) % _buffer.Length; - if (_buffer[currentIndex].SequenceId is long id && id <= ackMessage.SequenceId) + if (_buffer[currentIndex].SequenceId <= ackMessage.SequenceId) { - _buffer[currentIndex] = (null, null); + _buffer[currentIndex] = (null, long.MinValue); } } From a1db3d712462f872033ecef16c2c91749052f0ec Mon Sep 17 00:00:00 2001 From: Brennan Date: Tue, 16 May 2023 09:28:41 -0700 Subject: [PATCH 16/25] stash --- .../csharp/Client.Core/src/HubConnection.cs | 29 ++- .../src/Internal/SerializedHubMessage.cs | 191 ++++++++++++++++++ ...soft.AspNetCore.SignalR.Client.Core.csproj | 1 + .../src/Internal/WebSocketsTransport.cs | 10 +- .../src/Internal/HttpConnectionContext.cs | 14 ++ .../src/Protocol/JsonHubProtocol.cs | 34 ++-- src/SignalR/common/Shared/MessageBuffer.cs | 139 +++++++++++++ .../SignalR.Common/src/Protocol/AckMessage.cs | 10 + .../src/Protocol/HubProtocolConstants.cs | 2 + .../server/Core/src/HubConnectionContext.cs | 15 +- .../Core/src/Internal/DefaultHubDispatcher.cs | 4 + .../server/Core/src/Internal/MessageBuffer.cs | 64 ------ .../Microsoft.AspNetCore.SignalR.Core.csproj | 1 + 13 files changed, 420 insertions(+), 94 deletions(-) create mode 100644 src/SignalR/clients/csharp/Client.Core/src/Internal/SerializedHubMessage.cs create mode 100644 src/SignalR/common/Shared/MessageBuffer.cs delete mode 100644 src/SignalR/server/Core/src/Internal/MessageBuffer.cs diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index 797ecee716c8..b4588c24e2ae 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -79,6 +79,7 @@ public partial class HubConnection : IAsyncDisposable private readonly ReconnectingConnectionState _state; private bool _disposed; + private MessageBuffer _buffer = new(); /// /// Occurs when the connection is closed. The connection could be closed due to an error or due to either the server or client intentionally @@ -946,11 +947,21 @@ private async Task InvokeStreamCore(ConnectionState connectionState, string meth private async Task SendHubMessage(ConnectionState connectionState, HubMessage hubMessage, CancellationToken cancellationToken = default) { _state.AssertConnectionValid(); - _protocol.WriteMessage(hubMessage, connectionState.Connection.Transport.Output); Log.SendingMessage(_logger, hubMessage); - await connectionState.Connection.Transport.Output.FlushAsync(cancellationToken).ConfigureAwait(false); + // TODO + var isAck = true; + if (isAck) + { + await _buffer.WriteAsync(connectionState.Connection.Transport.Output, new SerializedHubMessage(hubMessage), _protocol, cancellationToken).ConfigureAwait(false); + } + else + { + _protocol.WriteMessage(hubMessage, connectionState.Connection.Transport.Output); + + await connectionState.Connection.Transport.Output.FlushAsync(cancellationToken).ConfigureAwait(false); + } Log.MessageSent(_logger, hubMessage); // We've sent a message, so don't ping for a while @@ -1004,6 +1015,14 @@ private async Task SendWithLock(ConnectionState expectedConnectionState, HubMess Log.ResettingKeepAliveTimer(_logger); connectionState.ResetTimeout(); + if (true && message is HubInvocationMessage hubInvocation) + { + if (!_buffer.ShouldProcessMessage(hubInvocation)) + { + return null; + } + } + InvocationRequest? irq; switch (message) { @@ -1055,6 +1074,12 @@ private async Task SendWithLock(ConnectionState expectedConnectionState, HubMess Log.ReceivedPing(_logger); // timeout is reset above, on receiving any message break; + case AckMessage ackMessage: + _buffer.Ack(ackMessage); + break; + case SequenceMessage sequenceMessage: + _buffer.ResetSequence(sequenceMessage); + break; default: throw new InvalidOperationException($"Unexpected message type: {message.GetType().FullName}"); } diff --git a/src/SignalR/clients/csharp/Client.Core/src/Internal/SerializedHubMessage.cs b/src/SignalR/clients/csharp/Client.Core/src/Internal/SerializedHubMessage.cs new file mode 100644 index 000000000000..4fd1d41d5df9 --- /dev/null +++ b/src/SignalR/clients/csharp/Client.Core/src/Internal/SerializedHubMessage.cs @@ -0,0 +1,191 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Protocol; + +namespace Microsoft.AspNetCore.SignalR.Internal; + +/// +/// Represents a serialization cache for a single message. +/// +internal class SerializedHubMessage +{ + private SerializedMessage _cachedItem1; + private SerializedMessage _cachedItem2; + private List? _cachedItems; + private readonly object _lock = new object(); + + /// + /// Gets the hub message for the serialization cache. + /// + public HubMessage? Message { get; } + + /// + /// Initializes a new instance of the class. + /// + /// A collection of already serialized messages to cache. + public SerializedHubMessage(IReadOnlyList messages) + { + // A lock isn't needed here because nobody has access to this type until the constructor finishes. + for (var i = 0; i < messages.Count; i++) + { + var message = messages[i]; + SetCacheUnsynchronized(message.ProtocolName, message.Serialized); + } + } + + /// + /// Initializes a new instance of the class. + /// + /// The hub message for the cache. This will be serialized with an in to get the message's serialized representation. + public SerializedHubMessage(HubMessage message) + { + Message = message; + } + + /// + /// Gets the serialized representation of the using the specified . + /// + /// The protocol used to create the serialized representation. + /// The serialized representation of the . + public ReadOnlyMemory GetSerializedMessage(IHubProtocol protocol) + { + lock (_lock) + { + if (!TryGetCachedUnsynchronized(protocol.Name, out var serialized)) + { + if (Message == null) + { + throw new InvalidOperationException( + "This message was received from another server that did not have the requested protocol available."); + } + + serialized = protocol.GetMessageBytes(Message); + SetCacheUnsynchronized(protocol.Name, serialized); + } + + return serialized; + } + } + + // Used for unit testing. + internal IReadOnlyList GetAllSerializations() + { + // Even if this is only used in tests, let's do it right. + lock (_lock) + { + if (_cachedItem1.ProtocolName == null) + { + return Array.Empty(); + } + + var list = new List(2); + list.Add(_cachedItem1); + + if (_cachedItem2.ProtocolName != null) + { + list.Add(_cachedItem2); + + if (_cachedItems != null) + { + list.AddRange(_cachedItems); + } + } + + return list; + } + } + + private void SetCacheUnsynchronized(string protocolName, ReadOnlyMemory serialized) + { + // We set the fields before moving on to the list, if we need it to hold more than 2 items. + // We have to read/write these fields under the lock because the structs might tear and another + // thread might observe them half-assigned + + if (_cachedItem1.ProtocolName == null) + { + _cachedItem1 = new SerializedMessage(protocolName, serialized); + } + else if (_cachedItem2.ProtocolName == null) + { + _cachedItem2 = new SerializedMessage(protocolName, serialized); + } + else + { + if (_cachedItems == null) + { + _cachedItems = new List(); + } + + foreach (var item in _cachedItems) + { + if (string.Equals(item.ProtocolName, protocolName, StringComparison.Ordinal)) + { + // No need to add + return; + } + } + + _cachedItems.Add(new SerializedMessage(protocolName, serialized)); + } + } + + private bool TryGetCachedUnsynchronized(string protocolName, out ReadOnlyMemory result) + { + if (string.Equals(_cachedItem1.ProtocolName, protocolName, StringComparison.Ordinal)) + { + result = _cachedItem1.Serialized; + return true; + } + + if (string.Equals(_cachedItem2.ProtocolName, protocolName, StringComparison.Ordinal)) + { + result = _cachedItem2.Serialized; + return true; + } + + if (_cachedItems != null) + { + foreach (var serializedMessage in _cachedItems) + { + if (string.Equals(serializedMessage.ProtocolName, protocolName, StringComparison.Ordinal)) + { + result = serializedMessage.Serialized; + return true; + } + } + } + + result = default; + return false; + } +} + +internal readonly struct SerializedMessage +{ + /// + /// Gets the protocol of the serialized message. + /// + public string ProtocolName { get; } + + /// + /// Gets the serialized representation of the message. + /// + public ReadOnlyMemory Serialized { get; } + + /// + /// Initializes a new instance of the class. + /// + /// The protocol of the serialized message. + /// The serialized representation of the message. + public SerializedMessage(string protocolName, ReadOnlyMemory serialized) + { + ProtocolName = protocolName; + Serialized = serialized; + } +} diff --git a/src/SignalR/clients/csharp/Client.Core/src/Microsoft.AspNetCore.SignalR.Client.Core.csproj b/src/SignalR/clients/csharp/Client.Core/src/Microsoft.AspNetCore.SignalR.Client.Core.csproj index 65cf5d649c19..e0cd9f43d95b 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/Microsoft.AspNetCore.SignalR.Client.Core.csproj +++ b/src/SignalR/clients/csharp/Client.Core/src/Microsoft.AspNetCore.SignalR.Client.Core.csproj @@ -14,6 +14,7 @@ + diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs index 0078f421c263..61e3c67727ae 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs @@ -297,11 +297,11 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio { // Create the pipe pair (Application's writer is connected to Transport's reader, and vice versa) DuplexPipePair pair; - if (_useAck) - { - pair = CreateAckConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); - } - else + //if (_useAck) + //{ + // pair = CreateAckConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); + //} + //else { pair = DuplexPipe.CreateConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); } diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs index 70c8a3f341bf..cd139b9cbc8c 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs @@ -540,6 +540,9 @@ internal async Task CancelPreviousPoll(HttpContext context) if (UseAcks && TransportType == HttpTransportType.WebSockets) { Application.Input.CancelPendingRead(); + var prevPipe = Application.Input; + UpdateConnectionPair(); + prevPipe.Complete(new Exception()); } try @@ -642,6 +645,17 @@ public void RequestClose() ThreadPool.UnsafeQueueUserWorkItem(static cts => ((CancellationTokenSource)cts!).Cancel(), _connectionCloseRequested); } + private void UpdateConnectionPair() + { + var input = new Pipe(_options.TransportPipeOptions); + + var transportToApplication = new DuplexPipe(Transport.Input, input.Writer); + var applicationToTransport = new DuplexPipe(input.Reader, Application.Output); + + Application = applicationToTransport; + Transport = transportToApplication; + } + private static partial class Log { [LoggerMessage(1, LogLevel.Trace, "Disposing connection {TransportConnectionId}.", EventName = "DisposingConnection")] diff --git a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs index be03c0d8fd45..01322cd1ecab 100644 --- a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs +++ b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs @@ -461,6 +461,8 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) return BindCloseMessage(error, allowReconnect); case HubProtocolConstants.AckMessageType: return BindAckMessage(sequenceId); + case HubProtocolConstants.SequenceMessageType: + return BindSequenceMessage(sequenceId); case null: throw new InvalidDataException($"Missing required property '{TypePropertyName}'."); default: @@ -557,6 +559,10 @@ private void WriteMessageCore(HubMessage message, IBufferWriter stream) WriteMessageType(writer, HubProtocolConstants.AckMessageType); WriteAckMessage(m, writer); break; + case SequenceMessage m: + WriteMessageType(writer, HubProtocolConstants.SequenceMessageType); + WriteSequenceMessage(m, writer); + break; default: throw new InvalidOperationException($"Unsupported message type: {message.GetType().FullName}"); } @@ -586,7 +592,6 @@ private static void WriteHeaders(Utf8JsonWriter writer, HubInvocationMessage mes private void WriteCompletionMessage(CompletionMessage message, Utf8JsonWriter writer) { WriteInvocationId(message, writer); - WriteSequenceId(message, writer); if (!string.IsNullOrEmpty(message.Error)) { writer.WriteString(ErrorPropertyNameBytes, message.Error); @@ -615,13 +620,11 @@ private void WriteCompletionMessage(CompletionMessage message, Utf8JsonWriter wr private static void WriteCancelInvocationMessage(CancelInvocationMessage message, Utf8JsonWriter writer) { WriteInvocationId(message, writer); - WriteSequenceId(message, writer); } private void WriteStreamItemMessage(StreamItemMessage message, Utf8JsonWriter writer) { WriteInvocationId(message, writer); - WriteSequenceId(message, writer); writer.WritePropertyName(ItemPropertyNameBytes); if (message.Item == null) @@ -637,7 +640,6 @@ private void WriteStreamItemMessage(StreamItemMessage message, Utf8JsonWriter wr private void WriteInvocationMessage(InvocationMessage message, Utf8JsonWriter writer) { WriteInvocationId(message, writer); - WriteSequenceId(message, writer); writer.WriteString(TargetPropertyNameBytes, message.Target); WriteArguments(message.Arguments, writer); @@ -648,7 +650,6 @@ private void WriteInvocationMessage(InvocationMessage message, Utf8JsonWriter wr private void WriteStreamInvocationMessage(StreamInvocationMessage message, Utf8JsonWriter writer) { WriteInvocationId(message, writer); - WriteSequenceId(message, writer); writer.WriteString(TargetPropertyNameBytes, message.Target); WriteArguments(message.Arguments, writer); @@ -674,6 +675,11 @@ private static void WriteAckMessage(AckMessage message, Utf8JsonWriter writer) writer.WriteNumber(SequenceIdPropertyName, message.SequenceId); } + private static void WriteSequenceMessage(SequenceMessage message, Utf8JsonWriter writer) + { + writer.WriteNumber(SequenceIdPropertyName, message.SequenceId); + } + private void WriteArguments(object?[] arguments, Utf8JsonWriter writer) { writer.WriteStartArray(ArgumentsPropertyNameBytes); @@ -714,14 +720,6 @@ private static void WriteInvocationId(HubInvocationMessage message, Utf8JsonWrit } } - private static void WriteSequenceId(HubInvocationMessage message, Utf8JsonWriter writer) - { - if (message.SequenceId is not null) - { - writer.WriteNumber(SequenceIdPropertyNameBytes, message.SequenceId.Value); - } - } - private static void WriteMessageType(Utf8JsonWriter writer, int type) { writer.WriteNumber(TypePropertyNameBytes, type); @@ -896,6 +894,16 @@ private static AckMessage BindAckMessage(long? sequenceId) return new AckMessage(sequenceId.Value); } + private static SequenceMessage BindSequenceMessage(long? sequenceId) + { + if (sequenceId is null) + { + throw new InvalidDataException("Missing 'sequenceId' in Sequence message."); + } + + return new SequenceMessage(sequenceId.Value); + } + private static HubMessage ApplyHeaders(HubMessage message, Dictionary? headers) { if (headers != null && message is HubInvocationMessage invocationMessage) diff --git a/src/SignalR/common/Shared/MessageBuffer.cs b/src/SignalR/common/Shared/MessageBuffer.cs new file mode 100644 index 000000000000..81dc6b449f6b --- /dev/null +++ b/src/SignalR/common/Shared/MessageBuffer.cs @@ -0,0 +1,139 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Protocol; + +namespace Microsoft.AspNetCore.SignalR.Internal; + +internal sealed class MessageBuffer +{ + private readonly (SerializedHubMessage? Message, long SequenceId)[] _buffer; + private int _index; + private long _totalMessageCount; + + // TODO: pass in limits + public MessageBuffer() + { + _buffer = new (SerializedHubMessage? Message, long SequenceId)[10]; + } + + public async ValueTask WriteAsync(PipeWriter pipeWriter, SerializedHubMessage hubMessage, IHubProtocol protocol, + CancellationToken cancellationToken) + { + // No lock because this is always called in a single async loop? + // And other methods don't affect the checks here? + + // TODO: Backpressure + + if (_buffer[_index].Message is not null) + { + // ... + } + + try + { + + if (hubMessage.Message is HubInvocationMessage invocationMessage) + { + _totalMessageCount++; + } + else + { + // Non-ackable message, don't add to buffer + return await pipeWriter.WriteAsync(hubMessage.GetSerializedMessage(protocol), cancellationToken).ConfigureAwait(false); + } + + _buffer[_index] = (hubMessage, _totalMessageCount); + _index = (_index + 1) % _buffer.Length; + return await pipeWriter.WriteAsync(hubMessage.GetSerializedMessage(protocol), cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + // TODO: specific exception or some identifier needed + + // wait for reconnect, send sequencemessage, and then do resend loop + + long latestAckedIndex = -1; + for (var i = 0; i < _buffer.Length - 1; i++) + { + if (_buffer[(_index + i + 1) % _buffer.Length].SequenceId > long.MinValue) + { + latestAckedIndex = (_index + i + 1) % _buffer.Length; + } + } + + if (latestAckedIndex == -1) + { + // no unacked messages, probably not possible + // because we are in the middle of writing a message when we get here, so there should be 1 minimum + } + + protocol.WriteMessage(new SequenceMessage(_buffer[latestAckedIndex].SequenceId), pipeWriter); + await pipeWriter.FlushAsync(cancellationToken).ConfigureAwait(false); + + for (var i = 0; i < _buffer.Length; i++) + { + var item = _buffer[(latestAckedIndex + i) % _buffer.Length]; + if (item.SequenceId > long.MinValue) + { + await pipeWriter.WriteAsync(item.Message!.GetSerializedMessage(protocol), cancellationToken).ConfigureAwait(false); + } + else + { + break; + } + } + + return new FlushResult(isCanceled: false, isCompleted: false); + } + } + + public void Ack(AckMessage ackMessage) + { + var index = _index; + for (var i = 0; i < _buffer.Length; i++) + { + var currentIndex = (index + i) % _buffer.Length; + if (_buffer[currentIndex].SequenceId <= ackMessage.SequenceId) + { + _buffer[currentIndex] = (null, long.MinValue); + } + } + + // Release backpressure? + } + + private long _currentReceivingSequenceId; + private long _latestReceivedSequenceId = long.MinValue; + + internal bool ShouldProcessMessage(HubInvocationMessage message) + { + // TODO: if we're expecting a sequence message but get here we should error + + var currentId = _currentReceivingSequenceId; + _currentReceivingSequenceId++; + if (currentId <= _latestReceivedSequenceId) + { + // Ignore, this is a duplicate message + return false; + } + _latestReceivedSequenceId = currentId; + + return true; + } + + internal void ResetSequence(SequenceMessage sequenceMessage) + { + // TODO: is a sequence message expected right now? + + if (sequenceMessage.SequenceId > _currentReceivingSequenceId) + { + throw new Exception("Sequence ID greater than amount we've acked"); + } + _currentReceivingSequenceId = sequenceMessage.SequenceId; + } +} diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs index 916e8e0a2aba..3065673bfe0c 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs @@ -28,3 +28,13 @@ public AckMessage(long sequenceId) /// public long SequenceId { get; } } + +public sealed class SequenceMessage : HubMessage +{ + public SequenceMessage(long sequenceId) + { + SequenceId = sequenceId; + } + + public long SequenceId { get; } +} diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/HubProtocolConstants.cs b/src/SignalR/common/SignalR.Common/src/Protocol/HubProtocolConstants.cs index 538e07ce0e03..0d32dbc2f235 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/HubProtocolConstants.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/HubProtocolConstants.cs @@ -47,4 +47,6 @@ public static class HubProtocolConstants /// /// public const int AckMessageType = 8; + + public const int SequenceMessageType = 9; } diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index f883dfadbd57..07944315e7ba 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -52,7 +52,6 @@ public partial class HubConnectionContext internal bool UseAcks; private long _sequenceId; - private long _latestReceivedSequenceId = long.MinValue; /// /// Initializes a new instance of the class. @@ -774,15 +773,11 @@ internal void Ack(AckMessage ackMessage) internal bool ShouldProcessMessage(HubInvocationMessage message) { - var currentId = _currentReceivingSequenceId; - _currentReceivingSequenceId++; - if (currentId <= _latestReceivedSequenceId) - { - // Ignore, this is a duplicate message - return false; - } - _latestReceivedSequenceId = currentId; + return _messageBuffer.ShouldProcessMessage(message); + } - return true; + internal void ResetSequence(SequenceMessage sequenceMessage) + { + _messageBuffer.ResetSequence(sequenceMessage); } } diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 85cae246ce3e..1ff4b67d24a7 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -203,6 +203,10 @@ public override Task DispatchMessageAsync(HubConnectionContext connection, HubMe connection.Ack(ackMessage); break; + case SequenceMessage sequenceMessage: + connection.ResetSequence(sequenceMessage); + break; + // Other kind of message we weren't expecting default: Log.UnsupportedMessageReceived(_logger, hubMessage.GetType().FullName!); diff --git a/src/SignalR/server/Core/src/Internal/MessageBuffer.cs b/src/SignalR/server/Core/src/Internal/MessageBuffer.cs deleted file mode 100644 index 431225e0280a..000000000000 --- a/src/SignalR/server/Core/src/Internal/MessageBuffer.cs +++ /dev/null @@ -1,64 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.IO.Pipelines; -using Microsoft.AspNetCore.SignalR.Protocol; - -namespace Microsoft.AspNetCore.SignalR.Internal; - -internal sealed class MessageBuffer -{ - private readonly (SerializedHubMessage? Message, long SequenceId)[] _buffer; - private int _index; - private long _totalMessageCount; - - // TODO: pass in limits - public MessageBuffer() - { - _buffer = new (SerializedHubMessage? Message, long SequenceId)[10]; - } - - public async ValueTask WriteAsync(PipeWriter pipeWriter, SerializedHubMessage hubMessage, IHubProtocol protocol, - CancellationToken cancellationToken) - { - // No lock because this is always called in a single async loop? - // And other methods don't affect the checks here? - - // TODO: Backpressure - - if (_buffer[_index].Message is not null) - { - // ... - } - - if (hubMessage.Message is HubInvocationMessage invocationMessage) - { - //sequenceId = invocationMessage.SequenceId; - _totalMessageCount++; - } - else - { - // Non-ackable message, don't add to buffer - return await pipeWriter.WriteAsync(hubMessage.GetSerializedMessage(protocol), cancellationToken); - } - - _buffer[_index] = (hubMessage, _totalMessageCount); - _index = (_index + 1) % _buffer.Length; - return await pipeWriter.WriteAsync(hubMessage.GetSerializedMessage(protocol), cancellationToken); - } - - public void Ack(AckMessage ackMessage) - { - var index = _index; - for (var i = 0; i < _buffer.Length; i++) - { - var currentIndex = (index + i) % _buffer.Length; - if (_buffer[currentIndex].SequenceId <= ackMessage.SequenceId) - { - _buffer[currentIndex] = (null, long.MinValue); - } - } - - // Release backpressure? - } -} diff --git a/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj b/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj index db4c7d7e81df..b4552ea0c9f6 100644 --- a/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj +++ b/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj @@ -17,6 +17,7 @@ + From a22ad6c3c3f9c5865b9bad787858429dd3e852a0 Mon Sep 17 00:00:00 2001 From: Brennan Conroy Date: Tue, 16 May 2023 15:18:21 -0700 Subject: [PATCH 17/25] stash, but it 'works' --- .../src/Features/IReconnectFeature.cs | 15 + .../csharp/Client.Core/src/HubConnection.cs | 12 +- .../src/HttpConnection.cs | 6 + .../src/Internal/WebSocketsTransport.cs | 102 +-- ....AspNetCore.Http.Connections.Client.csproj | 3 - .../src/Internal/HttpConnectionContext.cs | 18 +- .../src/Internal/HttpConnectionManager.cs | 17 +- .../Transports/WebSocketsServerTransport.cs | 57 -- ...crosoft.AspNetCore.Http.Connections.csproj | 3 - src/SignalR/common/Shared/AckPipeReader.cs | 190 ---- src/SignalR/common/Shared/AckPipeWriter.cs | 107 --- src/SignalR/common/Shared/MessageBuffer.cs | 151 +++- .../common/Shared/ParseAckPipeReader.cs | 176 ---- .../test/Internal/Protocol/AckPipeTests.cs | 830 ------------------ ...oft.AspNetCore.SignalR.Common.Tests.csproj | 3 - .../server/Core/src/HubConnectionContext.cs | 11 +- 16 files changed, 195 insertions(+), 1506 deletions(-) create mode 100644 src/Servers/Connections.Abstractions/src/Features/IReconnectFeature.cs delete mode 100644 src/SignalR/common/Shared/AckPipeReader.cs delete mode 100644 src/SignalR/common/Shared/AckPipeWriter.cs delete mode 100644 src/SignalR/common/Shared/ParseAckPipeReader.cs delete mode 100644 src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs diff --git a/src/Servers/Connections.Abstractions/src/Features/IReconnectFeature.cs b/src/Servers/Connections.Abstractions/src/Features/IReconnectFeature.cs new file mode 100644 index 000000000000..94e30d62366a --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/Features/IReconnectFeature.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Connections.Abstractions; + +public interface IReconnectFeature +{ + public Action NotifyOnReconnect { get; set; } +} diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index b4588c24e2ae..81672ccbd5cd 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -16,6 +16,7 @@ using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Abstractions; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.Shared; @@ -79,7 +80,7 @@ public partial class HubConnection : IAsyncDisposable private readonly ReconnectingConnectionState _state; private bool _disposed; - private MessageBuffer _buffer = new(); + private MessageBuffer? _buffer; /// /// Occurs when the connection is closed. The connection could be closed due to an error or due to either the server or client intentionally @@ -477,6 +478,9 @@ private async Task StartAsyncCore(CancellationToken cancellationToken) var connection = await _connectionFactory.ConnectAsync(_endPoint, cancellationToken).ConfigureAwait(false); var startingConnectionState = new ConnectionState(connection, this); + // TODO: probably go on ConnectionState + _buffer = new MessageBuffer(connection, _protocol); + // From here on, if an error occurs we need to shut down the connection because // we still own it. try @@ -954,7 +958,7 @@ private async Task SendHubMessage(ConnectionState connectionState, HubMessage hu var isAck = true; if (isAck) { - await _buffer.WriteAsync(connectionState.Connection.Transport.Output, new SerializedHubMessage(hubMessage), _protocol, cancellationToken).ConfigureAwait(false); + await _buffer.WriteAsync(new SerializedHubMessage(hubMessage), _protocol, cancellationToken).ConfigureAwait(false); } else { @@ -1260,6 +1264,10 @@ private async Task HandshakeAsync(ConnectionState startingConnectionState, Cance } Log.HandshakeComplete(_logger); + + var f = startingConnectionState.Connection.Features.Get(); + f.NotifyOnReconnect = _buffer.Resend; + break; } } diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs index 18115c572434..aa08bfd42e77 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs @@ -11,6 +11,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Abstractions; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Connections.Client.Internal; using Microsoft.AspNetCore.Http.Features; @@ -530,6 +531,11 @@ private async Task StartTransport(Uri connectUrl, HttpTransportType transportTyp // We successfully started, set the transport properties (we don't want to set these until the transport is definitely running). _transport = transport; + if (_httpConnectionOptions.UseAcks && _transport is IReconnectFeature reconnectFeature) + { + Features.Set(reconnectFeature); + } + Log.TransportStarted(_logger, transportType); } diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs index 61e3c67727ae..edbe3c9d2c5b 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs @@ -17,6 +17,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Abstractions; using Microsoft.AspNetCore.Shared; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -24,7 +25,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal; -internal sealed partial class WebSocketsTransport : ITransport +internal sealed partial class WebSocketsTransport : ITransport, IReconnectFeature { private WebSocket? _webSocket; private IDuplexPipe? _application; @@ -41,6 +42,7 @@ internal sealed partial class WebSocketsTransport : ITransport // Used for reconnect (when enabled) to determine if the close was ungraceful or not, reconnect only happens on ungraceful disconnect // The assumption is that a graceful close was triggered purposefully by either the client or server and a reconnect shouldn't occur private bool _gracefulClose; + private Action? _notifyOnReconnect; internal Task Running { get; private set; } = Task.CompletedTask; @@ -48,6 +50,8 @@ internal sealed partial class WebSocketsTransport : ITransport public PipeWriter Output => _transport!.Output; + public Action NotifyOnReconnect { get => _notifyOnReconnect is not null ? _notifyOnReconnect : () => { }; set => _notifyOnReconnect = value; } + public WebSocketsTransport(HttpConnectionOptions httpConnectionOptions, ILoggerFactory loggerFactory, Func> accessTokenProvider, HttpClient? httpClient, bool useAck = false) { @@ -296,92 +300,15 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio if (_transport is null) { // Create the pipe pair (Application's writer is connected to Transport's reader, and vice versa) - DuplexPipePair pair; - //if (_useAck) - //{ - // pair = CreateAckConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); - //} - //else - { - pair = DuplexPipe.CreateConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); - } + var pair = CreateConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); _transport = pair.Transport; _application = pair.Application; } - else - { - if (_application!.Input is AckPipeReader reader) - { - if (reader.Resend()) - { - // Start reconnect ack handshake - // 1. Send ack ID to server for last message we recieved from server before we disconnected - // 2. Read from server to get the last ack ID it received before we disconnecting - // 3. Resume normal send/receive loops - - ignoreFirstCanceled = true; - var buf = new byte[AckPipeWriter.FrameHeaderSize]; - AckPipeWriter.WriteFrame(buf.AsSpan(), 0, ((AckPipeWriter)_transport.Output).LastAck); - await _webSocket.SendAsync(new ArraySegment(buf, 0, AckPipeWriter.FrameHeaderSize), _webSocketMessageType, true, _stopCts.Token).ConfigureAwait(false); - - Array.Clear(buf, 0, buf.Length); - // server sends 0 length, but with latest ack, so there shouldn't be more than a frame of data sent - var readLength = 0; - WebSocketReceiveResult? receiveResult; - do - { - receiveResult = await _webSocket.ReceiveAsync(new ArraySegment(buf, readLength, AckPipeWriter.FrameHeaderSize - readLength), _stopCts.Token).ConfigureAwait(false); - readLength += receiveResult.Count; - } while (readLength < AckPipeWriter.FrameHeaderSize && !receiveResult.EndOfMessage); - - if (readLength != AckPipeWriter.FrameHeaderSize) - { - _application.Output.Complete(new InvalidDataException("WebSocket reconnect handshake received less data than expected.")); - _application.Input.Complete(); - return; - } - - if (!receiveResult.EndOfMessage) - { - _application.Output.Complete(new InvalidDataException("WebSocket reconnect handshake received more data than expected.")); - _application.Input.Complete(); - return; - } - - // Parsing ack id and updating reader here avoids issue where we send to server before receive loop runs, which is what normally updates ack - // This avoids resending data that was already acked - var parsedLen = ParseAckPipeReader.ParseFrame(new ReadOnlySequence(buf), reader); - - // TODO: why do we need to unblock the receive loop to not delay/block shutdown sometimes? - // Looks like calling stop/dispose on the client doesn't avoid the reconnect cycle in the transport, we'll need to fix that - var flushResult = await _application.Output.FlushAsync(default).ConfigureAwait(false); - } - } - else - { - Debug.Assert(false); - } - } // TODO: Handle TCP connection errors // https://github.com/SignalR/SignalR/blob/1fba14fa3437e24c204dfaf8a18db3fce8acad3c/src/Microsoft.AspNet.SignalR.Core/Owin/WebSockets/WebSocketHandler.cs#L248-L251 Running = ProcessSocketAsync(_webSocket, url, ignoreFirstCanceled); - - static DuplexPipePair CreateAckConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) - { - var input = new Pipe(inputOptions); - var output = new Pipe(outputOptions); - - // Use for one side only, i.e. server - var ackWriter = new AckPipeWriter(output); - var ackReader = new AckPipeReader(output); - var transportReader = new ParseAckPipeReader(input.Reader, ackWriter, ackReader); - var transportToApplication = new DuplexPipe(ackReader, input.Writer); - var applicationToTransport = new DuplexPipe(transportReader, ackWriter); - - return new DuplexPipePair(applicationToTransport, transportToApplication); - } } private async Task ProcessSocketAsync(WebSocket socket, Uri url, bool ignoreFirstCanceled) @@ -439,6 +366,7 @@ private async Task ProcessSocketAsync(WebSocket socket, Uri url, bool ignoreFirs if (_useAck && !_gracefulClose) { + UpdateConnectionPair(); await StartAsync(url, _webSocketMessageType == WebSocketMessageType.Binary ? TransferFormat.Binary : TransferFormat.Text, default).ConfigureAwait(false); } } @@ -698,4 +626,20 @@ public async Task StopAsync() Log.TransportStopped(_logger, null); } + + private void UpdateConnectionPair() + { + var prevPipe = _application!.Input; + var input = new Pipe(_httpConnectionOptions.TransportPipeOptions); + + var transportToApplication = new DuplexPipe(_transport!.Input, input.Writer); + var applicationToTransport = new DuplexPipe(input.Reader, _application!.Output); + + _application = applicationToTransport; + _transport = transportToApplication; + + prevPipe.Complete(new Exception()); + + _notifyOnReconnect.Invoke(); + } } diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Microsoft.AspNetCore.Http.Connections.Client.csproj b/src/SignalR/clients/csharp/Http.Connections.Client/src/Microsoft.AspNetCore.Http.Connections.Client.csproj index d6204265fcb0..e79a1fd7bdba 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Microsoft.AspNetCore.Http.Connections.Client.csproj +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Microsoft.AspNetCore.Http.Connections.Client.csproj @@ -11,9 +11,6 @@ - - - diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs index cd139b9cbc8c..45670068696d 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs @@ -7,6 +7,7 @@ using System.Security.Claims; using System.Security.Principal; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Abstractions; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Connections.Features; using Microsoft.AspNetCore.Http.Connections.Internal.Transports; @@ -18,6 +19,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal; +internal sealed class Reconnect : IReconnectFeature +{ + public Action NotifyOnReconnect { get; set; } +} + internal sealed partial class HttpConnectionContext : ConnectionContext, IConnectionIdFeature, IConnectionItemsFeature, @@ -92,6 +98,12 @@ public HttpConnectionContext(string connectionId, string connectionToken, ILogge Features.Set(this); Features.Set(this); + if (useAcks) + { + var reconnectFeature = new Reconnect(); + Features.Set(reconnectFeature); + } + _connectionClosedTokenSource = new CancellationTokenSource(); ConnectionClosed = _connectionClosedTokenSource.Token; @@ -540,9 +552,7 @@ internal async Task CancelPreviousPoll(HttpContext context) if (UseAcks && TransportType == HttpTransportType.WebSockets) { Application.Input.CancelPendingRead(); - var prevPipe = Application.Input; UpdateConnectionPair(); - prevPipe.Complete(new Exception()); } try @@ -647,6 +657,7 @@ public void RequestClose() private void UpdateConnectionPair() { + var prevPipe = Application.Input; var input = new Pipe(_options.TransportPipeOptions); var transportToApplication = new DuplexPipe(Transport.Input, input.Writer); @@ -654,6 +665,9 @@ private void UpdateConnectionPair() Application = applicationToTransport; Transport = transportToApplication; + + prevPipe.Complete(new Exception()); + Features.GetRequiredFeature().NotifyOnReconnect?.Invoke(); } private static partial class Log diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs index 742e990f1f7e..3086ee363ba1 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs @@ -91,27 +91,12 @@ internal HttpConnectionContext CreateConnection(HttpConnectionDispatcherOptions HttpConnectionsEventSource.Log.ConnectionStart(id); _metrics.ConnectionStart(metricsContext); - var pair = DuplexPipe.CreateConnectionPair(options.TransportPipeOptions, options.AppPipeOptions); + var pair = CreateConnectionPair(options.TransportPipeOptions, options.AppPipeOptions); var connection = new HttpConnectionContext(id, connectionToken, _connectionLogger, metricsContext, pair.Application, pair.Transport, options, useAck); _connections.TryAdd(connectionToken, (connection, startTimestamp)); return connection; - - static DuplexPipePair CreateAckConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) - { - var input = new Pipe(inputOptions); - var output = new Pipe(outputOptions); - - // Use for one side only, i.e. server - var ackWriterApp = new AckPipeWriter(output); - var ackReader = new AckPipeReader(output); - var transportReader = new ParseAckPipeReader(input.Reader, ackWriterApp, ackReader); - var transportToApplication = new DuplexPipe(ackReader, input.Writer); - var applicationToTransport = new DuplexPipe(transportReader, ackWriterApp); - - return new DuplexPipePair(transportToApplication, applicationToTransport); - } } public void RemoveConnection(string id, HttpTransportType transportType, HttpConnectionStopStatus status) diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs index 202ec9df9697..693f8b2fa4f7 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs @@ -61,63 +61,6 @@ public async Task ProcessRequestAsync(HttpContext context, CancellationTok public async Task ProcessSocketAsync(WebSocket socket) { var ignoreFirstCancel = false; - if (_application.Input is AckPipeReader reader) - { - _aborted = false; - // TODO: why is this needed on initial connection start, ideally should be in if condition below - ignoreFirstCancel = true; - // TODO: check if the pipe was used previously? - // Currently checked in Resend - if (reader.Resend()) - { - // Start reconnect ack handshake - // 1. Read from client to get the last ack ID it received before disconnecting - // 2. Send ack ID to client for last message we received from client before it disconnected - // 3. Resume normal send/receive loops - - var buf = new byte[AckPipeWriter.FrameHeaderSize]; - WebSocketReceiveResult? res; - var readLength = 0; - do - { - res = await socket.ReceiveAsync(new ArraySegment(buf, readLength, AckPipeWriter.FrameHeaderSize - readLength), _connection.Cancellation?.Token ?? default); - readLength += res.Count; - } while (readLength < AckPipeWriter.FrameHeaderSize && !res.EndOfMessage); - - if (readLength != AckPipeWriter.FrameHeaderSize) - { - _application.Output.Complete(new InvalidDataException("WebSocket reconnect handshake received less data than expected.")); - _application.Input.Complete(); - return; - } - - if (!res.EndOfMessage) - { - _application.Output.Complete(new InvalidDataException("WebSocket reconnect handshake received more data than expected.")); - _application.Input.Complete(); - return; - } - - // Needed so that the readers ack position gets updated and we don't re-send messages to client - // Normally this would be done by the HubConnectionHandler loop, but that requires a new message to be read - // so we instead make sure it's updated immediately here - var parsedLen = ParseAckPipeReader.ParseFrame(new ReadOnlySequence(buf), reader); - Debug.Assert(parsedLen == 0); - // we don't need to write to the pipe if we parse the frame? - //await _application.Output.WriteAsync(buf); - - var webSocketMessageType = (_connection.ActiveFormat == TransferFormat.Binary - ? WebSocketMessageType.Binary - : WebSocketMessageType.Text); - Array.Clear(buf); - Debug.Assert(_connection.Transport.Output is AckPipeWriter); - AckPipeWriter.WriteFrame(buf.AsSpan(), 0, ((AckPipeWriter)_connection.Transport.Output).LastAck); - _connection.StartSendCancellation(); - // send without going through the Pipe, we don't treat this as an ackable message - await socket.SendAsync(buf, webSocketMessageType, endOfMessage: true, _connection.SendingToken); - _connection.StopSendCancellation(); - } - } var receiving = StartReceiving(socket); var sending = StartSending(socket, ignoreFirstCancel); diff --git a/src/SignalR/common/Http.Connections/src/Microsoft.AspNetCore.Http.Connections.csproj b/src/SignalR/common/Http.Connections/src/Microsoft.AspNetCore.Http.Connections.csproj index 10c11fbe5ac6..e6ee74ec7735 100644 --- a/src/SignalR/common/Http.Connections/src/Microsoft.AspNetCore.Http.Connections.csproj +++ b/src/SignalR/common/Http.Connections/src/Microsoft.AspNetCore.Http.Connections.csproj @@ -17,9 +17,6 @@ - - - diff --git a/src/SignalR/common/Shared/AckPipeReader.cs b/src/SignalR/common/Shared/AckPipeReader.cs deleted file mode 100644 index 081aef7a3aa0..000000000000 --- a/src/SignalR/common/Shared/AckPipeReader.cs +++ /dev/null @@ -1,190 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Diagnostics; -using System.IO.Pipelines; -using System.Threading.Tasks; -using System.Threading; -using System; - -#nullable enable - -namespace Microsoft.AspNetCore.Http.Connections; - -// Wrapper around a PipeReader that adds an Ack position which replaces Consumed -// This allows the underlying pipe to keep un-acked data in the pipe while still providing only new data to the reader -internal sealed class AckPipeReader : PipeReader -{ - private readonly PipeReader _inner; - private readonly object _lock = new object(); - - private SequencePosition _consumed; - private SequencePosition _ackPosition; - private long _ackDiff; - private long _ackId; - private long _totalWritten; - private bool _resend; - - // Accept Pipe instead of PipeReader because we don't want custom pipe implementations to be used with this type - // and Pipe is sealed so a custom one can't be provided - // We rely on undefined implementation details of the default Pipe - public AckPipeReader(Pipe innerPipe) - { - _inner = innerPipe.Reader; - } - - // Update the ack position. This number includes the framing size. - // If byteID is larger than the total bytes sent, it'll throw InvalidOperationException. - public void Ack(long byteID) - { - lock (_lock) - { - //Debug.Assert(_ackDiff == 0); - // ignore? Is this a bad state? - if (byteID < _ackId) - { - return; - } - _ackDiff = byteID - _ackId; - - if (_totalWritten < byteID) - { - Throw(byteID, _totalWritten); - static void Throw(long id, long total) - { - throw new InvalidOperationException($"Ack ID '{id}' is greater than total amount of '{total}' bytes that have been sent."); - } - } - } - } - - public bool Resend() - { - // TODO: Do we need to check this? - Debug.Assert(_resend == false); - if (_totalWritten == 0) - { - return false; - } - // Unblocks ReadAsync and gives a buffer with the examined but not consumed bytes - // This avoids the issue where we have to wait for someone to write to the pipe before - // the receive loop will see what might have been written during disconnect - CancelPendingRead(); - _resend = true; - return true; - } - - public override void AdvanceTo(SequencePosition consumed) - { - AdvanceTo(consumed, consumed); - } - - public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) - { - _consumed = consumed; - // Consumed stays at the ack positions, we store the passed in consumed value for use in ReadAsync so we can give the user only new data - _inner.AdvanceTo(_ackPosition, examined); - - if (_consumed.Equals(_ackPosition)) - { - // Reset to default, we check this in ReadAsync to know if we should provide the current read buffer to the user - // Or slice to the consumed position - _consumed = default; - _ackPosition = default; - } - } - - public override void CancelPendingRead() - { - _inner.CancelPendingRead(); - } - - public override void Complete(Exception? exception = null) - { - _inner.Complete(exception); - } - - public override async ValueTask ReadAsync(CancellationToken cancellationToken = default) - { - var res = await _inner.ReadAsync(cancellationToken).ConfigureAwait(false); - var buffer = res.Buffer; - - lock (_lock) - { - if (_ackDiff != 0) - { - // This detects the odd scenario where _consumed points to the end of a Segment and buffer.Slice(_ackDiff) points to the beginning of the next Segment - // While they technically point to different positions, they point to the same concept of "beginning of the next buffer" - var ackSlice = buffer.Slice(_ackDiff); - if (buffer.Slice(_consumed).First.Length == 0 && ackSlice.Start.GetInteger() == 0) - { - // Fix consumed to point to the beginning of the next Segment - _consumed = ackSlice.Start; - } - else if (!_consumed.Equals(default)) - { - var consumedLength = buffer.Slice(_consumed).Length; - if (consumedLength == ackSlice.Length) - { - _consumed = default; - } - else if (consumedLength > ackSlice.Length) - { - // ack is greater than consumed, should not be possible - - // TODO: verify that if ack is less than total but more than consumed this isn't hit - // e.g. 13 bytes in underlying pipe, only consumed 11 during Read+Advance. Will an ack id of 12 be allowed? - Debug.Assert(false); - } - else if (consumedLength < ackSlice.Length) - { - // this is normal, ack id is less than total written - } - } - - buffer = ackSlice; - _ackId += _ackDiff; - _ackDiff = 0; - _ackPosition = buffer.Start; - } - } - - // Slice consumed, unless resending, then slice to ackPosition - if (_resend) - { - _resend = false; - if (buffer.Length != 0 && !_ackPosition.Equals(default)) - { - buffer = buffer.Slice(_ackPosition); - } - // update total written if there is more written to the pipe during a reconnect - // TODO: add tests for both these paths - if (!_consumed.Equals(default)) - { - Debug.Assert(buffer.Length - buffer.Slice(_consumed).Length >= 0); - _totalWritten += buffer.Length - buffer.Slice(_consumed).Length; - } - else - { - _totalWritten += buffer.Length; - } - } - else if (buffer.Length > 0) - { - _ackPosition = buffer.Start; - if (!_consumed.Equals(default)) - { - buffer = buffer.Slice(_consumed); - } - _totalWritten += (uint)buffer.Length; - } - - res = new(buffer, res.IsCanceled, res.IsCompleted); - return res; - } - - public override bool TryRead(out ReadResult result) - { - throw new NotImplementedException(); - } -} diff --git a/src/SignalR/common/Shared/AckPipeWriter.cs b/src/SignalR/common/Shared/AckPipeWriter.cs deleted file mode 100644 index 6dff353ebffe..000000000000 --- a/src/SignalR/common/Shared/AckPipeWriter.cs +++ /dev/null @@ -1,107 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Buffers.Text; -using System.Buffers; -using System.Diagnostics; -using System.IO.Pipelines; -using System.Threading.Tasks; -using System.Threading; -using System; -using System.Buffers.Binary; - -#nullable enable - -namespace Microsoft.AspNetCore.Http.Connections; - -// Wrapper around a PipeWriter that adds framing to writes -internal sealed class AckPipeWriter : PipeWriter -{ - public const int FrameHeaderSize = 24; - private readonly PipeWriter _inner; - internal long LastAck; - - Memory _frameHeader; - bool _shouldAdvanceFrameHeader; - private long _buffered; - - // Accept Pipe instead of PipeWriter because we don't want custom pipe implementations to be used with this type - // and Pipe is sealed so a custom one can't be provided - // We rely on undefined implementation details of the default Pipe - public AckPipeWriter(Pipe innerPipe) - { - _inner = innerPipe.Writer; - } - - public override void Advance(int bytes) - { - _buffered += bytes; - if (_shouldAdvanceFrameHeader) - { - bytes += FrameHeaderSize; - _shouldAdvanceFrameHeader = false; - } - _inner.Advance(bytes); - } - - public override void CancelPendingFlush() - { - _inner.CancelPendingFlush(); - } - - public override void Complete(Exception? exception = null) - { - _inner.Complete(exception); - } - - // TODO: We could reduce this to 16 bytes for binary transports and avoid the base64 encode/decode - // TODO: We could also reduce this to 1 + 12 (or 8) bytes occasionally if we add a flag for no new ack ID and avoid sending an ack - // X - 12 byte - size of payload as long and base64 encoded - // Y - 12 byte - number of acked bytes as long and base64 encoded - // Z - payload - // [ XXXX YYYY ZZZZ ] - public override ValueTask FlushAsync(CancellationToken cancellationToken = default) - { - Debug.Assert(_frameHeader.Length >= FrameHeaderSize); - - WriteFrame(_frameHeader.Span, _buffered, LastAck); - - _frameHeader = Memory.Empty; - _buffered = 0; - return _inner.FlushAsync(cancellationToken); - } - - public override Memory GetMemory(int sizeHint = 0) - { - var segment = _inner.GetMemory(Math.Max(FrameHeaderSize + 1, sizeHint)); - if (_frameHeader.IsEmpty || _buffered == 0) - { - Debug.Assert(segment.Length > FrameHeaderSize); - - _frameHeader = segment.Slice(0, FrameHeaderSize); - segment = segment.Slice(FrameHeaderSize); - _shouldAdvanceFrameHeader = true; - } - return segment; - } - - public override Span GetSpan(int sizeHint = 0) - { - return GetMemory(sizeHint).Span; - } - - public static void WriteFrame(Span header, long length, long ack) - { - Debug.Assert(header.Length >= FrameHeaderSize); - - BinaryPrimitives.WriteInt64LittleEndian(header, length); - var status = Base64.EncodeToUtf8InPlace(header, 8, out var written); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(written == 12); - - BinaryPrimitives.WriteInt64LittleEndian(header.Slice(12), ack); - status = Base64.EncodeToUtf8InPlace(header.Slice(12), 8, out written); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(written == 12); - } -} diff --git a/src/SignalR/common/Shared/MessageBuffer.cs b/src/SignalR/common/Shared/MessageBuffer.cs index 81dc6b449f6b..16aa345debd7 100644 --- a/src/SignalR/common/Shared/MessageBuffer.cs +++ b/src/SignalR/common/Shared/MessageBuffer.cs @@ -5,6 +5,7 @@ using System.IO.Pipelines; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.SignalR.Protocol; namespace Microsoft.AspNetCore.SignalR.Internal; @@ -12,20 +13,34 @@ namespace Microsoft.AspNetCore.SignalR.Internal; internal sealed class MessageBuffer { private readonly (SerializedHubMessage? Message, long SequenceId)[] _buffer; + private readonly ConnectionContext _connection; + private readonly IHubProtocol _protocol; + private int _index; private long _totalMessageCount; + private TaskCompletionSource _resend = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + // TODO: pass in limits - public MessageBuffer() + public MessageBuffer(ConnectionContext connection, IHubProtocol protocol) { _buffer = new (SerializedHubMessage? Message, long SequenceId)[10]; + for (var i = 0; i < _buffer.Length; i++) + { + _buffer[i].SequenceId = long.MinValue; + } + _connection = connection; + _protocol = protocol; + + _resend.SetResult(new()); } - public async ValueTask WriteAsync(PipeWriter pipeWriter, SerializedHubMessage hubMessage, IHubProtocol protocol, + public async ValueTask WriteAsync(SerializedHubMessage hubMessage, IHubProtocol protocol, CancellationToken cancellationToken) { // No lock because this is always called in a single async loop? // And other methods don't affect the checks here? + // Sending ping does hit this method, but it shouldn't modify any state // TODO: Backpressure @@ -34,6 +49,8 @@ public async ValueTask WriteAsync(PipeWriter pipeWriter, Serialized // ... } + await _resend.Task.ConfigureAwait(false); + try { @@ -44,51 +61,56 @@ public async ValueTask WriteAsync(PipeWriter pipeWriter, Serialized else { // Non-ackable message, don't add to buffer - return await pipeWriter.WriteAsync(hubMessage.GetSerializedMessage(protocol), cancellationToken).ConfigureAwait(false); + return await _connection.Transport.Output.WriteAsync(hubMessage.GetSerializedMessage(protocol), cancellationToken).ConfigureAwait(false); } _buffer[_index] = (hubMessage, _totalMessageCount); _index = (_index + 1) % _buffer.Length; - return await pipeWriter.WriteAsync(hubMessage.GetSerializedMessage(protocol), cancellationToken).ConfigureAwait(false); + return await _connection.Transport.Output.WriteAsync(hubMessage.GetSerializedMessage(protocol), cancellationToken).ConfigureAwait(false); } catch (Exception ex) { // TODO: specific exception or some identifier needed - // wait for reconnect, send sequencemessage, and then do resend loop - - long latestAckedIndex = -1; - for (var i = 0; i < _buffer.Length - 1; i++) - { - if (_buffer[(_index + i + 1) % _buffer.Length].SequenceId > long.MinValue) - { - latestAckedIndex = (_index + i + 1) % _buffer.Length; - } - } + // wait for reconnect, send SequenceMessage, and then do resend loop - if (latestAckedIndex == -1) + var oldTcs = Interlocked.Exchange(ref _resend, new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously)); + if (!oldTcs.Task.IsCompleted) { - // no unacked messages, probably not possible - // because we are in the middle of writing a message when we get here, so there should be 1 minimum + return await oldTcs.Task.ConfigureAwait(false); } - - protocol.WriteMessage(new SequenceMessage(_buffer[latestAckedIndex].SequenceId), pipeWriter); - await pipeWriter.FlushAsync(cancellationToken).ConfigureAwait(false); - - for (var i = 0; i < _buffer.Length; i++) - { - var item = _buffer[(latestAckedIndex + i) % _buffer.Length]; - if (item.SequenceId > long.MinValue) - { - await pipeWriter.WriteAsync(item.Message!.GetSerializedMessage(protocol), cancellationToken).ConfigureAwait(false); - } - else - { - break; - } - } - - return new FlushResult(isCanceled: false, isCompleted: false); + return await _resend.Task.ConfigureAwait(false); + + //long latestAckedIndex = -1; + //for (var i = 0; i < _buffer.Length - 1; i++) + //{ + // if (_buffer[(_index + i + 1) % _buffer.Length].SequenceId > long.MinValue) + // { + // latestAckedIndex = (_index + i + 1) % _buffer.Length; + // } + //} + + //if (latestAckedIndex == -1) + //{ + // // no unacked messages, probably not possible + // // because we are in the middle of writing a message when we get here, so there should be 1 minimum + //} + + //protocol.WriteMessage(new SequenceMessage(_buffer[latestAckedIndex].SequenceId), connection.Transport.Output); + //await connection.Transport.Output.FlushAsync(cancellationToken).ConfigureAwait(false); + + //for (var i = 0; i < _buffer.Length; i++) + //{ + // var item = _buffer[(latestAckedIndex + i) % _buffer.Length]; + // if (item.SequenceId > long.MinValue) + // { + // await connection.Transport.Output.WriteAsync(item.Message!.GetSerializedMessage(protocol), cancellationToken).ConfigureAwait(false); + // } + // else + // { + // break; + // } + //} } } @@ -136,4 +158,63 @@ internal void ResetSequence(SequenceMessage sequenceMessage) } _currentReceivingSequenceId = sequenceMessage.SequenceId; } + + internal void Resend() + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var oldTcs = Interlocked.Exchange(ref _resend, tcs); + if (!oldTcs.Task.IsCompleted) + { + Interlocked.Exchange(ref _resend, oldTcs); + tcs = oldTcs; + } + _ = DoResendAsync(tcs); + } + + private async Task DoResendAsync(TaskCompletionSource tcs) + { + long latestAckedIndex = -1; + for (var i = 0; i < _buffer.Length - 1; i++) + { + if (_buffer[(_index + i + 1) % _buffer.Length].SequenceId > long.MinValue) + { + latestAckedIndex = (_index + i + 1) % _buffer.Length; + break; + } + } + + if (latestAckedIndex == -1) + { + // no unacked messages, probably not possible + // because we are in the middle of writing a message when we get here, so there should be 1 minimum + } + + FlushResult finalResult = new(); + try + { + _protocol.WriteMessage(new SequenceMessage(_buffer[latestAckedIndex].SequenceId), _connection.Transport.Output); + finalResult = await _connection.Transport.Output.FlushAsync().ConfigureAwait(false); + + for (var i = 0; i < _buffer.Length; i++) + { + var item = _buffer[(latestAckedIndex + i) % _buffer.Length]; + if (item.SequenceId > long.MinValue) + { + finalResult = await _connection.Transport.Output.WriteAsync(item.Message!.GetSerializedMessage(_protocol)).ConfigureAwait(false); + } + else + { + break; + } + } + } + catch (Exception ex) + { + tcs.SetException(ex); + } + finally + { + tcs.TrySetResult(finalResult); + } + } } diff --git a/src/SignalR/common/Shared/ParseAckPipeReader.cs b/src/SignalR/common/Shared/ParseAckPipeReader.cs deleted file mode 100644 index 390bee81f5b1..000000000000 --- a/src/SignalR/common/Shared/ParseAckPipeReader.cs +++ /dev/null @@ -1,176 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System; -using System.Buffers; -using System.Buffers.Binary; -using System.Buffers.Text; -using System.Diagnostics; -using System.IO.Pipelines; -using System.Threading; -using System.Threading.Tasks; - -#nullable enable - -namespace Microsoft.AspNetCore.Http.Connections; - -// Read from "network" -// Parse framing and slice the read so the application doesn't see the framing -// Notify outbound pipe of framing details for when sending back -// Notify application pipe of ack id provided by other side of the network -internal sealed class ParseAckPipeReader : PipeReader -{ - private const int FrameHeaderSize = 24; - private readonly PipeReader _inner; - private readonly AckPipeWriter _ackPipeWriter; - private readonly AckPipeReader _ackPipeReader; - private long _totalBytes; - private long _remaining; - - private ReadOnlySequence _currentRead; - - public ParseAckPipeReader(PipeReader inner, AckPipeWriter ackPipeWriter, AckPipeReader ackPipeReader) - { - _inner = inner; - _ackPipeWriter = ackPipeWriter; - _ackPipeReader = ackPipeReader; - } - - public override void AdvanceTo(SequencePosition consumed) - { - CommonAdvance(ref consumed); - _inner.AdvanceTo(consumed); - } - - public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) - { - CommonAdvance(ref consumed); - _inner.AdvanceTo(consumed, examined); - } - - private void CommonAdvance(ref SequencePosition consumed) - { - // Get the number of bytes consumed to update our internal state - var len = _currentRead.Length; - // This is used by ReadAsync to help update the ack id - _currentRead = _currentRead.Slice(consumed); - len -= _currentRead.Length; - - _remaining -= len; - } - - public override void CancelPendingRead() - { - _inner.CancelPendingRead(); - } - - public override void Complete(Exception? exception = null) - { - _inner.Complete(exception); - } - - public override async ValueTask ReadAsync(CancellationToken cancellationToken = default) - { - var res = await _inner.ReadAsync(cancellationToken).ConfigureAwait(false); - try - { - var newBytes = res.Buffer.Length - _currentRead.Length; - _currentRead = res.Buffer; - - if (res.IsCompleted || res.IsCanceled) - { - // TODO: figure out behavior - if (res.Buffer.Length >= FrameHeaderSize) - { - res = new(res.Buffer.Slice(FrameHeaderSize), res.IsCanceled, res.IsCompleted); - } - return res; - } - - ReadOnlySequence buffer = res.Buffer; - if (_remaining == 0) - { - // TODO: didn't get 24 bytes - var frame = buffer.Slice(0, FrameHeaderSize); - var len = ParseFrame(frame, _ackPipeReader); - _totalBytes += len; - - _remaining = len; - - // if the buffer doesn't have enough data we need to update how much we're slicing - if (len > buffer.Length - FrameHeaderSize) - { - len = buffer.Length - FrameHeaderSize; - } - - buffer = buffer.Slice(FrameHeaderSize, len); - _currentRead = buffer; - // 0 length means it was part of the reconnect handshake and not sent over the pipe, ignore it for acking purposes - // TODO: check if 0 byte writes are possible in ConnectionHandlers and possibly handle them differently - _ackPipeWriter.LastAck += buffer.Length == 0 ? 0 : buffer.Length + FrameHeaderSize; - } - else - { - // Advance was called and didn't consume everything even though we gave it the entire Frame Length of data - // This means the caller is expecting more than a single frame of data - // We'll need to start buffering to parse multiple frames of data - if (_remaining <= _currentRead.Length && buffer.Length > _remaining) - { - // TODO: multi-frame support - } - _ackPipeWriter.LastAck += Math.Min(_remaining, newBytes); - _currentRead = buffer; - buffer = buffer.Slice(0, Math.Min(_remaining, buffer.Length)); - } - - // TODO: validation everywhere! - //Debug.Assert(len < res.Buffer.Length); - - res = new(buffer, res.IsCanceled, res.IsCompleted); - - // TODO: probably should avoid returning when we have 0 bytes to return (unless canceled/completed) - //Debug.Assert(buffer.Length > 0); - } - catch (Exception ex) - { - _inner.Complete(ex); - throw; - } - - return res; - } - - public static long ParseFrame(ReadOnlySequence frame, AckPipeReader ackPipeReader) - { - Debug.Assert(frame.Length >= FrameHeaderSize); - frame = frame.Slice(0, FrameHeaderSize); - - long len; - long ackId; - - // TODO: check perf of single Span check vs Stackalloc - Span buffer = stackalloc byte[FrameHeaderSize]; - frame.CopyTo(buffer); - var status = Base64.DecodeFromUtf8InPlace(buffer.Slice(0, FrameHeaderSize / 2), out var written); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(written == 8); - - len = BinaryPrimitives.ReadInt64LittleEndian(buffer); - - var ackFrame = buffer.Slice(FrameHeaderSize / 2); - status = Base64.DecodeFromUtf8InPlace(ackFrame, out written); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(written == 8); - ackId = BinaryPrimitives.ReadInt64LittleEndian(ackFrame); - - // Update ack id provided by other side, so the underlying pipe can release buffered memory - ackPipeReader.Ack(ackId); - return len; - } - - public override bool TryRead(out ReadResult result) - { - // TODO: Not needed for SignalR, but could be called in ConnectionHandler layer of user code - throw new NotImplementedException(); - } -} diff --git a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs deleted file mode 100644 index 0b99495c9b47..000000000000 --- a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/AckPipeTests.cs +++ /dev/null @@ -1,830 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Buffers; -using System.Buffers.Text; -using System.Diagnostics; -using System.IO.Pipelines; -using Microsoft.AspNetCore.Http.Connections; -using Microsoft.AspNetCore.Testing; - -namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol; - -public class AckPipeTests -{ - private const int FrameSize = 24; - - [Fact] - public async Task CanSendAndReceiveTransport() - { - var duplexPipe = CreateConnectionPair(new PipeOptions(), new PipeOptions()); - - var values = new byte[] { 1, 2, 3, 4, 5 }; - var flushRes = await duplexPipe.Transport.Output.WriteAsync(values); - - Assert.False(flushRes.IsCanceled); - Assert.False(flushRes.IsCompleted); - - var readResult = await duplexPipe.Application.Input.ReadAsync(); - - Assert.False(readResult.IsCanceled); - Assert.False(readResult.IsCompleted); - Assert.Equal(values.Length, readResult.Buffer.Length); - Assert.Equal(values, readResult.Buffer.ToArray()); - } - - [Fact] - public async Task CanSendAndReceiveLargeAmount() - { - var duplexPipe = CreateConnectionPair(new PipeOptions(), new PipeOptions()); - - var values = new byte[20000]; - Random.Shared.NextBytes(values); - var flushRes = await duplexPipe.Transport.Output.WriteAsync(values); - - Assert.False(flushRes.IsCanceled); - Assert.False(flushRes.IsCompleted); - - var readResult = await duplexPipe.Application.Input.ReadAsync(); - - Assert.False(readResult.IsCanceled); - Assert.False(readResult.IsCompleted); - Assert.Equal(values.Length, readResult.Buffer.Length); - Assert.Equal(values, readResult.Buffer.ToArray()); - } - - [Fact] - public async Task CanSendAndReceiveLargeAmount_ManyWritesSingleFlush() - { - var duplexPipe = CreateConnectionPair(new PipeOptions(), new PipeOptions()); - - var values = new byte[20000]; - Random.Shared.NextBytes(values); - var written = 0; - while (written < values.Length) - { - var mem = duplexPipe.Transport.Output.GetMemory(); - var toWrite = Math.Min(mem.Length, values.Length - written); - values.AsSpan(written, toWrite).CopyTo(mem.Span); - duplexPipe.Transport.Output.Advance(toWrite); - written += toWrite; - } - - var flushRes = await duplexPipe.Transport.Output.FlushAsync(); - - Assert.False(flushRes.IsCanceled); - Assert.False(flushRes.IsCompleted); - - var readResult = await duplexPipe.Application.Input.ReadAsync(); - - Assert.False(readResult.IsCanceled); - Assert.False(readResult.IsCompleted); - Assert.Equal(values.Length, readResult.Buffer.Length); - Assert.Equal(values, readResult.Buffer.ToArray()); - } - - [Fact] - public async Task ReadFromTransportRemovesFraming() - { - var duplexPipe = CreateClient(); - - var buffer = new byte[28]; - Random.Shared.NextBytes(buffer); - WriteFrame(buffer, buffer.Length - FrameSize, 0); - - // "write" from server - await duplexPipe.Application.Output.WriteAsync(buffer); - - // read in client application layer - var res = await duplexPipe.Transport.Input.ReadAsync(); - Assert.Equal(4, res.Buffer.Length); - Assert.Equal(buffer.AsSpan(FrameSize).ToArray(), res.Buffer.ToArray()); - } - - [Fact] - public async Task WriteFromApplicationAddsFraming() - { - var duplexPipe = CreateClient(); - - var buffer = new byte[20]; - Random.Shared.NextBytes(buffer); - - await duplexPipe.Transport.Output.WriteAsync(buffer); - - var res = await duplexPipe.Application.Input.ReadAsync(); - var framing = ReadFrame(res.Buffer.ToArray()); - - Assert.Equal(buffer.Length, framing.Length); - Assert.Equal(0, framing.AckId); - Assert.Equal(buffer.Length + FrameSize, res.Buffer.Length); - Assert.Equal(buffer, res.Buffer.Slice(FrameSize).ToArray()); - } - - [Fact] - public async Task MultipleWritesSingleFlushFromApplicationAddsFraming() - { - var duplexPipe = CreateClient(); - - var buffer = new byte[20]; - Random.Shared.NextBytes(buffer); - - for (var i = 0; i < 3; i++) - { - var memory = duplexPipe.Transport.Output.GetMemory(); - buffer.CopyTo(memory); - duplexPipe.Transport.Output.Advance(buffer.Length); - } - await duplexPipe.Transport.Output.FlushAsync(); - - var res = await duplexPipe.Application.Input.ReadAsync(); - var framing = ReadFrame(res.Buffer.ToArray()); - - Assert.Equal(buffer.Length * 3, framing.Length); - Assert.Equal(0, framing.AckId); - Assert.Equal(buffer.Length * 3 + FrameSize, res.Buffer.Length); - Assert.Equal(buffer, res.Buffer.Slice(FrameSize, buffer.Length).ToArray()); - Assert.Equal(buffer, res.Buffer.Slice(FrameSize + buffer.Length, buffer.Length).ToArray()); - Assert.Equal(buffer, res.Buffer.Slice(FrameSize + buffer.Length * 2, buffer.Length).ToArray()); - } - - [Fact] - public async Task ReadFromTransportAcrossMultipleReads() - { - var duplexPipe = CreateClient(); - - var buffer = new byte[28]; - Random.Shared.NextBytes(buffer); - WriteFrame(buffer, buffer.Length - FrameSize + buffer.Length + buffer.Length, 0); - - // "write" from server - await duplexPipe.Application.Output.WriteAsync(buffer); - - // read in client application layer - var res = await duplexPipe.Transport.Input.ReadAsync(); - - Assert.Equal(4, res.Buffer.Length); - Assert.Equal(buffer.AsSpan(FrameSize).ToArray(), res.Buffer.ToArray()); - - // consume nothing - duplexPipe.Transport.Input.AdvanceTo(res.Buffer.Start, res.Buffer.End); - - await duplexPipe.Application.Output.WriteAsync(buffer); - res = await duplexPipe.Transport.Input.ReadAsync(); - - Assert.Equal(32, res.Buffer.Length); - - // consume nothing - duplexPipe.Transport.Input.AdvanceTo(res.Buffer.Start, res.Buffer.End); - - await duplexPipe.Application.Output.WriteAsync(buffer); - res = await duplexPipe.Transport.Input.ReadAsync(); - - Assert.Equal(60, res.Buffer.Length); - - // consume everything - duplexPipe.Transport.Input.AdvanceTo(res.Buffer.End); - - // New write to make sure internal state is cleared from completed read - WriteFrame(buffer, buffer.Length - FrameSize, 0); - await duplexPipe.Application.Output.WriteAsync(buffer); - res = await duplexPipe.Transport.Input.ReadAsync(); - - Assert.Equal(4, res.Buffer.Length); - } - - [Fact] - public async Task ManyWritesSingleFlush_WritesSingleFrame() - { - var duplexPipe = CreateClient(); - - var buffer = new byte[20]; - Random.Shared.NextBytes(buffer); - - var memory = duplexPipe.Transport.Output.GetMemory(); - Assert.True(memory.Length > buffer.Length); - buffer.CopyTo(memory); - duplexPipe.Transport.Output.Advance(buffer.Length); - - memory = duplexPipe.Transport.Output.GetMemory(); - Assert.True(memory.Length > buffer.Length); - buffer.CopyTo(memory); - duplexPipe.Transport.Output.Advance(buffer.Length); - - memory = duplexPipe.Transport.Output.GetMemory(); - Assert.True(memory.Length > buffer.Length); - buffer.CopyTo(memory); - duplexPipe.Transport.Output.Advance(buffer.Length); - - await duplexPipe.Transport.Output.FlushAsync(); - - var res = await duplexPipe.Application.Input.ReadAsync(); - var framing = ReadFrame(res.Buffer.ToArray()); - - Assert.Equal(buffer.Length * 3, framing.Length); - Assert.Equal(0, framing.AckId); - Assert.Equal(framing.Length + FrameSize, res.Buffer.Length); - - var buf = res.Buffer.Slice(FrameSize); - while (buf.Length > 0) - { - Assert.Equal(buffer, buf.Slice(0, buffer.Length).ToArray()); - buf = buf.Slice(buffer.Length); - } - } - - [Fact(Skip = "Something we want to support?")] - public async Task ReadFromTransportAcrossFrames() - { - var duplexPipe = CreateClient(); - - var buffer = new byte[20]; - Random.Shared.NextBytes(buffer); - WriteFrame(buffer, buffer.Length - FrameSize, 0); - - // "write" from server - await duplexPipe.Application.Output.WriteAsync(buffer); - await duplexPipe.Application.Output.WriteAsync(buffer); - - // read in client application layer - var res = await duplexPipe.Transport.Input.ReadAsync(); - - Assert.Equal(4, res.Buffer.Length); - Assert.Equal(buffer.AsSpan(FrameSize).ToArray(), res.Buffer.ToArray()); - - // consume nothing - duplexPipe.Transport.Input.AdvanceTo(res.Buffer.Start, res.Buffer.End); - - res = await duplexPipe.Transport.Input.ReadAsync(); - // ?? - } - - [Fact] - public async Task AckFromTransportReadUpdatesApplicationBuffer() - { - var duplexPipe = CreateClient(); - // write something so we can ack it and see that the pipe has nothing in it - await duplexPipe.Transport.Output.WriteAsync(new byte[2]); - - var res = await duplexPipe.Application.Input.ReadAsync(); - // in real usage this will be advanced properly - // but we're claiming we read nothing so we can observe the ack behavior in the next read - duplexPipe.Application.Input.AdvanceTo(res.Buffer.Start); - - var buffer = new byte[28]; - Random.Shared.NextBytes(buffer); - WriteFrame(buffer, buffer.Length - 24, ackId: FrameSize + 2); - - // "write" from server - await duplexPipe.Application.Output.WriteAsync(buffer); - - // read in client application layer - // this reads the ack from the "server" and updates state - _ = await duplexPipe.Transport.Input.ReadAsync(); - - // this will be an empty read because the ack will be applied and everything will be marked as read - res = await duplexPipe.Application.Input.ReadAsync(); - - Assert.Equal(0, res.Buffer.Length); - Assert.False(res.IsCanceled); - Assert.False(res.IsCompleted); - } - - [Fact] - public async Task AckFromTransportReadUpdatesApplicationBuffer_CanReadNewDataAfter() - { - // Basically the same test as AckFromTransportReadUpdatesApplicationBuffer but we write more data after the ack has fully flowed - // Just to smoke test that the pipe is still usable - - var duplexPipe = CreateClient(); - // write something so we can ack it - await duplexPipe.Transport.Output.WriteAsync(new byte[2]); - - var res = await duplexPipe.Application.Input.ReadAsync(); - // in real usage this will be advanced properly - // but we're claiming we read nothing so we can observe the ack behavior in the next read - duplexPipe.Application.Input.AdvanceTo(res.Buffer.Start); - - var buffer = new byte[FrameSize + 4]; - Random.Shared.NextBytes(buffer); - WriteFrame(buffer, buffer.Length - FrameSize, ackId: FrameSize + 2); - - // "write" from server - await duplexPipe.Application.Output.WriteAsync(buffer); - - // read in client application layer - // this reads the ack from the "server" and updates state - res = await duplexPipe.Transport.Input.ReadAsync(); - duplexPipe.Transport.Input.AdvanceTo(res.Buffer.End); - - // write again to update total sent - await duplexPipe.Transport.Output.WriteAsync(new byte[2] { 42, 99 }); - - res = await duplexPipe.Application.Input.ReadAsync(); - - Assert.Equal(FrameSize + 2, res.Buffer.Length); - var (len, ack) = ReadFrame(res.Buffer.ToArray()); - Assert.Equal(2, len); - Assert.Equal(FrameSize + 4, ack); - Assert.Equal(new byte[] { 42, 99 }, res.Buffer.Slice(FrameSize).ToArray()); - Assert.False(res.IsCanceled); - Assert.False(res.IsCompleted); - } - - [Fact] - public async Task ReceiveAckIdLargerThanTotalSentErrors() - { - var duplexPipe = CreateClient(); - - var buffer = new byte[28]; - Random.Shared.NextBytes(buffer); - // ackId more than what has been sent - WriteFrame(buffer, buffer.Length - FrameSize, ackId: 30); - - // "write" from server - await duplexPipe.Application.Output.WriteAsync(buffer); - - // read in client application layer - var exception = await Assert.ThrowsAsync(async () => await duplexPipe.Transport.Input.ReadAsync()); - Assert.Equal("Ack ID '30' is greater than total amount of '0' bytes that have been sent.", exception.Message); - - // Pipe is completed - exception = await Assert.ThrowsAsync(async () => await duplexPipe.Transport.Input.ReadAsync()); - Assert.Equal("Reading is not allowed after reader was completed.", exception.Message); - } - - // This is a fun edge case test, where if we have consumed everything in a BufferSegment and Acked everything too - // then consumed points to the end of the Segment, while Ack points to the beginning of the next Segment - // This test verifies that everything behaves correctly in that case - [Fact] - public async Task ConsumeAndAckAtEndOfSegment_CanServeNextSegment() - { - var duplexPipe = CreateClient(); - - var buffer = new byte[4072]; - Random.Shared.NextBytes(buffer); - - // "write" from server - await duplexPipe.Transport.Output.WriteAsync(buffer); - - // read in client application layer - var res = await duplexPipe.Application.Input.ReadAsync(); - duplexPipe.Application.Input.AdvanceTo(res.Buffer.End); - - Random.Shared.NextBytes(buffer); - await duplexPipe.Transport.Output.WriteAsync(buffer); - - var appBuffer = new byte[28]; - Random.Shared.NextBytes(appBuffer); - WriteFrame(appBuffer, appBuffer.Length - FrameSize, 4096); - await duplexPipe.Application.Output.WriteAsync(appBuffer); - - // Updates Ack in Application.Input - await duplexPipe.Transport.Input.ReadAsync(); - - res = await duplexPipe.Application.Input.ReadAsync(); - Assert.Equal(4096, res.Buffer.Length); - var (len, ack) = ReadFrame(res.Buffer.ToArray()); - Assert.Equal(4072, len); - Assert.Equal(0, ack); - Assert.Equal(buffer, res.Buffer.Slice(FrameSize).ToArray()); - Assert.True(res.Buffer.IsSingleSegment); - } - - [Fact] - public async Task ApplicationSendsAck() - { - var duplexPipe = CreateClient(); - - var buffer = new byte[FrameSize + 4]; - Random.Shared.NextBytes(buffer); - WriteFrame(buffer, buffer.Length - FrameSize, 0); - - // "write" from server - await duplexPipe.Application.Output.WriteAsync(buffer); - - // read in client application layer - var res = await duplexPipe.Transport.Input.ReadAsync(); - duplexPipe.Transport.Input.AdvanceTo(res.Buffer.End); - - _ = await duplexPipe.Transport.Output.WriteAsync(new byte[2]); - - res = await duplexPipe.Application.Input.ReadAsync(); - var (length, ackId) = ReadFrame(res.Buffer.ToArray()); - - Assert.Equal(2, length); - Assert.Equal(FrameSize + 4, ackId); - } - - [Fact] - public async Task ApplicationSendsAckWithMultiSegment_ConsumingWhileReading() - { - var duplexPipe = CreateClient(); - - var buffer = new byte[FrameSize + 5]; - Random.Shared.NextBytes(buffer); - WriteFrame(buffer, buffer.Length - FrameSize, 0); - - // "write" from server, 26 of the 29 bytes, we want to force the reader to do two reads to get the full data - await duplexPipe.Application.Output.WriteAsync(buffer.AsSpan(0, FrameSize + 2).ToArray()); - - // read in client application layer - var res = await duplexPipe.Transport.Input.ReadAsync(); - Assert.Equal(2, res.Buffer.Length); - Assert.Equal(buffer.AsSpan(FrameSize, 2).ToArray(), res.Buffer.ToArray()); - // Consume all seen so far - duplexPipe.Transport.Input.AdvanceTo(res.Buffer.End); - - // write again, the last 3 of the 29 bytes - await duplexPipe.Application.Output.WriteAsync(buffer.AsSpan(FrameSize + 2, 3).ToArray()); - - res = await duplexPipe.Transport.Input.ReadAsync(); - Assert.Equal(3, res.Buffer.Length); - Assert.Equal(buffer.AsSpan(FrameSize + 2, 3).ToArray(), res.Buffer.ToArray()); - duplexPipe.Transport.Input.AdvanceTo(res.Buffer.End); - - _ = await duplexPipe.Transport.Output.WriteAsync(new byte[2]); - - res = await duplexPipe.Application.Input.ReadAsync(); - var (length, ackId) = ReadFrame(res.Buffer.ToArray()); - - Assert.Equal(2, length); - Assert.Equal(FrameSize + 5, ackId); - } - - [Fact] - public async Task ApplicationSendsAckWithMultiSegment_OnlyConsumeAtEnd() - { - var duplexPipe = CreateClient(); - - var buffer = new byte[29]; - Random.Shared.NextBytes(buffer); - WriteFrame(buffer, buffer.Length - FrameSize, 0); - - // "write" from server, 26 of the 29 bytes, we want to force the reader to do two reads to get the full data - await duplexPipe.Application.Output.WriteAsync(buffer.AsSpan(0, 26).ToArray()); - - // read in client application layer - var res = await duplexPipe.Transport.Input.ReadAsync(); - Assert.Equal(2, res.Buffer.Length); - Assert.Equal(buffer.AsSpan(FrameSize, 2).ToArray(), res.Buffer.ToArray()); - // Don't consume any - duplexPipe.Transport.Input.AdvanceTo(res.Buffer.Start, res.Buffer.End); - - // write again, the last 3 of the 29 bytes - await duplexPipe.Application.Output.WriteAsync(buffer.AsSpan(26, 3).ToArray()); - - res = await duplexPipe.Transport.Input.ReadAsync(); - Assert.Equal(5, res.Buffer.Length); - Assert.Equal(buffer.AsSpan(FrameSize, 5).ToArray(), res.Buffer.ToArray()); - duplexPipe.Transport.Input.AdvanceTo(res.Buffer.End); - - _ = await duplexPipe.Transport.Output.WriteAsync(new byte[2]); - - res = await duplexPipe.Application.Input.ReadAsync(); - var (length, ackId) = ReadFrame(res.Buffer.ToArray()); - - Assert.Equal(2, length); - Assert.Equal(29, ackId); - } - - [Fact] - public async Task CompleteWithErrorFromTransportWriterFlowsToAppReader() - { - var duplexPipe = CreateClient(); - - duplexPipe.Transport.Output.Complete(new Exception("custom")); - - var ex = await Assert.ThrowsAsync(async () => await duplexPipe.Application.Input.ReadAsync()); - Assert.Equal("custom", ex.Message); - } - - [Fact] - public async Task CompleteWithErrorFromTransportReaderFlowsToAppWriter() - { - var duplexPipe = CreateClient(); - - duplexPipe.Transport.Input.Complete(new Exception("custom")); - - var ex = await Assert.ThrowsAsync(async () => await duplexPipe.Application.Output.FlushAsync()); - Assert.Equal("custom", ex.Message); - } - - [Fact] - public async Task CompleteWithErrorFromAppWriterFlowsToTransportReader() - { - var duplexPipe = CreateClient(); - - duplexPipe.Application.Output.Complete(new Exception("custom")); - - var ex = await Assert.ThrowsAsync(async () => await duplexPipe.Transport.Input.ReadAsync()); - Assert.Equal("custom", ex.Message); - } - - [Fact] - public async Task CompleteWithErrorFromAppReaderFlowsToTransportWriter() - { - var duplexPipe = CreateClient(); - - duplexPipe.Application.Input.Complete(new Exception("custom")); - - var ex = await Assert.ThrowsAsync(async () => await duplexPipe.Transport.Output.WriteAsync(new byte[1])); - Assert.Equal("custom", ex.Message); - } - - [Fact] - public async Task TriggerResendWithNothingWritten() - { - var duplexPipe = CreateClient(); - - var reader = (AckPipeReader)duplexPipe.Application.Input; - reader.Resend(); - - await duplexPipe.Transport.Output.WriteAsync(new byte[2]); - - var res = await duplexPipe.Application.Input.ReadAsync(); - - Assert.Equal(FrameSize + 2, res.Buffer.Length); - Assert.False(res.IsCanceled); - Assert.False(res.IsCompleted); - - var (length, ackId) = ReadFrame(res.Buffer.ToArray()); - Assert.Equal(2, length); - Assert.Equal(0, ackId); - } - - [Fact] - public async Task TriggerResendWithEverythingAcked() - { - var duplexPipe = CreateClient(); - - await duplexPipe.Transport.Output.WriteAsync(new byte[2]); - // Read to pretend we've sent 18 bytes, so that an ack will be allowed - var res = await duplexPipe.Application.Input.ReadAsync(); - duplexPipe.Application.Input.AdvanceTo(res.Buffer.Start, res.Buffer.Start); - - var buffer = new byte[FrameSize]; - WriteFrame(buffer, 0, FrameSize + 2); - await duplexPipe.Application.Output.WriteAsync(buffer); - - // Updates ack from App.Output in App.Input - _ = await duplexPipe.Transport.Input.ReadAsync(); - - var reader = (AckPipeReader)duplexPipe.Application.Input; - reader.Resend(); - - // Nothing returned since everything was acked before resend triggered - res = await duplexPipe.Application.Input.ReadAsync(); - - Assert.Equal(0, res.Buffer.Length); - Assert.True(res.IsCanceled); - Assert.False(res.IsCompleted); - duplexPipe.Application.Input.AdvanceTo(res.Buffer.End); - - // smoke testing that we can still receive - await duplexPipe.Transport.Output.WriteAsync(new byte[2]); - - res = await duplexPipe.Application.Input.ReadAsync(); - - Assert.Equal(FrameSize + 2, res.Buffer.Length); - var (len, ackId) = ReadFrame(res.Buffer.ToArray()); - Assert.Equal(2, len); - Assert.Equal(0, ackId); - Assert.False(res.IsCanceled); - Assert.False(res.IsCompleted); - } - - [Fact] - public async Task TriggerResendSendsEverythingNotAcked() - { - var duplexPipe = CreateClient(); - - // Write two frames of data - await duplexPipe.Transport.Output.WriteAsync(new byte[2] { 1, 2 }); - await duplexPipe.Transport.Output.WriteAsync(new byte[2] { 3, 4 }); - - var reader = (AckPipeReader)duplexPipe.Application.Input; - reader.Resend(); - - var res = await duplexPipe.Application.Input.ReadAsync(); - - Assert.Equal(52, res.Buffer.Length); - Assert.False(res.IsCanceled); - Assert.False(res.IsCompleted); - var (len, ackId) = ReadFrame(res.Buffer.ToArray()); - Assert.Equal(2, len); - Assert.Equal(0, ackId); - Assert.Equal(new byte[] { 1, 2 }, res.Buffer.ToArray().AsSpan(FrameSize, 2).ToArray()); - (len, ackId) = ReadFrame(res.Buffer.ToArray().AsSpan(FrameSize + 2).ToArray()); - Assert.Equal(2, len); - Assert.Equal(0, ackId); - Assert.Equal(new byte[] { 3, 4 }, res.Buffer.ToArray().AsSpan(FrameSize * 2 + 2, 2).ToArray()); - - duplexPipe.Application.Input.AdvanceTo(res.Buffer.End); - - // smoke testing that we can still receive - await duplexPipe.Transport.Output.WriteAsync(new byte[2] { 4, 5 }); - - res = await duplexPipe.Application.Input.ReadAsync(); - - Assert.Equal(FrameSize + 2, res.Buffer.Length); - (len, ackId) = ReadFrame(res.Buffer.ToArray()); - Assert.Equal(2, len); - Assert.Equal(0, ackId); - Assert.Equal(new byte[] { 4, 5 }, res.Buffer.ToArray().AsSpan(FrameSize, 2).ToArray()); - Assert.False(res.IsCanceled); - Assert.False(res.IsCompleted); - } - - [Fact] - public async Task TriggerResendWhenPartialFrameAcked() - { - var duplexPipe = CreateClient(); - - await duplexPipe.Transport.Output.WriteAsync(new byte[] { 1, 2, 3, 4, 5, 6, 7 }); - // Read to pretend we've sent 31 bytes, so that an ack will be allowed - var res = await duplexPipe.Application.Input.ReadAsync(); - duplexPipe.Application.Input.AdvanceTo(res.Buffer.Start, res.Buffer.Start); - - var buffer = new byte[FrameSize]; - // Only ack 26 of 31 bytes - WriteFrame(buffer, 0, FrameSize + 2); - await duplexPipe.Application.Output.WriteAsync(buffer); - - // Updates ack from App.Output in App.Input - res = await duplexPipe.Transport.Input.ReadAsync(); - duplexPipe.Transport.Input.AdvanceTo(res.Buffer.End); - - var reader = (AckPipeReader)duplexPipe.Application.Input; - reader.Resend(); - - res = await duplexPipe.Application.Input.ReadAsync(); - - Assert.Equal(5, res.Buffer.Length); - Assert.True(res.IsCanceled); - Assert.False(res.IsCompleted); - Assert.Equal(new byte[] { 3, 4, 5, 6, 7 }, res.Buffer.ToArray()); - - duplexPipe.Application.Input.AdvanceTo(res.Buffer.End); - - // smoke testing that we can still receive - await duplexPipe.Transport.Output.WriteAsync(new byte[2] { 9, 7 }); - - res = await duplexPipe.Application.Input.ReadAsync(); - - Assert.Equal(FrameSize + 2, res.Buffer.Length); - var (len, ackId) = ReadFrame(res.Buffer.ToArray()); - Assert.Equal(2, len); - Assert.Equal(0, ackId); - Assert.Equal(new byte[] { 9, 7 }, res.Buffer.ToArray().AsSpan(FrameSize, 2).ToArray()); - Assert.False(res.IsCanceled); - Assert.False(res.IsCompleted); - } - - [Fact] - public async Task BackpressureIsAppliedInBothDirections() - { - var duplexPipe = CreateClient(inputOptions: new PipeOptions(pauseWriterThreshold: 10, resumeWriterThreshold: 5, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline), - outputOptions: new PipeOptions(pauseWriterThreshold: 10, resumeWriterThreshold: 5, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline)); - - var buffer = new byte[FrameSize + 1]; - WriteFrame(buffer, 1, 0); - var writeTask = duplexPipe.Application.Output.WriteAsync(buffer); - // Shouldn't complete until the reader reads due to pauseWriterThreshold being 10 and we wrote 25 - Assert.False(writeTask.IsCompleted); - - var res = await duplexPipe.Transport.Input.ReadAsync(); - Assert.Equal(1, res.Buffer.Length); - duplexPipe.Transport.Input.AdvanceTo(res.Buffer.Start, res.Buffer.End); - await writeTask.DefaultTimeout(); - - writeTask = duplexPipe.Transport.Output.WriteAsync(new byte[2] { 4, 5 }); - // Shouldn't complete until the reader reads due to pauseWriterThreshold being 10 and we wrote 26 - Assert.False(writeTask.IsCompleted); - - res = await duplexPipe.Application.Input.ReadAsync(); - Assert.Equal(26, res.Buffer.Length); - duplexPipe.Application.Input.AdvanceTo(res.Buffer.End); - await writeTask.DefaultTimeout(); - } - - internal static DuplexPipePair CreateClient(PipeOptions inputOptions = default, PipeOptions outputOptions = default) - { - var input = new Pipe(inputOptions ?? new()); - var output = new Pipe(outputOptions ?? new()); - - // Use for one side only, this is client side - var ackWriter = new AckPipeWriter(output); - var ackReader = new AckPipeReader(output); - var transportReader = new ParseAckPipeReader(input.Reader, ackWriter, ackReader); - var transportToApplication = new DuplexPipe(ackReader, input.Writer); - var applicationToTransport = new DuplexPipe(transportReader, ackWriter); - - // Transport.Output.Write goes to Application.Input, which is read in the transport code - // Application.Output.Write goes to Transport.Input, which is read in the application code - - return new DuplexPipePair(applicationToTransport, transportToApplication); - } - - internal static DuplexPipePair CreateServer(PipeOptions inputOptions = default, PipeOptions outputOptions = default) - { - var input = new Pipe(inputOptions ?? new()); - var output = new Pipe(outputOptions ?? new()); - - // Use for one side only, this is server side - var ackWriter = new AckPipeWriter(output); - var ackReader = new AckPipeReader(output); - var transportReader = new ParseAckPipeReader(input.Reader, ackWriter, ackReader); - var transportToApplication = new DuplexPipe(ackReader, input.Writer); - var applicationToTransport = new DuplexPipe(transportReader, ackWriter); - - return new DuplexPipePair(transportToApplication, applicationToTransport); - } - - internal static void WriteFrame(byte[] header, long payloadLength, long ackId = 0) - { - Assert.True(header.Length >= FrameSize); - - Assert.True(BitConverter.TryWriteBytes(header, payloadLength)); - Assert.True(BitConverter.TryWriteBytes(header.AsSpan(8), ackId)); - var res = BitConverter.TryWriteBytes(header.AsSpan(), payloadLength); - Debug.Assert(res); - var status = Base64.EncodeToUtf8InPlace(header.AsSpan(), 8, out var written); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(written == 12); - res = BitConverter.TryWriteBytes(header.AsSpan(12), ackId); - Debug.Assert(res); - status = Base64.EncodeToUtf8InPlace(header.AsSpan(12), 8, out written); - Debug.Assert(status == OperationStatus.Done); - Debug.Assert(written == 12); - } - - internal static (long Length, long AckId) ReadFrame(byte[] frameBytes) - { - var frame = frameBytes.AsSpan(0, FrameSize); - Span buffer = stackalloc byte[FrameSize]; - frame.CopyTo(buffer); - var status = Base64.DecodeFromUtf8InPlace(buffer.Slice(0, 12), out var written); - Assert.Equal(OperationStatus.Done, status); - Assert.Equal(8, written); - var len = BitConverter.ToInt64(buffer); - status = Base64.DecodeFromUtf8InPlace(buffer.Slice(12, 12), out written); - Assert.Equal(OperationStatus.Done, status); - Assert.Equal(8, written); - var ackId = BitConverter.ToInt64(buffer.Slice(12)); - - return (len, ackId); - } - - internal static (long PayloadLength, long AckId) ReadFrame(ref Span header) - { - Assert.True(header.Length >= FrameSize); - - var len = BitConverter.ToInt64(header); - var ackId = BitConverter.ToInt64(header.Slice(FrameSize / 2)); - - return (len, ackId); - } - - internal static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) - { - var input = new Pipe(inputOptions); - var output = new Pipe(outputOptions); - - // wire up both sides for testing - var ackWriterApp = new AckPipeWriter(output); - var ackReaderApp = new AckPipeReader(output); - var ackWriterClient = new AckPipeWriter(input); - var ackReaderClient = new AckPipeReader(input); - var transportReader = new ParseAckPipeReader(input.Reader, ackWriterApp, ackReaderApp); - var applicationReader = new ParseAckPipeReader(ackReaderApp, ackWriterClient, ackReaderClient); - var transportToApplication = new DuplexPipe(applicationReader, ackWriterClient); - var applicationToTransport = new DuplexPipe(transportReader, ackWriterApp); - - return new DuplexPipePair(applicationToTransport, transportToApplication); - } - - internal sealed class DuplexPipe : IDuplexPipe - { - public DuplexPipe(PipeReader reader, PipeWriter writer) - { - Input = reader; - Output = writer; - } - - public PipeReader Input { get; } - - public PipeWriter Output { get; } - } - - public readonly struct DuplexPipePair - { - public IDuplexPipe Transport { get; } - public IDuplexPipe Application { get; } - - public DuplexPipePair(IDuplexPipe transport, IDuplexPipe application) - { - Transport = transport; - Application = application; - } - } -} diff --git a/src/SignalR/common/SignalR.Common/test/Microsoft.AspNetCore.SignalR.Common.Tests.csproj b/src/SignalR/common/SignalR.Common/test/Microsoft.AspNetCore.SignalR.Common.Tests.csproj index 1f5288af77e8..1120424ee602 100644 --- a/src/SignalR/common/SignalR.Common/test/Microsoft.AspNetCore.SignalR.Common.Tests.csproj +++ b/src/SignalR/common/SignalR.Common/test/Microsoft.AspNetCore.SignalR.Common.Tests.csproj @@ -7,9 +7,6 @@ - - - diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index 07944315e7ba..7fff770ba810 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -8,6 +8,7 @@ using System.IO.Pipelines; using System.Security.Claims; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Abstractions; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.SignalR.Internal; @@ -35,8 +36,8 @@ public partial class HubConnectionContext private readonly TimeProvider _timeProvider; private readonly CancellationTokenRegistration _closedRegistration; private readonly CancellationTokenRegistration? _closedRequestedRegistration; - private readonly MessageBuffer _messageBuffer = new(); + private MessageBuffer? _messageBuffer; private StreamTracker? _streamTracker; private long _lastSendTick; private ReadOnlyMemory _cachedPingMessage; @@ -262,7 +263,7 @@ private ValueTask WriteCore(HubMessage message, CancellationToken c var isAck = true; if (isAck) { - return _messageBuffer.WriteAsync(_connectionContext.Transport.Output, new SerializedHubMessage(message), Protocol, cancellationToken); + return _messageBuffer.WriteAsync(new SerializedHubMessage(message), Protocol, cancellationToken); } else { @@ -292,7 +293,7 @@ private ValueTask WriteCore(SerializedHubMessage message, Cancellat var isAck = true; if (isAck) { - return _messageBuffer.WriteAsync(_connectionContext.Transport.Output, message, Protocol, cancellationToken); + return _messageBuffer.WriteAsync(message, Protocol, cancellationToken); } else { @@ -572,6 +573,10 @@ await WriteHandshakeResponseAsync(new HandshakeResponseMessage( Log.HandshakeComplete(_logger, Protocol.Name); await WriteHandshakeResponseAsync(HandshakeResponseMessage.Empty); + + _messageBuffer = new MessageBuffer(_connectionContext, Protocol); + var f = _connectionContext.Features.Get(); + f.NotifyOnReconnect = _messageBuffer.Resend; return true; } else if (overLength) From a08cdc5bd049d8e2c3c025f9b84e9ada3a6222dc Mon Sep 17 00:00:00 2001 From: Brennan Conroy Date: Tue, 16 May 2023 15:27:30 -0700 Subject: [PATCH 18/25] tiny cleanup --- src/SignalR/common/Shared/MessageBuffer.cs | 34 +------------------ .../server/Core/src/HubConnectionContext.cs | 9 ----- 2 files changed, 1 insertion(+), 42 deletions(-) diff --git a/src/SignalR/common/Shared/MessageBuffer.cs b/src/SignalR/common/Shared/MessageBuffer.cs index 16aa345debd7..592aaec98a3c 100644 --- a/src/SignalR/common/Shared/MessageBuffer.cs +++ b/src/SignalR/common/Shared/MessageBuffer.cs @@ -80,37 +80,6 @@ public async ValueTask WriteAsync(SerializedHubMessage hubMessage, return await oldTcs.Task.ConfigureAwait(false); } return await _resend.Task.ConfigureAwait(false); - - //long latestAckedIndex = -1; - //for (var i = 0; i < _buffer.Length - 1; i++) - //{ - // if (_buffer[(_index + i + 1) % _buffer.Length].SequenceId > long.MinValue) - // { - // latestAckedIndex = (_index + i + 1) % _buffer.Length; - // } - //} - - //if (latestAckedIndex == -1) - //{ - // // no unacked messages, probably not possible - // // because we are in the middle of writing a message when we get here, so there should be 1 minimum - //} - - //protocol.WriteMessage(new SequenceMessage(_buffer[latestAckedIndex].SequenceId), connection.Transport.Output); - //await connection.Transport.Output.FlushAsync(cancellationToken).ConfigureAwait(false); - - //for (var i = 0; i < _buffer.Length; i++) - //{ - // var item = _buffer[(latestAckedIndex + i) % _buffer.Length]; - // if (item.SequenceId > long.MinValue) - // { - // await connection.Transport.Output.WriteAsync(item.Message!.GetSerializedMessage(protocol), cancellationToken).ConfigureAwait(false); - // } - // else - // { - // break; - // } - //} } } @@ -185,8 +154,7 @@ private async Task DoResendAsync(TaskCompletionSource tcs) if (latestAckedIndex == -1) { - // no unacked messages, probably not possible - // because we are in the middle of writing a message when we get here, so there should be 1 minimum + // no unacked messages, still send SequenceMessage? } FlushResult finalResult = new(); diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index 7fff770ba810..ad2ba3fe18f5 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -765,15 +765,6 @@ internal void Ack(AckMessage ackMessage) _messageBuffer.Ack(ackMessage); } - //private long? GetSequenceId() - //{ - // if (UseAcks) - // { - // return Interlocked.Increment(ref _sequenceId); - // } - // return null; - //} - private long _currentReceivingSequenceId; internal bool ShouldProcessMessage(HubInvocationMessage message) From be2c5d65bd6fb03c6ae9da56587e243bf4aa1d96 Mon Sep 17 00:00:00 2001 From: Brennan Date: Thu, 18 May 2023 09:17:30 -0700 Subject: [PATCH 19/25] stash fix bugs --- .../csharp/Client.Core/src/HubConnection.cs | 1 + .../FunctionalTests/HubConnectionTests.cs | 6 ++--- .../src/Internal/WebSocketsTransport.cs | 23 +++++++++++++++---- .../src/Internal/HttpConnectionContext.cs | 2 +- src/SignalR/common/Shared/MessageBuffer.cs | 5 ++-- .../Core/src/Internal/DefaultHubDispatcher.cs | 1 + 6 files changed, 28 insertions(+), 10 deletions(-) diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index 81672ccbd5cd..bdba65bb5900 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -1023,6 +1023,7 @@ private async Task SendWithLock(ConnectionState expectedConnectionState, HubMess { if (!_buffer.ShouldProcessMessage(hubInvocation)) { + _logger.LogInformation($"Dropped {hubInvocation.GetType().Name}. ID: {hubInvocation.InvocationId}"); return null; } } diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs index ce464f5b179c..fbb725d63776 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs @@ -2543,7 +2543,7 @@ public async Task ServerSentEventsWorksWithHttp2OnlyEndpoint() public async Task CanReconnectAndSendMessageWhileDisconnected() { var protocol = HubProtocols["json"]; - await using (var server = await StartServer()) + await using (var server = await StartServer(w => w.EventId.Name == "ReceivedUnexpectedResponse")) { var websocket = new ClientWebSocket(); var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); @@ -2599,11 +2599,11 @@ public async Task CanReconnectAndSendMessageWhileDisconnected() } [Fact] - [Repeat(500)] + [Repeat(1500)] public async Task CanReconnectAndSendMessageOnceConnected() { var protocol = HubProtocols["json"]; - await using (var server = await StartServer()) + await using (var server = await StartServer(w => w.EventId.Name == "ReceivedUnexpectedResponse")) { var websocket = new ClientWebSocket(); var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs index edbe3c9d2c5b..a043a7506270 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs @@ -305,6 +305,10 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio _transport = pair.Transport; _application = pair.Application; } + else + { + ignoreFirstCanceled = true; + } // TODO: Handle TCP connection errors // https://github.com/SignalR/SignalR/blob/1fba14fa3437e24c204dfaf8a18db3fce8acad3c/src/Microsoft.AspNet.SignalR.Core/Owin/WebSockets/WebSocketHandler.cs#L248-L251 @@ -325,7 +329,7 @@ private async Task ProcessSocketAsync(WebSocket socket, Uri url, bool ignoreFirs var trigger = await Task.WhenAny(receiving, sending).ConfigureAwait(false); _stopCts.CancelAfter(_closeTimeout); - + _logger.LogInformation("starting close"); if (trigger == receiving) { // We're waiting for the application to finish and there are 2 things it could be doing @@ -433,6 +437,7 @@ private async Task StartReceiving(WebSocket socket) // or if the consumer is done if (flushResult.IsCanceled || flushResult.IsCompleted) { + _logger.LogInformation("receive: pipe canceled or completed"); break; } } @@ -451,9 +456,10 @@ private async Task StartReceiving(WebSocket socket) } else { - _application.Output.CancelPendingFlush(); + //_application.Output.CancelPendingFlush(); } //_closed = true; + _logger.LogInformation(ex, "receive error"); } } finally @@ -487,6 +493,7 @@ private async Task StartSending(WebSocket socket, bool ignoreFirstCanceled) { if (result.IsCanceled && !ignoreFirstCanceled) { + _logger.LogInformation("send canceled"); break; } @@ -519,6 +526,7 @@ private async Task StartSending(WebSocket socket, bool ignoreFirstCanceled) } else if (result.IsCompleted) { + _logger.LogInformation("send: pipe result completed"); break; } } @@ -560,7 +568,14 @@ private async Task StartSending(WebSocket socket, bool ignoreFirstCanceled) { _application.Input.Complete(error); } - // TODO: log error in else? + else + { + if (error is not null) + { + // TODO: log error in else? + _logger.LogInformation(error, "send error"); + } + } Log.SendStopped(_logger); } @@ -638,7 +653,7 @@ private void UpdateConnectionPair() _application = applicationToTransport; _transport = transportToApplication; - prevPipe.Complete(new Exception()); + prevPipe.Complete(new ConnectionResetException("")); _notifyOnReconnect.Invoke(); } diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs index 45670068696d..4626aa0024b9 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs @@ -666,7 +666,7 @@ private void UpdateConnectionPair() Application = applicationToTransport; Transport = transportToApplication; - prevPipe.Complete(new Exception()); + prevPipe.Complete(new ConnectionResetException("")); Features.GetRequiredFeature().NotifyOnReconnect?.Invoke(); } diff --git a/src/SignalR/common/Shared/MessageBuffer.cs b/src/SignalR/common/Shared/MessageBuffer.cs index 592aaec98a3c..a91971b87919 100644 --- a/src/SignalR/common/Shared/MessageBuffer.cs +++ b/src/SignalR/common/Shared/MessageBuffer.cs @@ -68,7 +68,7 @@ public async ValueTask WriteAsync(SerializedHubMessage hubMessage, _index = (_index + 1) % _buffer.Length; return await _connection.Transport.Output.WriteAsync(hubMessage.GetSerializedMessage(protocol), cancellationToken).ConfigureAwait(false); } - catch (Exception ex) + catch (ConnectionResetException ex) { // TODO: specific exception or some identifier needed @@ -98,7 +98,8 @@ public void Ack(AckMessage ackMessage) // Release backpressure? } - private long _currentReceivingSequenceId; + // Message IDs start at 1 and always increment by 1 + private long _currentReceivingSequenceId = 1; private long _latestReceivedSequenceId = long.MinValue; internal bool ShouldProcessMessage(HubInvocationMessage message) diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 1ff4b67d24a7..9f38cc0ab7dc 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -139,6 +139,7 @@ public override Task DispatchMessageAsync(HubConnectionContext connection, HubMe { if (!connection.ShouldProcessMessage(invocation)) { + _logger.LogInformation($"dropping {invocation.GetType().Name}. ID: {invocation.InvocationId}"); return Task.CompletedTask; } } From 47248e04162b94a57a995b6826acec33f15580ba Mon Sep 17 00:00:00 2001 From: Brennan Conroy Date: Thu, 18 May 2023 17:32:03 -0700 Subject: [PATCH 20/25] acking --- .../csharp/Client.Core/src/HubConnection.cs | 51 +++++-- .../FunctionalTests/HubConnectionTests.cs | 2 +- src/SignalR/common/Shared/MessageBuffer.cs | 139 ++++++++++++++---- .../SignalR.Common/src/Protocol/AckMessage.cs | 4 +- .../DefaultHubDispatcherBenchmark.cs | 1 + .../server/Core/src/HubConnectionContext.cs | 5 +- .../Core/src/Internal/DefaultHubDispatcher.cs | 2 + 7 files changed, 158 insertions(+), 46 deletions(-) diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index bdba65bb5900..bbc961f95fd9 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -8,6 +8,7 @@ using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.IO; +using System.IO.Pipelines; using System.Linq; using System.Net; using System.Reflection; @@ -80,7 +81,6 @@ public partial class HubConnection : IAsyncDisposable private readonly ReconnectingConnectionState _state; private bool _disposed; - private MessageBuffer? _buffer; /// /// Occurs when the connection is closed. The connection could be closed due to an error or due to either the server or client intentionally @@ -478,9 +478,6 @@ private async Task StartAsyncCore(CancellationToken cancellationToken) var connection = await _connectionFactory.ConnectAsync(_endPoint, cancellationToken).ConfigureAwait(false); var startingConnectionState = new ConnectionState(connection, this); - // TODO: probably go on ConnectionState - _buffer = new MessageBuffer(connection, _protocol); - // From here on, if an error occurs we need to shut down the connection because // we still own it. try @@ -958,7 +955,7 @@ private async Task SendHubMessage(ConnectionState connectionState, HubMessage hu var isAck = true; if (isAck) { - await _buffer.WriteAsync(new SerializedHubMessage(hubMessage), _protocol, cancellationToken).ConfigureAwait(false); + await connectionState.WriteAsync(new SerializedHubMessage(hubMessage), cancellationToken).ConfigureAwait(false); } else { @@ -1021,7 +1018,7 @@ private async Task SendWithLock(ConnectionState expectedConnectionState, HubMess if (true && message is HubInvocationMessage hubInvocation) { - if (!_buffer.ShouldProcessMessage(hubInvocation)) + if (!connectionState.ShouldProcessMessage(hubInvocation)) { _logger.LogInformation($"Dropped {hubInvocation.GetType().Name}. ID: {hubInvocation.InvocationId}"); return null; @@ -1080,10 +1077,12 @@ private async Task SendWithLock(ConnectionState expectedConnectionState, HubMess // timeout is reset above, on receiving any message break; case AckMessage ackMessage: - _buffer.Ack(ackMessage); + _logger.LogInformation("Received Ack with ID {id}", ackMessage.SequenceId); + connectionState.Ack(ackMessage); break; case SequenceMessage sequenceMessage: - _buffer.ResetSequence(sequenceMessage); + _logger.LogInformation("Received SequenceMessage with ID {id}", sequenceMessage.SequenceId); + connectionState.ResetSequence(sequenceMessage); break; default: throw new InvalidOperationException($"Unexpected message type: {message.GetType().FullName}"); @@ -1266,9 +1265,6 @@ private async Task HandshakeAsync(ConnectionState startingConnectionState, Cance Log.HandshakeComplete(_logger); - var f = startingConnectionState.Connection.Features.Get(); - f.NotifyOnReconnect = _buffer.Resend; - break; } } @@ -1859,6 +1855,8 @@ private sealed class ConnectionState : IInvocationBinder private long _nextActivationServerTimeout; private long _nextActivationSendPing; + private MessageBuffer? _buffer; + public ConnectionContext Connection { get; } public Task? ReceiveTask { get; set; } public Exception? CloseException { get; set; } @@ -1884,6 +1882,15 @@ public ConnectionState(ConnectionContext connection, HubConnection hubConnection _logger = _hubConnection._logger; _hasInherentKeepAlive = connection.Features.Get()?.HasInherentKeepAlive ?? false; + + var useAck = true; + if (useAck) + { + _buffer = new MessageBuffer(connection, hubConnection._protocol); + + var f = Connection.Features.Get(); + f.NotifyOnReconnect = _buffer.Resend; + } } public string GetNextId() => (++_nextInvocationId).ToString(CultureInfo.InvariantCulture); @@ -1969,6 +1976,8 @@ private async Task StopAsyncCore() { Log.Stopping(_logger); + _buffer.Dispose(); + // Complete our write pipe, which should cause everything to shut down Log.TerminatingReceiveLoop(_logger); Connection.Transport.Input.CancelPendingRead(); @@ -2000,6 +2009,26 @@ public async Task TimerLoop(TimerAwaitable timer) } } + public ValueTask WriteAsync(SerializedHubMessage message, CancellationToken cancellationToken) + { + return _buffer.WriteAsync(message, cancellationToken); + } + + public bool ShouldProcessMessage(HubInvocationMessage message) + { + return _buffer.ShouldProcessMessage(message); + } + + public void Ack(AckMessage ackMessage) + { + _buffer.Ack(ackMessage); + } + + public void ResetSequence(SequenceMessage sequenceMessage) + { + _buffer.ResetSequence(sequenceMessage); + } + public void ResetSendPing() { Volatile.Write(ref _nextActivationSendPing, (DateTime.UtcNow + _hubConnection.KeepAliveInterval).Ticks); diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs index fbb725d63776..087717353cf1 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs @@ -2670,7 +2670,7 @@ public async Task CanReconnectAndSendMessageOnceConnected() [Fact] [Repeat(500)] - public async Task ServerAbortsConnectionNoReconnectAttempted() + public async Task ServerAbortsConnectionWithAckingEnabledNoReconnectAttempted() { var protocol = HubProtocols["json"]; await using (var server = await StartServer()) diff --git a/src/SignalR/common/Shared/MessageBuffer.cs b/src/SignalR/common/Shared/MessageBuffer.cs index a91971b87919..e2639691336a 100644 --- a/src/SignalR/common/Shared/MessageBuffer.cs +++ b/src/SignalR/common/Shared/MessageBuffer.cs @@ -6,19 +6,33 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.SignalR.Protocol; namespace Microsoft.AspNetCore.SignalR.Internal; -internal sealed class MessageBuffer +internal sealed class MessageBuffer : IDisposable { private readonly (SerializedHubMessage? Message, long SequenceId)[] _buffer; private readonly ConnectionContext _connection; private readonly IHubProtocol _protocol; + private readonly AckMessage _ackMessage = new(0); + private readonly SequenceMessage _sequenceMessage = new(0); +#if NET8_0_OR_GREATER + private readonly PeriodicTimer _timer = new(TimeSpan.FromSeconds(1)); +#else + private readonly TimerAwaitable _timer = new(TimeSpan.FromSeconds(1), TimeSpan.FromSeconds(1)); +#endif + private readonly SemaphoreSlim _writeLock = new(1, 1); - private int _index; + private int _bufferIndex; private long _totalMessageCount; + // Message IDs start at 1 and always increment by 1 + private long _currentReceivingSequenceId = 1; + private long _latestReceivedSequenceId = long.MinValue; + private long _lastAckedId = long.MinValue; + private TaskCompletionSource _resend = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); // TODO: pass in limits @@ -33,27 +47,66 @@ public MessageBuffer(ConnectionContext connection, IHubProtocol protocol) _protocol = protocol; _resend.SetResult(new()); + +#if !NET8_0_OR_GREATER + _timer.Start(); +#endif + _ = RunTimer(); } - public async ValueTask WriteAsync(SerializedHubMessage hubMessage, IHubProtocol protocol, - CancellationToken cancellationToken) + private async Task RunTimer() + { + using (_timer) + { +#if NET8_0_OR_GREATER + while (await _timer.WaitForNextTickAsync().ConfigureAwait(false)) +#else + while (await _timer) +#endif + { + if (_lastAckedId < _latestReceivedSequenceId) + { + // TODO: consider a minimum time between sending these? + // If we only read and don't write, this approach isn't great + + var sequenceId = _latestReceivedSequenceId; + _ackMessage.SequenceId = sequenceId; + + await _writeLock.WaitAsync().ConfigureAwait(false); + try + { + _protocol.WriteMessage(_ackMessage, _connection.Transport.Output); + await _connection.Transport.Output.FlushAsync().ConfigureAwait(false); + _lastAckedId = sequenceId; + } + finally + { + _writeLock.Release(); + } + } + } + } + } + + public async ValueTask WriteAsync(SerializedHubMessage hubMessage, CancellationToken cancellationToken) { // No lock because this is always called in a single async loop? // And other methods don't affect the checks here? - // Sending ping does hit this method, but it shouldn't modify any state // TODO: Backpressure - if (_buffer[_index].Message is not null) + if (_buffer[_bufferIndex].Message is not null) { // ... } await _resend.Task.ConfigureAwait(false); + var waitForResend = false; + + await _writeLock.WaitAsync(cancellationToken: default).ConfigureAwait(false); try { - if (hubMessage.Message is HubInvocationMessage invocationMessage) { _totalMessageCount++; @@ -61,19 +114,25 @@ public async ValueTask WriteAsync(SerializedHubMessage hubMessage, else { // Non-ackable message, don't add to buffer - return await _connection.Transport.Output.WriteAsync(hubMessage.GetSerializedMessage(protocol), cancellationToken).ConfigureAwait(false); + return await _connection.Transport.Output.WriteAsync(hubMessage.GetSerializedMessage(_protocol), cancellationToken).ConfigureAwait(false); } - _buffer[_index] = (hubMessage, _totalMessageCount); - _index = (_index + 1) % _buffer.Length; - return await _connection.Transport.Output.WriteAsync(hubMessage.GetSerializedMessage(protocol), cancellationToken).ConfigureAwait(false); + _buffer[_bufferIndex] = (hubMessage, _totalMessageCount); + _bufferIndex = (_bufferIndex + 1) % _buffer.Length; + return await _connection.Transport.Output.WriteAsync(hubMessage.GetSerializedMessage(_protocol), cancellationToken).ConfigureAwait(false); } + // TODO: figure out what exception to use catch (ConnectionResetException ex) { - // TODO: specific exception or some identifier needed - - // wait for reconnect, send SequenceMessage, and then do resend loop + waitForResend = true; + } + finally + { + _writeLock.Release(); + } + if (waitForResend) + { var oldTcs = Interlocked.Exchange(ref _resend, new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously)); if (!oldTcs.Task.IsCompleted) { @@ -81,30 +140,33 @@ public async ValueTask WriteAsync(SerializedHubMessage hubMessage, } return await _resend.Task.ConfigureAwait(false); } + + throw new NotImplementedException("shouldn't reach here"); } public void Ack(AckMessage ackMessage) { - var index = _index; + // TODO: what if ackMessage.SequenceId is larger than last sent message? + + // Grabbing _bufferIndex unsynchronized should be fine, we might miss the most recent message but the client shouldn't be able to ack that yet + // Or in exceptional cases we could miss multiple messages, but the next ack will clear them + var index = _bufferIndex; for (var i = 0; i < _buffer.Length; i++) { var currentIndex = (index + i) % _buffer.Length; - if (_buffer[currentIndex].SequenceId <= ackMessage.SequenceId) + if (_buffer[currentIndex].Message is not null && _buffer[currentIndex].SequenceId <= ackMessage.SequenceId) { _buffer[currentIndex] = (null, long.MinValue); } + // TODO: figure out an early exit? } // Release backpressure? } - // Message IDs start at 1 and always increment by 1 - private long _currentReceivingSequenceId = 1; - private long _latestReceivedSequenceId = long.MinValue; - internal bool ShouldProcessMessage(HubInvocationMessage message) { - // TODO: if we're expecting a sequence message but get here we should error + // TODO: if we're expecting a sequence message but get here we should probably error var currentId = _currentReceivingSequenceId; _currentReceivingSequenceId++; @@ -146,23 +208,28 @@ private async Task DoResendAsync(TaskCompletionSource tcs) long latestAckedIndex = -1; for (var i = 0; i < _buffer.Length - 1; i++) { - if (_buffer[(_index + i + 1) % _buffer.Length].SequenceId > long.MinValue) + // TODO: this could grab the index of the just written message from WriteAsync which would result in the wrong value for latestAckedIndex if there are more than 1 messages buffered + if (_buffer[(_bufferIndex + i + 1) % _buffer.Length].SequenceId > long.MinValue) { - latestAckedIndex = (_index + i + 1) % _buffer.Length; + latestAckedIndex = (_bufferIndex + i + 1) % _buffer.Length; break; } } - if (latestAckedIndex == -1) - { - // no unacked messages, still send SequenceMessage? - } - FlushResult finalResult = new(); + await _writeLock.WaitAsync().ConfigureAwait(false); try { - _protocol.WriteMessage(new SequenceMessage(_buffer[latestAckedIndex].SequenceId), _connection.Transport.Output); - finalResult = await _connection.Transport.Output.FlushAsync().ConfigureAwait(false); + if (latestAckedIndex == -1) + { + // no unacked messages, still send SequenceMessage? + return; + } + + _sequenceMessage.SequenceId = _buffer[latestAckedIndex].SequenceId; + _protocol.WriteMessage(_sequenceMessage, _connection.Transport.Output); + // don't need to call flush just for the SequenceMessage if we're writing more messages + var shouldFlush = true; for (var i = 0; i < _buffer.Length; i++) { @@ -170,12 +237,18 @@ private async Task DoResendAsync(TaskCompletionSource tcs) if (item.SequenceId > long.MinValue) { finalResult = await _connection.Transport.Output.WriteAsync(item.Message!.GetSerializedMessage(_protocol)).ConfigureAwait(false); + shouldFlush = false; } else { break; } } + + if (shouldFlush) + { + finalResult = await _connection.Transport.Output.FlushAsync().ConfigureAwait(false); + } } catch (Exception ex) { @@ -183,7 +256,13 @@ private async Task DoResendAsync(TaskCompletionSource tcs) } finally { + _writeLock.Release(); tcs.TrySetResult(finalResult); } } + + public void Dispose() + { + ((IDisposable)_timer).Dispose(); + } } diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs index 3065673bfe0c..755e3ed5f563 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs @@ -26,7 +26,7 @@ public AckMessage(long sequenceId) /// /// /// - public long SequenceId { get; } + public long SequenceId { get; set; } } public sealed class SequenceMessage : HubMessage @@ -36,5 +36,5 @@ public SequenceMessage(long sequenceId) SequenceId = sequenceId; } - public long SequenceId { get; } + public long SequenceId { get; set; } } diff --git a/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs b/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs index bdbf8443d752..d0efd3d4489c 100644 --- a/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs +++ b/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs @@ -35,6 +35,7 @@ public void GlobalSetup() new HubContext(hubLifetimeManager), enableDetailedErrors: false, disableImplicitFromServiceParameters: true, + useAcks: false, new Logger>(NullLoggerFactory.Instance), hubFilters: null, hubLifetimeManager); diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index ad2ba3fe18f5..e003f5b9ec95 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -263,7 +263,7 @@ private ValueTask WriteCore(HubMessage message, CancellationToken c var isAck = true; if (isAck) { - return _messageBuffer.WriteAsync(new SerializedHubMessage(message), Protocol, cancellationToken); + return _messageBuffer.WriteAsync(new SerializedHubMessage(message), cancellationToken); } else { @@ -293,7 +293,7 @@ private ValueTask WriteCore(SerializedHubMessage message, Cancellat var isAck = true; if (isAck) { - return _messageBuffer.WriteAsync(message, Protocol, cancellationToken); + return _messageBuffer.WriteAsync(message, cancellationToken); } else { @@ -752,6 +752,7 @@ internal void StopClientTimeout() internal void Cleanup() { + _messageBuffer.Dispose(); _closedRegistration.Dispose(); _closedRequestedRegistration?.Dispose(); diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 9f38cc0ab7dc..a5ce40a49311 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -201,10 +201,12 @@ public override Task DispatchMessageAsync(HubConnectionContext connection, HubMe break; case AckMessage ackMessage: + _logger.LogInformation("received ack with id {id}", ackMessage.SequenceId); connection.Ack(ackMessage); break; case SequenceMessage sequenceMessage: + _logger.LogInformation("received sequence message with id {id}", sequenceMessage.SequenceId); connection.ResetSequence(sequenceMessage); break; From 2bf515ad42746ce7d5107e80e2a388285d8c2baa Mon Sep 17 00:00:00 2001 From: Brennan Conroy Date: Fri, 19 May 2023 15:40:57 -0700 Subject: [PATCH 21/25] more cleanup and bug fixes --- .../src/Features/IReconnectFeature.cs | 9 +++ .../PublicAPI/net462/PublicAPI.Unshipped.txt | 3 + .../PublicAPI/net8.0/PublicAPI.Unshipped.txt | 3 + .../netstandard2.0/PublicAPI.Unshipped.txt | 3 + .../netstandard2.1/PublicAPI.Unshipped.txt | 3 + .../csharp/Client.Core/src/HubConnection.cs | 50 ++++++++------ .../src/HttpConnection.cs | 2 +- .../src/Internal/WebSocketsTransport.cs | 1 + .../src/Internal/HttpConnectionContext.cs | 13 ++-- .../src/Internal/HttpConnectionManager.cs | 1 - .../Transports/WebSocketsServerTransport.cs | 1 - src/SignalR/common/Shared/MessageBuffer.cs | 67 ++++++++++++++++--- .../SignalR.Common/src/Protocol/AckMessage.cs | 12 +++- .../src/Protocol/HubProtocolConstants.cs | 3 + .../PublicAPI/net462/PublicAPI.Unshipped.txt | 18 ++--- .../PublicAPI/net8.0/PublicAPI.Unshipped.txt | 10 +++ .../netstandard2.0/PublicAPI.Unshipped.txt | 10 +++ .../server/Core/src/HubConnectionContext.cs | 30 ++++----- .../server/Core/src/HubConnectionHandler.cs | 3 +- src/SignalR/server/Core/src/HubOptions.cs | 2 - .../Core/src/Internal/DefaultHubDispatcher.cs | 19 +++--- 21 files changed, 183 insertions(+), 80 deletions(-) diff --git a/src/Servers/Connections.Abstractions/src/Features/IReconnectFeature.cs b/src/Servers/Connections.Abstractions/src/Features/IReconnectFeature.cs index 94e30d62366a..81a9ebe587a2 100644 --- a/src/Servers/Connections.Abstractions/src/Features/IReconnectFeature.cs +++ b/src/Servers/Connections.Abstractions/src/Features/IReconnectFeature.cs @@ -9,7 +9,16 @@ namespace Microsoft.AspNetCore.Connections.Abstractions; +/// +/// +/// public interface IReconnectFeature { + /// + /// + /// public Action NotifyOnReconnect { get; set; } + + // TODO + // void DisableReconnect(); } diff --git a/src/Servers/Connections.Abstractions/src/PublicAPI/net462/PublicAPI.Unshipped.txt b/src/Servers/Connections.Abstractions/src/PublicAPI/net462/PublicAPI.Unshipped.txt index 3e85ed9e89fc..39ed42614d79 100644 --- a/src/Servers/Connections.Abstractions/src/PublicAPI/net462/PublicAPI.Unshipped.txt +++ b/src/Servers/Connections.Abstractions/src/PublicAPI/net462/PublicAPI.Unshipped.txt @@ -1,4 +1,7 @@ #nullable enable +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature.NotifyOnReconnect.get -> System.Action! +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature.NotifyOnReconnect.set -> void Microsoft.AspNetCore.Connections.Features.IConnectionMetricsTagsFeature Microsoft.AspNetCore.Connections.Features.IConnectionMetricsTagsFeature.Tags.get -> System.Collections.Generic.ICollection>! Microsoft.AspNetCore.Connections.Features.IConnectionNamedPipeFeature diff --git a/src/Servers/Connections.Abstractions/src/PublicAPI/net8.0/PublicAPI.Unshipped.txt b/src/Servers/Connections.Abstractions/src/PublicAPI/net8.0/PublicAPI.Unshipped.txt index 9f1d00bb09ad..9e361db01313 100644 --- a/src/Servers/Connections.Abstractions/src/PublicAPI/net8.0/PublicAPI.Unshipped.txt +++ b/src/Servers/Connections.Abstractions/src/PublicAPI/net8.0/PublicAPI.Unshipped.txt @@ -1,4 +1,7 @@ #nullable enable +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature.NotifyOnReconnect.get -> System.Action! +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature.NotifyOnReconnect.set -> void Microsoft.AspNetCore.Connections.Features.IConnectionMetricsTagsFeature Microsoft.AspNetCore.Connections.Features.IConnectionMetricsTagsFeature.Tags.get -> System.Collections.Generic.ICollection>! Microsoft.AspNetCore.Connections.Features.IConnectionNamedPipeFeature diff --git a/src/Servers/Connections.Abstractions/src/PublicAPI/netstandard2.0/PublicAPI.Unshipped.txt b/src/Servers/Connections.Abstractions/src/PublicAPI/netstandard2.0/PublicAPI.Unshipped.txt index 3e85ed9e89fc..39ed42614d79 100644 --- a/src/Servers/Connections.Abstractions/src/PublicAPI/netstandard2.0/PublicAPI.Unshipped.txt +++ b/src/Servers/Connections.Abstractions/src/PublicAPI/netstandard2.0/PublicAPI.Unshipped.txt @@ -1,4 +1,7 @@ #nullable enable +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature.NotifyOnReconnect.get -> System.Action! +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature.NotifyOnReconnect.set -> void Microsoft.AspNetCore.Connections.Features.IConnectionMetricsTagsFeature Microsoft.AspNetCore.Connections.Features.IConnectionMetricsTagsFeature.Tags.get -> System.Collections.Generic.ICollection>! Microsoft.AspNetCore.Connections.Features.IConnectionNamedPipeFeature diff --git a/src/Servers/Connections.Abstractions/src/PublicAPI/netstandard2.1/PublicAPI.Unshipped.txt b/src/Servers/Connections.Abstractions/src/PublicAPI/netstandard2.1/PublicAPI.Unshipped.txt index 3e85ed9e89fc..39ed42614d79 100644 --- a/src/Servers/Connections.Abstractions/src/PublicAPI/netstandard2.1/PublicAPI.Unshipped.txt +++ b/src/Servers/Connections.Abstractions/src/PublicAPI/netstandard2.1/PublicAPI.Unshipped.txt @@ -1,4 +1,7 @@ #nullable enable +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature.NotifyOnReconnect.get -> System.Action! +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature.NotifyOnReconnect.set -> void Microsoft.AspNetCore.Connections.Features.IConnectionMetricsTagsFeature Microsoft.AspNetCore.Connections.Features.IConnectionMetricsTagsFeature.Tags.get -> System.Collections.Generic.ICollection>! Microsoft.AspNetCore.Connections.Features.IConnectionNamedPipeFeature diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index bbc961f95fd9..190d30e1d8e2 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -951,9 +951,7 @@ private async Task SendHubMessage(ConnectionState connectionState, HubMessage hu Log.SendingMessage(_logger, hubMessage); - // TODO - var isAck = true; - if (isAck) + if (connectionState.UsingAcks()) { await connectionState.WriteAsync(new SerializedHubMessage(hubMessage), cancellationToken).ConfigureAwait(false); } @@ -1016,11 +1014,11 @@ private async Task SendWithLock(ConnectionState expectedConnectionState, HubMess Log.ResettingKeepAliveTimer(_logger); connectionState.ResetTimeout(); - if (true && message is HubInvocationMessage hubInvocation) + if (connectionState.UsingAcks()) { - if (!connectionState.ShouldProcessMessage(hubInvocation)) + if (!connectionState.ShouldProcessMessage(message)) { - _logger.LogInformation($"Dropped {hubInvocation.GetType().Name}. ID: {hubInvocation.InvocationId}"); + _logger.LogInformation($"Dropped {((HubInvocationMessage)message).GetType().Name}. ID: {((HubInvocationMessage)message).InvocationId}"); return null; } } @@ -1078,11 +1076,17 @@ private async Task SendWithLock(ConnectionState expectedConnectionState, HubMess break; case AckMessage ackMessage: _logger.LogInformation("Received Ack with ID {id}", ackMessage.SequenceId); - connectionState.Ack(ackMessage); + if (connectionState.UsingAcks()) + { + connectionState.Ack(ackMessage); + } break; case SequenceMessage sequenceMessage: _logger.LogInformation("Received SequenceMessage with ID {id}", sequenceMessage.SequenceId); - connectionState.ResetSequence(sequenceMessage); + if (connectionState.UsingAcks()) + { + connectionState.ResetSequence(sequenceMessage); + } break; default: throw new InvalidOperationException($"Unexpected message type: {message.GetType().FullName}"); @@ -1843,6 +1847,7 @@ private sealed class ConnectionState : IInvocationBinder private readonly HubConnection _hubConnection; private readonly ILogger _logger; private readonly bool _hasInherentKeepAlive; + private readonly MessageBuffer? _messageBuffer; private readonly object _lock = new object(); private readonly Dictionary _pendingCalls = new Dictionary(StringComparer.Ordinal); @@ -1855,8 +1860,6 @@ private sealed class ConnectionState : IInvocationBinder private long _nextActivationServerTimeout; private long _nextActivationSendPing; - private MessageBuffer? _buffer; - public ConnectionContext Connection { get; } public Task? ReceiveTask { get; set; } public Exception? CloseException { get; set; } @@ -1883,13 +1886,11 @@ public ConnectionState(ConnectionContext connection, HubConnection hubConnection _logger = _hubConnection._logger; _hasInherentKeepAlive = connection.Features.Get()?.HasInherentKeepAlive ?? false; - var useAck = true; - if (useAck) + if (Connection.Features.Get() is IReconnectFeature feature) { - _buffer = new MessageBuffer(connection, hubConnection._protocol); + _messageBuffer = new MessageBuffer(connection, hubConnection._protocol); - var f = Connection.Features.Get(); - f.NotifyOnReconnect = _buffer.Resend; + feature.NotifyOnReconnect = _messageBuffer.Resend; } } @@ -1976,7 +1977,7 @@ private async Task StopAsyncCore() { Log.Stopping(_logger); - _buffer.Dispose(); + _messageBuffer?.Dispose(); // Complete our write pipe, which should cause everything to shut down Log.TerminatingReceiveLoop(_logger); @@ -2011,24 +2012,31 @@ public async Task TimerLoop(TimerAwaitable timer) public ValueTask WriteAsync(SerializedHubMessage message, CancellationToken cancellationToken) { - return _buffer.WriteAsync(message, cancellationToken); + Debug.Assert(_messageBuffer is not null); + return _messageBuffer.WriteAsync(message, cancellationToken); } - public bool ShouldProcessMessage(HubInvocationMessage message) + public bool ShouldProcessMessage(HubMessage message) { - return _buffer.ShouldProcessMessage(message); + Debug.Assert(_messageBuffer is not null); + return _messageBuffer.ShouldProcessMessage(message); } public void Ack(AckMessage ackMessage) { - _buffer.Ack(ackMessage); + Debug.Assert(_messageBuffer is not null); + _messageBuffer.Ack(ackMessage); } public void ResetSequence(SequenceMessage sequenceMessage) { - _buffer.ResetSequence(sequenceMessage); + Debug.Assert(_messageBuffer is not null); + _messageBuffer.ResetSequence(sequenceMessage); } + [MemberNotNullWhen(true, nameof(_messageBuffer))] + public bool UsingAcks() => _messageBuffer is not null; + public void ResetSendPing() { Volatile.Write(ref _nextActivationSendPing, (DateTime.UtcNow + _hubConnection.KeepAliveInterval).Ticks); diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs index aa08bfd42e77..57ebab814f71 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs @@ -531,7 +531,7 @@ private async Task StartTransport(Uri connectUrl, HttpTransportType transportTyp // We successfully started, set the transport properties (we don't want to set these until the transport is definitely running). _transport = transport; - if (_httpConnectionOptions.UseAcks && _transport is IReconnectFeature reconnectFeature) + if (useAck && _transport is IReconnectFeature reconnectFeature) { Features.Set(reconnectFeature); } diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs index a043a7506270..b193df958f5c 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs @@ -655,6 +655,7 @@ private void UpdateConnectionPair() prevPipe.Complete(new ConnectionResetException("")); + Debug.Assert(_notifyOnReconnect is not null); _notifyOnReconnect.Invoke(); } } diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs index 4626aa0024b9..0f3ceaafd62d 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs @@ -19,11 +19,6 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal; -internal sealed class Reconnect : IReconnectFeature -{ - public Action NotifyOnReconnect { get; set; } -} - internal sealed partial class HttpConnectionContext : ConnectionContext, IConnectionIdFeature, IConnectionItemsFeature, @@ -35,7 +30,8 @@ internal sealed partial class HttpConnectionContext : ConnectionContext, IHttpTransportFeature, IConnectionInherentKeepAliveFeature, IConnectionLifetimeFeature, - IConnectionLifetimeNotificationFeature + IConnectionLifetimeNotificationFeature, + IReconnectFeature { private readonly HttpConnectionDispatcherOptions _options; @@ -100,8 +96,7 @@ public HttpConnectionContext(string connectionId, string connectionToken, ILogge if (useAcks) { - var reconnectFeature = new Reconnect(); - Features.Set(reconnectFeature); + Features.Set(this); } _connectionClosedTokenSource = new CancellationTokenSource(); @@ -207,6 +202,8 @@ public IDuplexPipe Application public CancellationToken ConnectionClosedRequested { get; set; } + public Action NotifyOnReconnect { get; set; } = () => { }; + public override void Abort() { ThreadPool.UnsafeQueueUserWorkItem(cts => ((CancellationTokenSource)cts!).Cancel(), _connectionClosedTokenSource); diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs index 3086ee363ba1..c970679841ce 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs @@ -4,7 +4,6 @@ using System.Collections.Concurrent; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; -using System.IO.Pipelines; using System.Net.WebSockets; using System.Security.Cryptography; using Microsoft.Extensions.Hosting; diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs index 693f8b2fa4f7..c535ac180714 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Buffers; using System.Diagnostics; using System.IO.Pipelines; using System.Net.WebSockets; diff --git a/src/SignalR/common/Shared/MessageBuffer.cs b/src/SignalR/common/Shared/MessageBuffer.cs index e2639691336a..8380578dfbf5 100644 --- a/src/SignalR/common/Shared/MessageBuffer.cs +++ b/src/SignalR/common/Shared/MessageBuffer.cs @@ -4,6 +4,7 @@ using System; using System.IO.Pipelines; using System.Threading; +using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Internal; @@ -18,6 +19,8 @@ internal sealed class MessageBuffer : IDisposable private readonly IHubProtocol _protocol; private readonly AckMessage _ackMessage = new(0); private readonly SequenceMessage _sequenceMessage = new(0); + private readonly Channel _waitForAck = Channel.CreateBounded(new BoundedChannelOptions(1) { FullMode = BoundedChannelFullMode.DropOldest }); + #if NET8_0_OR_GREATER private readonly PeriodicTimer _timer = new(TimeSpan.FromSeconds(1)); #else @@ -27,6 +30,7 @@ internal sealed class MessageBuffer : IDisposable private int _bufferIndex; private long _totalMessageCount; + private bool _waitForSequenceMessage; // Message IDs start at 1 and always increment by 1 private long _currentReceivingSequenceId = 1; @@ -67,9 +71,9 @@ private async Task RunTimer() if (_lastAckedId < _latestReceivedSequenceId) { // TODO: consider a minimum time between sending these? - // If we only read and don't write, this approach isn't great var sequenceId = _latestReceivedSequenceId; + _ackMessage.SequenceId = sequenceId; await _writeLock.WaitAsync().ConfigureAwait(false); @@ -88,18 +92,28 @@ private async Task RunTimer() } } + // TODO: WriteAsync(HubMessage) overload, so we don't allocate SerializedHubMessage for messages that aren't going to be buffered public async ValueTask WriteAsync(SerializedHubMessage hubMessage, CancellationToken cancellationToken) { - // No lock because this is always called in a single async loop? - // And other methods don't affect the checks here? - - // TODO: Backpressure - + // TODO: Backpressure based on message count and total message size if (_buffer[_bufferIndex].Message is not null) { - // ... + // primitive backpressure if buffer is full + while (await _waitForAck.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) + { + if (_waitForAck.Reader.TryRead(out var index) + && (index == _bufferIndex || _buffer[_bufferIndex].Message is null)) + { + break; + } + } } + // Avoid condition where last Ack position is the position we're currently writing into the buffer + // If we wrote messages around the entire buffer before another Ack arrived we would end up reading the Ack position and writing over a buffered message + _waitForAck.Reader.TryRead(out _); + + // TODO: We could consider buffering messages until they hit backpressure in the case when the connection is down await _resend.Task.ConfigureAwait(false); var waitForResend = false; @@ -119,10 +133,11 @@ public async ValueTask WriteAsync(SerializedHubMessage hubMessage, _buffer[_bufferIndex] = (hubMessage, _totalMessageCount); _bufferIndex = (_bufferIndex + 1) % _buffer.Length; + return await _connection.Transport.Output.WriteAsync(hubMessage.GetSerializedMessage(_protocol), cancellationToken).ConfigureAwait(false); } // TODO: figure out what exception to use - catch (ConnectionResetException ex) + catch (ConnectionResetException) { waitForResend = true; } @@ -151,22 +166,48 @@ public void Ack(AckMessage ackMessage) // Grabbing _bufferIndex unsynchronized should be fine, we might miss the most recent message but the client shouldn't be able to ack that yet // Or in exceptional cases we could miss multiple messages, but the next ack will clear them var index = _bufferIndex; + var finalIndex = -1; for (var i = 0; i < _buffer.Length; i++) { var currentIndex = (index + i) % _buffer.Length; if (_buffer[currentIndex].Message is not null && _buffer[currentIndex].SequenceId <= ackMessage.SequenceId) { _buffer[currentIndex] = (null, long.MinValue); + finalIndex = currentIndex; } // TODO: figure out an early exit? } - // Release backpressure? + // Release backpressure + if (finalIndex > 0) + { + _waitForAck.Writer.TryWrite(finalIndex); + } } - internal bool ShouldProcessMessage(HubInvocationMessage message) + internal bool ShouldProcessMessage(HubMessage message) { - // TODO: if we're expecting a sequence message but get here we should probably error + // TODO: if we're expecting a sequence message but get here should we error or ignore? + if (_waitForSequenceMessage) + { + if (message is SequenceMessage) + { + _waitForSequenceMessage = false; + return true; + } + else + { + // ignore messages received while waiting for sequence message + return false; + } + } + + // Only care about messages implementing HubInvocationMessage currently (e.g. ignore ping, close, ack, sequence) + // Could expand in the future, but should probably rev the ack version if changes are made + if (message is not HubInvocationMessage) + { + return true; + } var currentId = _currentReceivingSequenceId; _currentReceivingSequenceId++; @@ -193,10 +234,14 @@ internal void ResetSequence(SequenceMessage sequenceMessage) internal void Resend() { + _waitForSequenceMessage = true; + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var oldTcs = Interlocked.Exchange(ref _resend, tcs); + // WriteAsync can also try to swap the TCS, we need to check if it's completed to know if it was swapped or not if (!oldTcs.Task.IsCompleted) { + // Swap back to the TCS created by WriteAsync since it's waiting on the result of that task Interlocked.Exchange(ref _resend, oldTcs); tcs = oldTcs; } diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs index 755e3ed5f563..33ef91881fa9 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs @@ -10,7 +10,7 @@ namespace Microsoft.AspNetCore.SignalR.Protocol; /// -/// +/// Represents the ID being acknowledged so we can stop buffering older messages. /// public sealed class AckMessage : HubMessage { @@ -29,12 +29,22 @@ public AckMessage(long sequenceId) public long SequenceId { get; set; } } +/// +/// Represents the restart of the sequence of messages being sent. is the starting ID of messages being sent, which might be duplicate messages. +/// public sealed class SequenceMessage : HubMessage { + /// + /// + /// + /// public SequenceMessage(long sequenceId) { SequenceId = sequenceId; } + /// + /// + /// public long SequenceId { get; set; } } diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/HubProtocolConstants.cs b/src/SignalR/common/SignalR.Common/src/Protocol/HubProtocolConstants.cs index 0d32dbc2f235..c5e67987ae92 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/HubProtocolConstants.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/HubProtocolConstants.cs @@ -48,5 +48,8 @@ public static class HubProtocolConstants /// public const int AckMessageType = 8; + /// + /// + /// public const int SequenceMessageType = 9; } diff --git a/src/SignalR/common/SignalR.Common/src/PublicAPI/net462/PublicAPI.Unshipped.txt b/src/SignalR/common/SignalR.Common/src/PublicAPI/net462/PublicAPI.Unshipped.txt index 0afbd3fec7cc..29b26f2e7839 100644 --- a/src/SignalR/common/SignalR.Common/src/PublicAPI/net462/PublicAPI.Unshipped.txt +++ b/src/SignalR/common/SignalR.Common/src/PublicAPI/net462/PublicAPI.Unshipped.txt @@ -1,11 +1,11 @@ #nullable enable +const Microsoft.AspNetCore.SignalR.Protocol.HubProtocolConstants.AckMessageType = 8 -> int +const Microsoft.AspNetCore.SignalR.Protocol.HubProtocolConstants.SequenceMessageType = 9 -> int Microsoft.AspNetCore.SignalR.Protocol.AckMessage -Microsoft.AspNetCore.SignalR.Protocol.AckMessage.AckMessage(string! sequenceId) -> void -Microsoft.AspNetCore.SignalR.Protocol.AckMessage.SequenceId.get -> string! -Microsoft.AspNetCore.SignalR.Protocol.CancelInvocationMessage.CancelInvocationMessage(string! invocationId, string! sequenceId) -> void -Microsoft.AspNetCore.SignalR.Protocol.CompletionMessage.CompletionMessage(string! invocationId, string! sequenceId, string? error, object? result, bool hasResult) -> void -Microsoft.AspNetCore.SignalR.Protocol.HubInvocationMessage.HubInvocationMessage(string? invocationId, string! sequenceId) -> void -Microsoft.AspNetCore.SignalR.Protocol.HubInvocationMessage.SequenceId.get -> string? -Microsoft.AspNetCore.SignalR.Protocol.HubMethodInvocationMessage.HubMethodInvocationMessage(string? invocationId, string! sequenceId, string! target, object?[]! arguments, string![]? streamIds) -> void -Microsoft.AspNetCore.SignalR.Protocol.StreamInvocationMessage.StreamInvocationMessage(string! invocationId, string! sequenceId, string! target, object?[]! arguments, string![]? streamIds) -> void -Microsoft.AspNetCore.SignalR.Protocol.StreamItemMessage.StreamItemMessage(string! invocationId, string! sequenceId, object? item) -> void +Microsoft.AspNetCore.SignalR.Protocol.AckMessage.AckMessage(long sequenceId) -> void +Microsoft.AspNetCore.SignalR.Protocol.AckMessage.SequenceId.get -> long +Microsoft.AspNetCore.SignalR.Protocol.AckMessage.SequenceId.set -> void +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage.SequenceId.get -> long +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage.SequenceId.set -> void +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage.SequenceMessage(long sequenceId) -> void diff --git a/src/SignalR/common/SignalR.Common/src/PublicAPI/net8.0/PublicAPI.Unshipped.txt b/src/SignalR/common/SignalR.Common/src/PublicAPI/net8.0/PublicAPI.Unshipped.txt index 7dc5c58110bf..29b26f2e7839 100644 --- a/src/SignalR/common/SignalR.Common/src/PublicAPI/net8.0/PublicAPI.Unshipped.txt +++ b/src/SignalR/common/SignalR.Common/src/PublicAPI/net8.0/PublicAPI.Unshipped.txt @@ -1 +1,11 @@ #nullable enable +const Microsoft.AspNetCore.SignalR.Protocol.HubProtocolConstants.AckMessageType = 8 -> int +const Microsoft.AspNetCore.SignalR.Protocol.HubProtocolConstants.SequenceMessageType = 9 -> int +Microsoft.AspNetCore.SignalR.Protocol.AckMessage +Microsoft.AspNetCore.SignalR.Protocol.AckMessage.AckMessage(long sequenceId) -> void +Microsoft.AspNetCore.SignalR.Protocol.AckMessage.SequenceId.get -> long +Microsoft.AspNetCore.SignalR.Protocol.AckMessage.SequenceId.set -> void +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage.SequenceId.get -> long +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage.SequenceId.set -> void +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage.SequenceMessage(long sequenceId) -> void diff --git a/src/SignalR/common/SignalR.Common/src/PublicAPI/netstandard2.0/PublicAPI.Unshipped.txt b/src/SignalR/common/SignalR.Common/src/PublicAPI/netstandard2.0/PublicAPI.Unshipped.txt index 7dc5c58110bf..29b26f2e7839 100644 --- a/src/SignalR/common/SignalR.Common/src/PublicAPI/netstandard2.0/PublicAPI.Unshipped.txt +++ b/src/SignalR/common/SignalR.Common/src/PublicAPI/netstandard2.0/PublicAPI.Unshipped.txt @@ -1 +1,11 @@ #nullable enable +const Microsoft.AspNetCore.SignalR.Protocol.HubProtocolConstants.AckMessageType = 8 -> int +const Microsoft.AspNetCore.SignalR.Protocol.HubProtocolConstants.SequenceMessageType = 9 -> int +Microsoft.AspNetCore.SignalR.Protocol.AckMessage +Microsoft.AspNetCore.SignalR.Protocol.AckMessage.AckMessage(long sequenceId) -> void +Microsoft.AspNetCore.SignalR.Protocol.AckMessage.SequenceId.get -> long +Microsoft.AspNetCore.SignalR.Protocol.AckMessage.SequenceId.set -> void +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage.SequenceId.get -> long +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage.SequenceId.set -> void +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage.SequenceMessage(long sequenceId) -> void diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index e003f5b9ec95..d42201048ebe 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -52,7 +52,6 @@ public partial class HubConnectionContext private ClaimsPrincipal? _user; internal bool UseAcks; - private long _sequenceId; /// /// Initializes a new instance of the class. @@ -259,10 +258,9 @@ private ValueTask WriteCore(HubMessage message, CancellationToken c { try { - // TODO - var isAck = true; - if (isAck) + if (UseAcks) { + Debug.Assert(_messageBuffer is not null); return _messageBuffer.WriteAsync(new SerializedHubMessage(message), cancellationToken); } else @@ -289,10 +287,9 @@ private ValueTask WriteCore(SerializedHubMessage message, Cancellat { try { - // TODO - var isAck = true; - if (isAck) + if (UseAcks) { + Debug.Assert(_messageBuffer is not null); return _messageBuffer.WriteAsync(message, cancellationToken); } else @@ -574,9 +571,12 @@ await WriteHandshakeResponseAsync(new HandshakeResponseMessage( await WriteHandshakeResponseAsync(HandshakeResponseMessage.Empty); - _messageBuffer = new MessageBuffer(_connectionContext, Protocol); - var f = _connectionContext.Features.Get(); - f.NotifyOnReconnect = _messageBuffer.Resend; + if (_connectionContext.Features.Get() is IReconnectFeature feature) + { + UseAcks = true; + _messageBuffer = new MessageBuffer(_connectionContext, Protocol); + feature.NotifyOnReconnect = _messageBuffer.Resend; + } return true; } else if (overLength) @@ -752,7 +752,7 @@ internal void StopClientTimeout() internal void Cleanup() { - _messageBuffer.Dispose(); + _messageBuffer?.Dispose(); _closedRegistration.Dispose(); _closedRequestedRegistration?.Dispose(); @@ -762,19 +762,19 @@ internal void Cleanup() internal void Ack(AckMessage ackMessage) { - // Remove from ring buffer + Debug.Assert(_messageBuffer is not null); _messageBuffer.Ack(ackMessage); } - private long _currentReceivingSequenceId; - - internal bool ShouldProcessMessage(HubInvocationMessage message) + internal bool ShouldProcessMessage(HubMessage message) { + Debug.Assert(_messageBuffer is not null); return _messageBuffer.ShouldProcessMessage(message); } internal void ResetSequence(SequenceMessage sequenceMessage) { + Debug.Assert(_messageBuffer is not null); _messageBuffer.ResetSequence(sequenceMessage); } } diff --git a/src/SignalR/server/Core/src/HubConnectionHandler.cs b/src/SignalR/server/Core/src/HubConnectionHandler.cs index f0aa90491a34..21e9061897e8 100644 --- a/src/SignalR/server/Core/src/HubConnectionHandler.cs +++ b/src/SignalR/server/Core/src/HubConnectionHandler.cs @@ -70,7 +70,6 @@ IServiceScopeFactory serviceScopeFactory _enableDetailedErrors = _hubOptions.EnableDetailedErrors ?? _enableDetailedErrors; _maxParallelInvokes = _hubOptions.MaximumParallelInvocationsPerClient; disableImplicitFromServiceParameters = _hubOptions.DisableImplicitFromServicesParameters; - var _ = _hubOptions.UseAcks; if (_hubOptions.HubFilters != null) { @@ -83,7 +82,6 @@ IServiceScopeFactory serviceScopeFactory _enableDetailedErrors = _globalHubOptions.EnableDetailedErrors ?? _enableDetailedErrors; _maxParallelInvokes = _globalHubOptions.MaximumParallelInvocationsPerClient; disableImplicitFromServiceParameters = _globalHubOptions.DisableImplicitFromServicesParameters; - var _ = _globalHubOptions.UseAcks; if (_globalHubOptions.HubFilters != null) { @@ -96,6 +94,7 @@ IServiceScopeFactory serviceScopeFactory new HubContext(lifetimeManager), _enableDetailedErrors, disableImplicitFromServiceParameters, + // TODO useAcks: true, new Logger>(loggerFactory), hubFilters, diff --git a/src/SignalR/server/Core/src/HubOptions.cs b/src/SignalR/server/Core/src/HubOptions.cs index 9e2a2fbe979b..3a4e0883f884 100644 --- a/src/SignalR/server/Core/src/HubOptions.cs +++ b/src/SignalR/server/Core/src/HubOptions.cs @@ -79,6 +79,4 @@ public int MaximumParallelInvocationsPerClient /// False by default. Hub method arguments will be resolved from a DI container if possible. /// public bool DisableImplicitFromServicesParameters { get; set; } - - public bool UseAcks { get; set; } } diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index a5ce40a49311..cff89408367e 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -74,9 +74,6 @@ public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContex public override async Task OnConnectedAsync(HubConnectionContext connection) { - // TODO: figure out when this should be true - connection.UseAcks = true; - await using var scope = _serviceScopeFactory.CreateAsyncScope(); var hubActivator = scope.ServiceProvider.GetRequiredService>(); @@ -135,11 +132,11 @@ public override Task DispatchMessageAsync(HubConnectionContext connection, HubMe // With parallel invokes enabled, messages run sequentially until they go async and then the next message will be allowed to start running. - if (_useAcks && hubMessage is HubInvocationMessage invocation) + if (connection.UseAcks) { - if (!connection.ShouldProcessMessage(invocation)) + if (!connection.ShouldProcessMessage(hubMessage)) { - _logger.LogInformation($"dropping {invocation.GetType().Name}. ID: {invocation.InvocationId}"); + _logger.LogInformation($"dropping {((HubInvocationMessage)hubMessage).GetType().Name}. ID: {((HubInvocationMessage)hubMessage).InvocationId}"); return Task.CompletedTask; } } @@ -202,12 +199,18 @@ public override Task DispatchMessageAsync(HubConnectionContext connection, HubMe case AckMessage ackMessage: _logger.LogInformation("received ack with id {id}", ackMessage.SequenceId); - connection.Ack(ackMessage); + if (connection.UseAcks) + { + connection.Ack(ackMessage); + } break; case SequenceMessage sequenceMessage: _logger.LogInformation("received sequence message with id {id}", sequenceMessage.SequenceId); - connection.ResetSequence(sequenceMessage); + if (connection.UseAcks) + { + connection.ResetSequence(sequenceMessage); + } break; // Other kind of message we weren't expecting From 8972c04e74bdf98b158d1a6c9ac048b3c15bffa1 Mon Sep 17 00:00:00 2001 From: Brennan Date: Fri, 19 May 2023 19:13:17 -0700 Subject: [PATCH 22/25] fix exception --- src/SignalR/common/Shared/MessageBuffer.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/SignalR/common/Shared/MessageBuffer.cs b/src/SignalR/common/Shared/MessageBuffer.cs index 8380578dfbf5..7ba5774d0cc7 100644 --- a/src/SignalR/common/Shared/MessageBuffer.cs +++ b/src/SignalR/common/Shared/MessageBuffer.cs @@ -227,7 +227,7 @@ internal void ResetSequence(SequenceMessage sequenceMessage) if (sequenceMessage.SequenceId > _currentReceivingSequenceId) { - throw new Exception("Sequence ID greater than amount we've acked"); + throw new InvalidOperationException("Sequence ID greater than amount we've acked"); } _currentReceivingSequenceId = sequenceMessage.SequenceId; } From 76d8c7aa7cdd8dcdc0c1f36a28fb0a5aa415e809 Mon Sep 17 00:00:00 2001 From: Brennan Conroy Date: Mon, 22 May 2023 14:16:29 -0700 Subject: [PATCH 23/25] fb and cleanup --- .../Client.Core/src/HubConnection.Log.cs | 9 +++ .../csharp/Client.Core/src/HubConnection.cs | 45 +++++------ .../src/HttpConnection.cs | 2 +- .../src/Internal/WebSocketsTransport.Log.cs | 6 ++ .../src/Internal/WebSocketsTransport.cs | 15 ++-- .../src/Internal/HttpConnectionContext.cs | 5 +- .../src/Internal/HttpConnectionDispatcher.cs | 2 +- .../WebSocketsServerTransport.Log.cs | 3 + .../Transports/WebSocketsServerTransport.cs | 25 +++--- src/SignalR/common/Shared/MessageBuffer.cs | 36 +++++---- src/SignalR/docs/specs/TransportProtocols.md | 77 +------------------ .../server/Core/src/HubConnectionContext.cs | 30 +++++--- .../Core/src/Internal/DefaultHubDispatcher.cs | 23 ++---- .../src/Internal/DefaultHubDispatcherLog.cs | 9 +++ 14 files changed, 122 insertions(+), 165 deletions(-) diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs index 3aa896767775..070fb2d2d43c 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs @@ -325,5 +325,14 @@ public static void ErrorHandshakeTimedOut(ILogger logger, TimeSpan handshakeTime [LoggerMessage(89, LogLevel.Trace, "Error sending Completion message for stream '{StreamId}'.", EventName = "ErrorSendingStreamCompletion")] public static partial void ErrorSendingStreamCompletion(ILogger logger, string streamId, Exception exception); + + [LoggerMessage(90, LogLevel.Trace, "Dropping {MessageType} with ID '{InvocationId}'.", EventName = "DroppingMessage")] + public static partial void DroppingMessage(ILogger logger, string messageType, string? invocationId); + + [LoggerMessage(91, LogLevel.Trace, "Received AckMessage with Sequence ID '{SequenceId}'.", EventName = "ReceivedAckMessage")] + public static partial void ReceivedAckMessage(ILogger logger, long sequenceId); + + [LoggerMessage(92, LogLevel.Trace, "Received SequenceMessage with Sequence ID '{SequenceId}'.", EventName = "ReceivedSequenceMessage")] + public static partial void ReceivedSequenceMessage(ILogger logger, long sequenceId); } } diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index 190d30e1d8e2..f2b9b70eda1f 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -1014,13 +1014,9 @@ private async Task SendWithLock(ConnectionState expectedConnectionState, HubMess Log.ResettingKeepAliveTimer(_logger); connectionState.ResetTimeout(); - if (connectionState.UsingAcks()) + if (!connectionState.ShouldProcessMessage(message)) { - if (!connectionState.ShouldProcessMessage(message)) - { - _logger.LogInformation($"Dropped {((HubInvocationMessage)message).GetType().Name}. ID: {((HubInvocationMessage)message).InvocationId}"); - return null; - } + return null; } InvocationRequest? irq; @@ -1075,18 +1071,12 @@ private async Task SendWithLock(ConnectionState expectedConnectionState, HubMess // timeout is reset above, on receiving any message break; case AckMessage ackMessage: - _logger.LogInformation("Received Ack with ID {id}", ackMessage.SequenceId); - if (connectionState.UsingAcks()) - { - connectionState.Ack(ackMessage); - } + Log.ReceivedAckMessage(_logger, ackMessage.SequenceId); + connectionState.Ack(ackMessage); break; case SequenceMessage sequenceMessage: - _logger.LogInformation("Received SequenceMessage with ID {id}", sequenceMessage.SequenceId); - if (connectionState.UsingAcks()) - { - connectionState.ResetSequence(sequenceMessage); - } + Log.ReceivedSequenceMessage(_logger, sequenceMessage.SequenceId); + connectionState.ResetSequence(sequenceMessage); break; default: throw new InvalidOperationException($"Unexpected message type: {message.GetType().FullName}"); @@ -2018,20 +2008,31 @@ public ValueTask WriteAsync(SerializedHubMessage message, Cancellat public bool ShouldProcessMessage(HubMessage message) { - Debug.Assert(_messageBuffer is not null); - return _messageBuffer.ShouldProcessMessage(message); + if (UsingAcks()) + { + if (!_messageBuffer.ShouldProcessMessage(message)) + { + Log.DroppingMessage(_logger, ((HubInvocationMessage)message).GetType().Name, ((HubInvocationMessage)message).InvocationId); + return false; + } + } + return true; } public void Ack(AckMessage ackMessage) { - Debug.Assert(_messageBuffer is not null); - _messageBuffer.Ack(ackMessage); + if (UsingAcks()) + { + _messageBuffer.Ack(ackMessage); + } } public void ResetSequence(SequenceMessage sequenceMessage) { - Debug.Assert(_messageBuffer is not null); - _messageBuffer.ResetSequence(sequenceMessage); + if (UsingAcks()) + { + _messageBuffer.ResetSequence(sequenceMessage); + } } [MemberNotNullWhen(true, nameof(_messageBuffer))] diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs index 57ebab814f71..22c250f865f7 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs @@ -313,7 +313,7 @@ private async Task SelectAndStartTransport(TransferFormat transferFormat, Cancel if (_httpConnectionOptions.Transports == HttpTransportType.WebSockets) { Log.StartingTransport(_logger, _httpConnectionOptions.Transports, uri); - await StartTransport(uri, _httpConnectionOptions.Transports, transferFormat, cancellationToken, false).ConfigureAwait(false); + await StartTransport(uri, _httpConnectionOptions.Transports, transferFormat, cancellationToken, useAck: false).ConfigureAwait(false); } else { diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.Log.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.Log.cs index 1540844ec4b0..8792a85bd2a9 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.Log.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.Log.cs @@ -72,5 +72,11 @@ private static partial class Log [LoggerMessage(20, LogLevel.Warning, $"Configuring request headers using {nameof(HttpConnectionOptions)}.{nameof(HttpConnectionOptions.Headers)} is not supported when using websockets transport " + "on the browser platform.", EventName = "HeadersNotSupported")] public static partial void HeadersNotSupported(ILogger logger); + + [LoggerMessage(21, LogLevel.Debug, "Receive loop errored.", EventName = "ReceiveErrored")] + public static partial void ReceiveErrored(ILogger logger, Exception exception); + + [LoggerMessage(22, LogLevel.Debug, "Send loop errored.", EventName = "SendErrored")] + public static partial void SendErrored(ILogger logger, Exception exception); } } diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs index b193df958f5c..578ff3460bc5 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs @@ -329,7 +329,7 @@ private async Task ProcessSocketAsync(WebSocket socket, Uri url, bool ignoreFirs var trigger = await Task.WhenAny(receiving, sending).ConfigureAwait(false); _stopCts.CancelAfter(_closeTimeout); - _logger.LogInformation("starting close"); + if (trigger == receiving) { // We're waiting for the application to finish and there are 2 things it could be doing @@ -437,7 +437,6 @@ private async Task StartReceiving(WebSocket socket) // or if the consumer is done if (flushResult.IsCanceled || flushResult.IsCompleted) { - _logger.LogInformation("receive: pipe canceled or completed"); break; } } @@ -456,10 +455,9 @@ private async Task StartReceiving(WebSocket socket) } else { - //_application.Output.CancelPendingFlush(); + // only logging in this case because the other case gets the exception flowed to application code + Log.ReceiveErrored(_logger, ex); } - //_closed = true; - _logger.LogInformation(ex, "receive error"); } } finally @@ -511,7 +509,6 @@ private async Task StartSending(WebSocket socket, bool ignoreFirstCanceled) } else { - socket.Dispose(); break; } } @@ -526,7 +523,6 @@ private async Task StartSending(WebSocket socket, bool ignoreFirstCanceled) } else if (result.IsCompleted) { - _logger.LogInformation("send: pipe result completed"); break; } } @@ -572,8 +568,7 @@ private async Task StartSending(WebSocket socket, bool ignoreFirstCanceled) { if (error is not null) { - // TODO: log error in else? - _logger.LogInformation(error, "send error"); + Log.SendErrored(_logger, error); } } @@ -647,12 +642,14 @@ private void UpdateConnectionPair() var prevPipe = _application!.Input; var input = new Pipe(_httpConnectionOptions.TransportPipeOptions); + // Add new pipe for reading from and writing to transport from app code var transportToApplication = new DuplexPipe(_transport!.Input, input.Writer); var applicationToTransport = new DuplexPipe(input.Reader, _application!.Output); _application = applicationToTransport; _transport = transportToApplication; + // Close previous pipe with specific error that application code can catch to know a restart is occurring prevPipe.Complete(new ConnectionResetException("")); Debug.Assert(_notifyOnReconnect is not null); diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs index 0f3ceaafd62d..8ec21012979f 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs @@ -548,6 +548,7 @@ internal async Task CancelPreviousPoll(HttpContext context) // TODO: remove transport check once other transports support acks if (UseAcks && TransportType == HttpTransportType.WebSockets) { + // Break transport send loop in case it's still waiting on reading from the application Application.Input.CancelPendingRead(); UpdateConnectionPair(); } @@ -657,14 +658,16 @@ private void UpdateConnectionPair() var prevPipe = Application.Input; var input = new Pipe(_options.TransportPipeOptions); + // Add new pipe for reading from and writing to transport from app code var transportToApplication = new DuplexPipe(Transport.Input, input.Writer); var applicationToTransport = new DuplexPipe(input.Reader, Application.Output); Application = applicationToTransport; Transport = transportToApplication; + // Close previous pipe with specific error that application code can catch to know a restart is occurring prevPipe.Complete(new ConnectionResetException("")); - Features.GetRequiredFeature().NotifyOnReconnect?.Invoke(); + Features.GetRequiredFeature().NotifyOnReconnect.Invoke(); } private static partial class Log diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs index e50bd09fee47..de6cc55921a0 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs @@ -274,7 +274,7 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti { // If false then the transport was ungracefully closed, this can mean a temporary network disconnection // We'll mark the connection as inactive and allow the connection to reconnect if that's the case. - // TODO: If acks aren't enabled we can close the connection immediately + // TODO: If acks aren't enabled we can close the connection immediately (not LongPolling) if (await connection.TransportTask!) { await _manager.DisposeAndRemoveAsync(connection, closeGracefully: true, HttpConnectionStopStatus.NormalClosure); diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.Log.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.Log.cs index 026285075b7d..d085cf8b0f84 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.Log.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.Log.cs @@ -51,5 +51,8 @@ private static partial class Log [LoggerMessage(15, LogLevel.Debug, "Closing webSocket failed.", EventName = "ClosingWebSocketFailed")] public static partial void ClosingWebSocketFailed(ILogger logger, Exception ex); + + [LoggerMessage(16, LogLevel.Debug, "Send loop errored.", EventName = "SendErrored")] + public static partial void SendErrored(ILogger logger, Exception exception); } } diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs index c535ac180714..77892ab942dc 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs @@ -18,7 +18,7 @@ internal sealed partial class WebSocketsServerTransport : IHttpTransport private volatile bool _aborted; // Used to determine if the close was graceful or a network issue - private bool _closed; + private bool _gracefulClose; public WebSocketsServerTransport(WebSocketOptions options, IDuplexPipe application, HttpConnectionContext connection, ILoggerFactory loggerFactory) { @@ -54,7 +54,7 @@ public async Task ProcessRequestAsync(HttpContext context, CancellationTok } } - return _closed; + return _gracefulClose; } public async Task ProcessSocketAsync(WebSocket socket) @@ -141,7 +141,7 @@ private async Task StartReceiving(WebSocket socket) if (result.MessageType == WebSocketMessageType.Close) { - _closed = true; + _gracefulClose = true; return; } @@ -152,7 +152,7 @@ private async Task StartReceiving(WebSocket socket) // Need to check again for netcoreapp3.0 and later because a close can happen between a 0-byte read and the actual read if (receiveResult.MessageType == WebSocketMessageType.Close) { - _closed = true; + _gracefulClose = true; return; } @@ -183,13 +183,13 @@ private async Task StartReceiving(WebSocket socket) { if (!_aborted && !token.IsCancellationRequested) { - _closed = true; + _gracefulClose = true; _application.Output.Complete(ex); } } finally { - if (_closed) + if (_gracefulClose) { // We're done writing _application.Output.Complete(); @@ -241,7 +241,7 @@ private async Task StartSending(WebSocket socket, bool ignoreFirstCancel) } catch (OperationCanceledException ex) when (ex.CancellationToken == _connection.SendingToken) { - _closed = true; + _gracefulClose = true; // TODO: probably log break; } @@ -286,15 +286,14 @@ private async Task StartSending(WebSocket socket, bool ignoreFirstCancel) } } - if (_closed) + if (_gracefulClose) { _application.Input.Complete(error); } - // TODO - //else if (error is not null) - //{ - // _logger.LogError("Error in send {ex}.", error); - //} + else if (error is not null) + { + Log.SendErrored(_logger, error); + } } } diff --git a/src/SignalR/common/Shared/MessageBuffer.cs b/src/SignalR/common/Shared/MessageBuffer.cs index 7ba5774d0cc7..ef7a49b6c322 100644 --- a/src/SignalR/common/Shared/MessageBuffer.cs +++ b/src/SignalR/common/Shared/MessageBuffer.cs @@ -14,6 +14,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal; internal sealed class MessageBuffer : IDisposable { + private static readonly TaskCompletionSource _completedTCS = new TaskCompletionSource(); + private readonly (SerializedHubMessage? Message, long SequenceId)[] _buffer; private readonly ConnectionContext _connection; private readonly IHubProtocol _protocol; @@ -37,12 +39,19 @@ internal sealed class MessageBuffer : IDisposable private long _latestReceivedSequenceId = long.MinValue; private long _lastAckedId = long.MinValue; - private TaskCompletionSource _resend = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + private TaskCompletionSource _resend = _completedTCS; + + static MessageBuffer() + { + _completedTCS.SetResult(new()); + } // TODO: pass in limits public MessageBuffer(ConnectionContext connection, IHubProtocol protocol) { - _buffer = new (SerializedHubMessage? Message, long SequenceId)[10]; + // Arbitrary size, we can figure out defaults and configurability later + const int bufferSize = 10; + _buffer = new (SerializedHubMessage? Message, long SequenceId)[bufferSize]; for (var i = 0; i < _buffer.Length; i++) { _buffer[i].SequenceId = long.MinValue; @@ -50,8 +59,6 @@ public MessageBuffer(ConnectionContext connection, IHubProtocol protocol) _connection = connection; _protocol = protocol; - _resend.SetResult(new()); - #if !NET8_0_OR_GREATER _timer.Start(); #endif @@ -250,21 +257,20 @@ internal void Resend() private async Task DoResendAsync(TaskCompletionSource tcs) { - long latestAckedIndex = -1; - for (var i = 0; i < _buffer.Length - 1; i++) - { - // TODO: this could grab the index of the just written message from WriteAsync which would result in the wrong value for latestAckedIndex if there are more than 1 messages buffered - if (_buffer[(_bufferIndex + i + 1) % _buffer.Length].SequenceId > long.MinValue) - { - latestAckedIndex = (_bufferIndex + i + 1) % _buffer.Length; - break; - } - } - FlushResult finalResult = new(); await _writeLock.WaitAsync().ConfigureAwait(false); try { + long latestAckedIndex = -1; + for (var i = 0; i < _buffer.Length - 1; i++) + { + if (_buffer[(_bufferIndex + i + 1) % _buffer.Length].SequenceId > long.MinValue) + { + latestAckedIndex = (_bufferIndex + i + 1) % _buffer.Length; + break; + } + } + if (latestAckedIndex == -1) { // no unacked messages, still send SequenceMessage? diff --git a/src/SignalR/docs/specs/TransportProtocols.md b/src/SignalR/docs/specs/TransportProtocols.md index 4e155b9b8952..e2fb3d28cb17 100644 --- a/src/SignalR/docs/specs/TransportProtocols.md +++ b/src/SignalR/docs/specs/TransportProtocols.md @@ -30,7 +30,7 @@ The client may close the connection if the "negotiateVersion" in the response is *useAck:* -In the POST request the client may include a query string parameter with the key "useAck" and the value of "true". If this is included the server will decide if it supports/allows the [ack protocol](#ack-protocol) described below, and return "useAck": "true" as a json property in the negotiate response if it will use the ack protocol. If true, the client must use the ack protocol when sending/receiving otherwise the connection will be terminated. Similarly, the server must use the ack protocol when sending/receiving. If false, the client must not use the ack protocol and will be terminated if it does. If the "useAck" property is missing from the negotiate response this also implies false, so the ack protocol should not be used. +In the POST request the client may include a query string parameter with the key "useAck" and the value of "true". If this is included the server will decide if it supports/allows the [ack protocol](#todo) described below, and return "useAck": "true" as a json property in the negotiate response if it will use the ack protocol. If true, the client can reconnect using the same transport and reuse the connectionToken/connectionId. The server may still reject the reconnect if it takes too long or for any reason it chooses. If false, the client must not reuse connectionToken/connectionId. If the "useAck" property is missing from the negotiate response this also implies false, so the ack protocol should not be used. ----------- @@ -205,78 +205,3 @@ When data is available, the server responds with a body in one of the two format If the `id` parameter is missing, a `400 Bad Request` response is returned. If there is no connection with the ID specified in `id`, a `404 Not Found` response is returned. When the client has finished with the connection, it can issue a `DELETE` request to `[endpoint-base]` (with the `id` in the query string) to gracefully terminate the connection. The server will complete the latest poll with `204` to indicate that it has shut down. - -## Ack Protocol - -The ack protocol primarily consists of writing and reading framing around the data being sent and received. -All sends need to start with a 24 byte frame. The frame consists of 2 64-bit little-endian values (8 bytes), both base-64 encoded (preserving padding) for a total of 2 12 byte base-64 values. The first base-64 value when decoded is the length of the payload being sent (minus the framing) as an int64 value. The second base-64 value when decoded is the ack ID as an int64 of how many bytes have been received from the other side so far. - -The second part of the protocol is for when the transport ungracefully reconnects and uses the Ack IDs to get any data that might have been missed during the disconnect window. This will be described after showing the framing. - -### Framing - -Consider the following example: - -0x41 0x67 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x3d 0x48 0x51 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x41 0x3d 0x48 0x69 - -This is a 26 byte message, the first 24 bytes are the framing, which we'll split into two 12 byte sections and the 2 remaining bytes -(hex) 41 67 41 41 41 41 41 41 41 41 41 3d - Base64 represention as bytes -AgAAAAAAAAA= - Base64 representation in ASCII -2 0 0 0 0 0 0 0 0 0 0 0 - Base64 decoded, int64 value of 2, representing a 2 length payload after the framing - -(hex) 48 51 41 41 41 41 41 41 41 41 41 3d - Base64 represention as bytes -HQAAAAAAAAA= - Base64 representation in ASCII -29 0 0 0 0 0 0 0 0 0 0 0 - Base64 decoded, int64 value of 29, representing an ack id of 29 bytes received from the endpoint so far - -0x48 0x69 -Hi - -From now on we'll use `[ , ]` annotation to represent the framing, with an implicit payload attached to it. - -To explain the Ack IDs we'll use the following example which is sending between a client and server, C and S respectively: - -``` -C->S: [ 5, 0 ] -S->C: [ 10, 29 ] -S->C: [ 13, 29 ] -C->S: [ 22, 71 ] -S->C: [ 1, 75 ] -``` - -The first send will send an Ack ID of 0 because the client hasn't received any data yet, so there is nothing to ack. When the server sends after it's received a message from the client it will send an Ack ID of the payload length (5) + the frame length (24), so 29. In this example we also send another message which won't have an updated Ack ID, because nothing new was received, so we send the previous value. The client in its next send adds all the received messages together to get the Ack ID to send to the server, 24 + 10 from the first message received, 24 + 13 from the second message received, for a total of 71. And then finally, the server adds its previously sent Ack ID of 29 with the message(s) received since its last send (24 + 22), for a total of 75 for the Ack ID it sends to the client. - -### Reconnect - -The second part of the protocol is what makes use of the Ack IDs. - -If a transport ungracefully disconnects the client can attempt to reconnect using the same `id` it was using before. The server is free to reject any reconnect attempts, but generally should allow a few seconds grace period. - -On a successful reconnect the client must send an Ack ID with a 0 length payload to the server indicating the last message it received before disconnecting. The client then waits for a message from the server that will contain the last Ack ID the server received before the disconnection, as well as a 0 length payload. This message **does not** increment the Ack ID tracking. The Ack ID received from the server will be used to send any missed messages from the client to the server. The normal send/receive loops can now start and if there is any unacked data on the client side the send loop should immediately send the missed data (framing and all). - -On a successful reconnect the server must wait for the client to send the last Ack ID it received before disconnecting. This message **does not** increment the Ack ID tracking. The Ack ID received from the client will be used to send any missed messages from the server to the client. The server will then send the last Ack ID it received before the disconnect occurred as well as a 0 length payload. The normal send/receive loops can now start and if there is any unacked data on the server side the send loop should immediately send the missed data (framing and all). - -The following example will send a few messages between client and server before having an ungraceful disconnect to show the reconnect flow: - -``` -C->S: [ 10, 0 ] -S->C: [ 1, 34 ] -C->S: [ 11, 25 ] -// Ungraceful disconnect -C->S: [ 0, 25 ] -S->C: [ 0, 34 ] -// normal send/receive loops for both sides are now started -C->S: [ 11, 25 ] // resend 11 byte payload that server didn't get before disconnect occurred -``` - -Another example that is the same as the last example except that the server did receive the clients last send before the disconnect: - -``` -C->S: [ 10, 0 ] -S->C: [ 1, 34 ] -C->S: [ 11, 25 ] -// Ungraceful disconnect -C->S: [ 0, 25 ] -S->C: [ 0, 69 ] -// normal send/receive loops for both sides are now started -// 11 bytes from C->S not resent because server did get it before the disconnect, as can be seen by the new Ack ID -``` \ No newline at end of file diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index d42201048ebe..d9dcc8beca2e 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -50,8 +50,10 @@ public partial class HubConnectionContext private TimeSpan _receivedMessageElapsed; private long _receivedMessageTick; private ClaimsPrincipal? _user; + private bool _useAcks; - internal bool UseAcks; + [MemberNotNullWhen(true, nameof(_messageBuffer))] + internal bool UsingAcks() => _useAcks; /// /// Initializes a new instance of the class. @@ -258,9 +260,8 @@ private ValueTask WriteCore(HubMessage message, CancellationToken c { try { - if (UseAcks) + if (UsingAcks()) { - Debug.Assert(_messageBuffer is not null); return _messageBuffer.WriteAsync(new SerializedHubMessage(message), cancellationToken); } else @@ -287,7 +288,7 @@ private ValueTask WriteCore(SerializedHubMessage message, Cancellat { try { - if (UseAcks) + if (UsingAcks()) { Debug.Assert(_messageBuffer is not null); return _messageBuffer.WriteAsync(message, cancellationToken); @@ -573,7 +574,7 @@ await WriteHandshakeResponseAsync(new HandshakeResponseMessage( if (_connectionContext.Features.Get() is IReconnectFeature feature) { - UseAcks = true; + _useAcks = true; _messageBuffer = new MessageBuffer(_connectionContext, Protocol); feature.NotifyOnReconnect = _messageBuffer.Resend; } @@ -762,19 +763,26 @@ internal void Cleanup() internal void Ack(AckMessage ackMessage) { - Debug.Assert(_messageBuffer is not null); - _messageBuffer.Ack(ackMessage); + if (UsingAcks()) + { + _messageBuffer.Ack(ackMessage); + } } internal bool ShouldProcessMessage(HubMessage message) { - Debug.Assert(_messageBuffer is not null); - return _messageBuffer.ShouldProcessMessage(message); + if (UsingAcks()) + { + return _messageBuffer.ShouldProcessMessage(message); + } + return true; } internal void ResetSequence(SequenceMessage sequenceMessage) { - Debug.Assert(_messageBuffer is not null); - _messageBuffer.ResetSequence(sequenceMessage); + if (UsingAcks()) + { + _messageBuffer.ResetSequence(sequenceMessage); + } } } diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index cff89408367e..a632f59a7538 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -132,13 +132,10 @@ public override Task DispatchMessageAsync(HubConnectionContext connection, HubMe // With parallel invokes enabled, messages run sequentially until they go async and then the next message will be allowed to start running. - if (connection.UseAcks) + if (!connection.ShouldProcessMessage(hubMessage)) { - if (!connection.ShouldProcessMessage(hubMessage)) - { - _logger.LogInformation($"dropping {((HubInvocationMessage)hubMessage).GetType().Name}. ID: {((HubInvocationMessage)hubMessage).InvocationId}"); - return Task.CompletedTask; - } + Log.DroppingMessage(_logger, ((HubInvocationMessage)hubMessage).GetType().Name, ((HubInvocationMessage)hubMessage).InvocationId); + return Task.CompletedTask; } switch (hubMessage) @@ -198,19 +195,13 @@ public override Task DispatchMessageAsync(HubConnectionContext connection, HubMe break; case AckMessage ackMessage: - _logger.LogInformation("received ack with id {id}", ackMessage.SequenceId); - if (connection.UseAcks) - { - connection.Ack(ackMessage); - } + Log.ReceivedAckMessage(_logger, ackMessage.SequenceId); + connection.Ack(ackMessage); break; case SequenceMessage sequenceMessage: - _logger.LogInformation("received sequence message with id {id}", sequenceMessage.SequenceId); - if (connection.UseAcks) - { - connection.ResetSequence(sequenceMessage); - } + Log.ReceivedSequenceMessage(_logger, sequenceMessage.SequenceId); + connection.ResetSequence(sequenceMessage); break; // Other kind of message we weren't expecting diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcherLog.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcherLog.cs index 24d6ee8b7fca..dce901a2d952 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcherLog.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcherLog.cs @@ -108,4 +108,13 @@ public static void ClosingStreamWithBindingError(ILogger logger, CompletionMessa [LoggerMessage(25, LogLevel.Error, "Invocation ID {InvocationId}: Failed while sending stream items from hub method {HubMethod}.", EventName = "FailedStreaming")] public static partial void FailedStreaming(ILogger logger, string invocationId, string hubMethod, Exception exception); + + [LoggerMessage(26, LogLevel.Trace, "Dropping {MessageType} with ID '{InvocationId}'.", EventName = "DroppingMessage")] + public static partial void DroppingMessage(ILogger logger, string messageType, string? invocationId); + + [LoggerMessage(27, LogLevel.Trace, "Received AckMessage with Sequence ID '{SequenceId}'.", EventName = "ReceivedAckMessage")] + public static partial void ReceivedAckMessage(ILogger logger, long sequenceId); + + [LoggerMessage(28, LogLevel.Trace, "Received SequenceMessage with Sequence ID '{SequenceId}'.", EventName = "ReceivedSequenceMessage")] + public static partial void ReceivedSequenceMessage(ILogger logger, long sequenceId); } From 352391e4a1114eba866b9d3d4ff5f1d44b6c59d8 Mon Sep 17 00:00:00 2001 From: Brennan Conroy Date: Tue, 23 May 2023 14:20:34 -0700 Subject: [PATCH 24/25] lock + some tests and fixes --- .../FunctionalTests/HubConnectionTests.cs | 3 - src/SignalR/common/Shared/MessageBuffer.cs | 56 +++- .../test/Internal/MessageBufferTests.cs | 291 ++++++++++++++++++ 3 files changed, 330 insertions(+), 20 deletions(-) create mode 100644 src/SignalR/server/SignalR/test/Internal/MessageBufferTests.cs diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs index 087717353cf1..966ed454e54a 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs @@ -2539,7 +2539,6 @@ public async Task ServerSentEventsWorksWithHttp2OnlyEndpoint() } [Fact] - [Repeat(500)] public async Task CanReconnectAndSendMessageWhileDisconnected() { var protocol = HubProtocols["json"]; @@ -2599,7 +2598,6 @@ public async Task CanReconnectAndSendMessageWhileDisconnected() } [Fact] - [Repeat(1500)] public async Task CanReconnectAndSendMessageOnceConnected() { var protocol = HubProtocols["json"]; @@ -2669,7 +2667,6 @@ public async Task CanReconnectAndSendMessageOnceConnected() } [Fact] - [Repeat(500)] public async Task ServerAbortsConnectionWithAckingEnabledNoReconnectAttempted() { var protocol = HubProtocols["json"]; diff --git a/src/SignalR/common/Shared/MessageBuffer.cs b/src/SignalR/common/Shared/MessageBuffer.cs index ef7a49b6c322..ab5e0a5189ab 100644 --- a/src/SignalR/common/Shared/MessageBuffer.cs +++ b/src/SignalR/common/Shared/MessageBuffer.cs @@ -41,6 +41,8 @@ internal sealed class MessageBuffer : IDisposable private TaskCompletionSource _resend = _completedTCS; + private object Lock => _buffer; + static MessageBuffer() { _completedTCS.SetResult(new()); @@ -99,19 +101,25 @@ private async Task RunTimer() } } + /// + /// Calling code is assumed to not call this method in parallel. Currently HubConnection and HubConnectionContext respect that. + /// // TODO: WriteAsync(HubMessage) overload, so we don't allocate SerializedHubMessage for messages that aren't going to be buffered public async ValueTask WriteAsync(SerializedHubMessage hubMessage, CancellationToken cancellationToken) { // TODO: Backpressure based on message count and total message size - if (_buffer[_bufferIndex].Message is not null) + if (_buffer[_bufferIndex].Message is not null || _buffer[_bufferIndex].SequenceId > long.MinValue) { // primitive backpressure if buffer is full while (await _waitForAck.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) { - if (_waitForAck.Reader.TryRead(out var index) - && (index == _bufferIndex || _buffer[_bufferIndex].Message is null)) + lock (Lock) { - break; + if (_waitForAck.Reader.TryRead(out var index) + && (index == _bufferIndex || _buffer[_bufferIndex].Message is null)) + { + break; + } } } } @@ -138,7 +146,10 @@ public async ValueTask WriteAsync(SerializedHubMessage hubMessage, return await _connection.Transport.Output.WriteAsync(hubMessage.GetSerializedMessage(_protocol), cancellationToken).ConfigureAwait(false); } - _buffer[_bufferIndex] = (hubMessage, _totalMessageCount); + lock (Lock) + { + _buffer[_bufferIndex] = (hubMessage, _totalMessageCount); + } _bufferIndex = (_bufferIndex + 1) % _buffer.Length; return await _connection.Transport.Output.WriteAsync(hubMessage.GetSerializedMessage(_protocol), cancellationToken).ConfigureAwait(false); @@ -174,19 +185,23 @@ public void Ack(AckMessage ackMessage) // Or in exceptional cases we could miss multiple messages, but the next ack will clear them var index = _bufferIndex; var finalIndex = -1; - for (var i = 0; i < _buffer.Length; i++) + + lock (Lock) { - var currentIndex = (index + i) % _buffer.Length; - if (_buffer[currentIndex].Message is not null && _buffer[currentIndex].SequenceId <= ackMessage.SequenceId) + for (var i = 0; i < _buffer.Length; i++) { - _buffer[currentIndex] = (null, long.MinValue); - finalIndex = currentIndex; + var currentIndex = (index + i) % _buffer.Length; + if (_buffer[currentIndex].Message is not null && _buffer[currentIndex].SequenceId <= ackMessage.SequenceId) + { + _buffer[currentIndex] = (null, long.MinValue); + finalIndex = currentIndex; + } + // TODO: figure out an early exit? } - // TODO: figure out an early exit? } // Release backpressure - if (finalIndex > 0) + if (finalIndex >= 0) { _waitForAck.Writer.TryWrite(finalIndex); } @@ -194,7 +209,7 @@ public void Ack(AckMessage ackMessage) internal bool ShouldProcessMessage(HubMessage message) { - // TODO: if we're expecting a sequence message but get here should we error or ignore? + // TODO: if we're expecting a sequence message but get here should we error or ignore or maybe even process them? if (_waitForSequenceMessage) { if (message is SequenceMessage) @@ -234,7 +249,7 @@ internal void ResetSequence(SequenceMessage sequenceMessage) if (sequenceMessage.SequenceId > _currentReceivingSequenceId) { - throw new InvalidOperationException("Sequence ID greater than amount we've acked"); + throw new InvalidOperationException("Sequence ID greater than amount of messages we've received."); } _currentReceivingSequenceId = sequenceMessage.SequenceId; } @@ -252,6 +267,7 @@ internal void Resend() Interlocked.Exchange(ref _resend, oldTcs); tcs = oldTcs; } + _ = DoResendAsync(tcs); } @@ -264,16 +280,22 @@ private async Task DoResendAsync(TaskCompletionSource tcs) long latestAckedIndex = -1; for (var i = 0; i < _buffer.Length - 1; i++) { - if (_buffer[(_bufferIndex + i + 1) % _buffer.Length].SequenceId > long.MinValue) + var currentIndex = (_bufferIndex + i + 1) % _buffer.Length; + if (_buffer[currentIndex].SequenceId > long.MinValue) { - latestAckedIndex = (_bufferIndex + i + 1) % _buffer.Length; + latestAckedIndex = currentIndex; break; } } if (latestAckedIndex == -1) { - // no unacked messages, still send SequenceMessage? + // no unacked messages, still send SequenceMessage as other side is expecting it, see _waitForSequenceMessage + + // Add 1 because this ID is used to set what the next ID to be received is, and since everything has been received so far, this needs to be the next ID that is going to be sent + _sequenceMessage.SequenceId = _totalMessageCount + 1; + _protocol.WriteMessage(_sequenceMessage, _connection.Transport.Output); + finalResult = await _connection.Transport.Output.FlushAsync().ConfigureAwait(false); return; } diff --git a/src/SignalR/server/SignalR/test/Internal/MessageBufferTests.cs b/src/SignalR/server/SignalR/test/Internal/MessageBufferTests.cs new file mode 100644 index 000000000000..52c5b78e67af --- /dev/null +++ b/src/SignalR/server/SignalR/test/Internal/MessageBufferTests.cs @@ -0,0 +1,291 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.IO.Pipelines; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.AspNetCore.Testing; + +namespace Microsoft.AspNetCore.SignalR.Tests.Internal; + +public class MessageBufferTests +{ + [Fact] + public async Task CanWriteNonBufferedMessagesWithoutBlocking() + { + var protocol = new JsonHubProtocol(); + var connection = new TestConnectionContext(); + var pipes = DuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions()); + connection.Transport = pipes.Transport; + using var messageBuffer = new MessageBuffer(connection, protocol); + + for (var i = 0; i < 100; i++) + { + await messageBuffer.WriteAsync(new SerializedHubMessage(PingMessage.Instance), default).DefaultTimeout(); + } + + var count = 0; + while (count < 100) + { + var res = await pipes.Application.Input.ReadAsync().DefaultTimeout(); + + var buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out var message)); + Assert.IsType(message); + + pipes.Application.Input.AdvanceTo(buffer.Start); + count++; + } + } + + [Fact] + public async Task WriteBlocksOnAckWhenBufferFull() + { + var protocol = new JsonHubProtocol(); + var connection = new TestConnectionContext(); + var pipes = DuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions()); + connection.Transport = pipes.Transport; + using var messageBuffer = new MessageBuffer(connection, protocol); + + for (var i = 0; i < 10; ++i) + { + await messageBuffer.WriteAsync(new SerializedHubMessage(new StreamItemMessage("id", null)), default); + } + + var writeTask = messageBuffer.WriteAsync(new SerializedHubMessage(new StreamItemMessage("id", null)), default); + Assert.False(writeTask.IsCompleted); + + var res = await pipes.Application.Input.ReadAsync(); + + var buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out var message)); + Assert.IsType(message); + + pipes.Application.Input.AdvanceTo(buffer.Start); + + // Write not unblocked by read, only unblocked after ack received + Assert.False(writeTask.IsCompleted); + + messageBuffer.Ack(new AckMessage(1)); + await writeTask.DefaultTimeout(); + + var count = 0; + while (count < 10) + { + res = await pipes.Application.Input.ReadAsync().DefaultTimeout(); + + buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message)); + Assert.IsType(message); + + pipes.Application.Input.AdvanceTo(buffer.Start); + + count++; + } + } + + [Fact] + public async Task UnAckedMessageResentOnReconnect() + { + var protocol = new JsonHubProtocol(); + var connection = new TestConnectionContext(); + var pipes = DuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions()); + connection.Transport = pipes.Transport; + using var messageBuffer = new MessageBuffer(connection, protocol); + + await messageBuffer.WriteAsync(new SerializedHubMessage(new StreamItemMessage("id", null)), default); + + var res = await pipes.Application.Input.ReadAsync(); + + var buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out var message)); + Assert.IsType(message); + + pipes.Application.Input.AdvanceTo(buffer.Start); + + DuplexPipe.UpdateConnectionPair(ref pipes, connection); + messageBuffer.Resend(); + + // Any message except SequenceMessage will be ignored until a SequenceMessage is received + Assert.False(messageBuffer.ShouldProcessMessage(PingMessage.Instance)); + Assert.False(messageBuffer.ShouldProcessMessage(CompletionMessage.WithResult("1", null))); + Assert.True(messageBuffer.ShouldProcessMessage(new SequenceMessage(1))); + + res = await pipes.Application.Input.ReadAsync(); + + buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message)); + var seqMessage = Assert.IsType(message); + Assert.Equal(1, seqMessage.SequenceId); + + pipes.Application.Input.AdvanceTo(buffer.Start); + + res = await pipes.Application.Input.ReadAsync(); + + buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message)); + Assert.IsType(message); + + pipes.Application.Input.AdvanceTo(buffer.Start); + + messageBuffer.ResetSequence(new SequenceMessage(1)); + + Assert.True(messageBuffer.ShouldProcessMessage(PingMessage.Instance)); + Assert.True(messageBuffer.ShouldProcessMessage(CompletionMessage.WithResult("1", null))); + } + + [Fact] + public async Task AckedMessageNotResentOnReconnect() + { + var protocol = new JsonHubProtocol(); + var connection = new TestConnectionContext(); + var pipes = DuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions()); + connection.Transport = pipes.Transport; + using var messageBuffer = new MessageBuffer(connection, protocol); + + await messageBuffer.WriteAsync(new SerializedHubMessage(new StreamItemMessage("id", null)), default); + + var res = await pipes.Application.Input.ReadAsync(); + + var buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out var message)); + Assert.IsType(message); + + pipes.Application.Input.AdvanceTo(buffer.Start); + + messageBuffer.Ack(new AckMessage(1)); + + DuplexPipe.UpdateConnectionPair(ref pipes, connection); + messageBuffer.Resend(); + + res = await pipes.Application.Input.ReadAsync(); + + buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message)); + var seqMessage = Assert.IsType(message); + Assert.Equal(2, seqMessage.SequenceId); + + pipes.Application.Input.AdvanceTo(buffer.Start); + + await messageBuffer.WriteAsync(new SerializedHubMessage(CompletionMessage.WithResult("1", null)), default); + + res = await pipes.Application.Input.ReadAsync(); + + buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message)); + Assert.IsType(message); + + pipes.Application.Input.AdvanceTo(buffer.Start); + } + + [Fact] + public async Task ReceiveSequenceMessageWithLargerIDThanMessagesReceived() + { + var protocol = new JsonHubProtocol(); + var connection = new TestConnectionContext(); + var pipes = DuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions()); + connection.Transport = pipes.Transport; + using var messageBuffer = new MessageBuffer(connection, protocol); + + DuplexPipe.UpdateConnectionPair(ref pipes, connection); + messageBuffer.Resend(); + + var res = await pipes.Application.Input.ReadAsync(); + + var buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out var message)); + var seqMessage = Assert.IsType(message); + Assert.Equal(1, seqMessage.SequenceId); + + pipes.Application.Input.AdvanceTo(buffer.Start); + + Assert.Throws(() => messageBuffer.ResetSequence(new SequenceMessage(2))); + } +} + +internal sealed class TestConnectionContext : ConnectionContext +{ + public override string ConnectionId { get; set; } + public override IFeatureCollection Features { get; } = new FeatureCollection(); + public override IDictionary Items { get; set; } + public override IDuplexPipe Transport { get; set; } +} + +internal sealed class DuplexPipe : IDuplexPipe +{ + public DuplexPipe(PipeReader reader, PipeWriter writer) + { + Input = reader; + Output = writer; + } + + public PipeReader Input { get; } + + public PipeWriter Output { get; } + + public static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) + { + var input = new Pipe(inputOptions); + var output = new Pipe(outputOptions); + + var transportToApplication = new DuplexPipe(output.Reader, input.Writer); + var applicationToTransport = new DuplexPipe(input.Reader, output.Writer); + + return new DuplexPipePair(applicationToTransport, transportToApplication); + } + + // This class exists to work around issues with value tuple on .NET Framework + public struct DuplexPipePair + { + public IDuplexPipe Transport { get; set; } + public IDuplexPipe Application { get; set; } + + public DuplexPipePair(IDuplexPipe transport, IDuplexPipe application) + { + Transport = transport; + Application = application; + } + } + + public static void UpdateConnectionPair(ref DuplexPipePair duplexPipePair, ConnectionContext connection) + { + var prevPipe = duplexPipePair.Application.Input; + var input = new Pipe(); + + // Add new pipe for reading from and writing to transport from app code + var transportToApplication = new DuplexPipe(duplexPipePair.Transport.Input, input.Writer); + var applicationToTransport = new DuplexPipe(input.Reader, duplexPipePair.Application.Output); + + duplexPipePair.Application = applicationToTransport; + duplexPipePair.Transport = transportToApplication; + + connection.Transport = duplexPipePair.Transport; + + // Close previous pipe with specific error that application code can catch to know a restart is occurring + prevPipe.Complete(new ConnectionResetException("")); + } +} + +internal sealed class TestBinder : IInvocationBinder +{ + public IReadOnlyList GetParameterTypes(string methodName) + { + var list = new List + { + typeof(object) + }; + return list; + } + + public Type GetReturnType(string invocationId) + { + return typeof(object); + } + + public Type GetStreamItemType(string streamId) + { + return typeof(object); + } +} From 342749c9bb99492acc5845cf53a0b66cc885c12d Mon Sep 17 00:00:00 2001 From: Brennan Conroy Date: Wed, 24 May 2023 15:15:28 -0700 Subject: [PATCH 25/25] byte count backpressure --- src/SignalR/common/Shared/MessageBuffer.cs | 311 ++++++++++++++---- .../test/Internal/MessageBufferTests.cs | 66 +++- 2 files changed, 290 insertions(+), 87 deletions(-) diff --git a/src/SignalR/common/Shared/MessageBuffer.cs b/src/SignalR/common/Shared/MessageBuffer.cs index ab5e0a5189ab..01c3b2850349 100644 --- a/src/SignalR/common/Shared/MessageBuffer.cs +++ b/src/SignalR/common/Shared/MessageBuffer.cs @@ -2,6 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; using System.IO.Pipelines; using System.Threading; using System.Threading.Channels; @@ -16,12 +19,12 @@ internal sealed class MessageBuffer : IDisposable { private static readonly TaskCompletionSource _completedTCS = new TaskCompletionSource(); - private readonly (SerializedHubMessage? Message, long SequenceId)[] _buffer; private readonly ConnectionContext _connection; private readonly IHubProtocol _protocol; private readonly AckMessage _ackMessage = new(0); private readonly SequenceMessage _sequenceMessage = new(0); private readonly Channel _waitForAck = Channel.CreateBounded(new BoundedChannelOptions(1) { FullMode = BoundedChannelFullMode.DropOldest }); + private readonly int _bufferLimit = 100 * 1000; #if NET8_0_OR_GREATER private readonly PeriodicTimer _timer = new(TimeSpan.FromSeconds(1)); @@ -30,7 +33,6 @@ internal sealed class MessageBuffer : IDisposable #endif private readonly SemaphoreSlim _writeLock = new(1, 1); - private int _bufferIndex; private long _totalMessageCount; private bool _waitForSequenceMessage; @@ -43,6 +45,9 @@ internal sealed class MessageBuffer : IDisposable private object Lock => _buffer; + private LinkedBuffer _buffer; + private int _bufferedByteCount; + static MessageBuffer() { _completedTCS.SetResult(new()); @@ -51,13 +56,9 @@ static MessageBuffer() // TODO: pass in limits public MessageBuffer(ConnectionContext connection, IHubProtocol protocol) { - // Arbitrary size, we can figure out defaults and configurability later - const int bufferSize = 10; - _buffer = new (SerializedHubMessage? Message, long SequenceId)[bufferSize]; - for (var i = 0; i < _buffer.Length; i++) - { - _buffer[i].SequenceId = long.MinValue; - } + // TODO: pool + _buffer = new LinkedBuffer(); + _connection = connection; _protocol = protocol; @@ -108,18 +109,14 @@ private async Task RunTimer() public async ValueTask WriteAsync(SerializedHubMessage hubMessage, CancellationToken cancellationToken) { // TODO: Backpressure based on message count and total message size - if (_buffer[_bufferIndex].Message is not null || _buffer[_bufferIndex].SequenceId > long.MinValue) + if (_bufferedByteCount > _bufferLimit) { // primitive backpressure if buffer is full while (await _waitForAck.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) { - lock (Lock) + if (_waitForAck.Reader.TryRead(out var count) && count < _bufferLimit) { - if (_waitForAck.Reader.TryRead(out var index) - && (index == _bufferIndex || _buffer[_bufferIndex].Message is null)) - { - break; - } + break; } } } @@ -146,13 +143,14 @@ public async ValueTask WriteAsync(SerializedHubMessage hubMessage, return await _connection.Transport.Output.WriteAsync(hubMessage.GetSerializedMessage(_protocol), cancellationToken).ConfigureAwait(false); } + var messageBytes = hubMessage.GetSerializedMessage(_protocol); lock (Lock) { - _buffer[_bufferIndex] = (hubMessage, _totalMessageCount); + _bufferedByteCount += messageBytes.Length; + _buffer.AddMessage(hubMessage, _totalMessageCount); } - _bufferIndex = (_bufferIndex + 1) % _buffer.Length; - return await _connection.Transport.Output.WriteAsync(hubMessage.GetSerializedMessage(_protocol), cancellationToken).ConfigureAwait(false); + return await _connection.Transport.Output.WriteAsync(messageBytes, cancellationToken).ConfigureAwait(false); } // TODO: figure out what exception to use catch (ConnectionResetException) @@ -181,35 +179,27 @@ public void Ack(AckMessage ackMessage) { // TODO: what if ackMessage.SequenceId is larger than last sent message? - // Grabbing _bufferIndex unsynchronized should be fine, we might miss the most recent message but the client shouldn't be able to ack that yet - // Or in exceptional cases we could miss multiple messages, but the next ack will clear them - var index = _bufferIndex; - var finalIndex = -1; + var newCount = -1; lock (Lock) { - for (var i = 0; i < _buffer.Length; i++) - { - var currentIndex = (index + i) % _buffer.Length; - if (_buffer[currentIndex].Message is not null && _buffer[currentIndex].SequenceId <= ackMessage.SequenceId) - { - _buffer[currentIndex] = (null, long.MinValue); - finalIndex = currentIndex; - } - // TODO: figure out an early exit? - } + var item = _buffer.RemoveMessages(ackMessage.SequenceId, _protocol); + _buffer = item.Item1; + _bufferedByteCount -= item.Item2; + + newCount = _bufferedByteCount; } - // Release backpressure - if (finalIndex >= 0) + // Release potential backpressure + if (newCount >= 0) { - _waitForAck.Writer.TryWrite(finalIndex); + _waitForAck.Writer.TryWrite(newCount); } } internal bool ShouldProcessMessage(HubMessage message) { - // TODO: if we're expecting a sequence message but get here should we error or ignore or maybe even process them? + // TODO: if we're expecting a sequence message but get here should we error or ignore or maybe even continue to process them? if (_waitForSequenceMessage) { if (message is SequenceMessage) @@ -277,65 +267,246 @@ private async Task DoResendAsync(TaskCompletionSource tcs) await _writeLock.WaitAsync().ConfigureAwait(false); try { - long latestAckedIndex = -1; - for (var i = 0; i < _buffer.Length - 1; i++) + _sequenceMessage.SequenceId = _totalMessageCount + 1; + + var isFirst = true; + foreach (var item in _buffer.GetMessages()) { - var currentIndex = (_bufferIndex + i + 1) % _buffer.Length; - if (_buffer[currentIndex].SequenceId > long.MinValue) + if (item.SequenceId > 0) { - latestAckedIndex = currentIndex; - break; + if (isFirst) + { + _sequenceMessage.SequenceId = item.SequenceId; + _protocol.WriteMessage(_sequenceMessage, _connection.Transport.Output); + isFirst = false; + } + finalResult = await _connection.Transport.Output.WriteAsync(item.HubMessage!.GetSerializedMessage(_protocol)).ConfigureAwait(false); } } - if (latestAckedIndex == -1) + if (isFirst) { - // no unacked messages, still send SequenceMessage as other side is expecting it, see _waitForSequenceMessage - - // Add 1 because this ID is used to set what the next ID to be received is, and since everything has been received so far, this needs to be the next ID that is going to be sent - _sequenceMessage.SequenceId = _totalMessageCount + 1; _protocol.WriteMessage(_sequenceMessage, _connection.Transport.Output); finalResult = await _connection.Transport.Output.FlushAsync().ConfigureAwait(false); - return; } + } + catch (Exception ex) + { + tcs.SetException(ex); + } + finally + { + _writeLock.Release(); + tcs.TrySetResult(finalResult); + } + } + + public void Dispose() + { + ((IDisposable)_timer).Dispose(); + } - _sequenceMessage.SequenceId = _buffer[latestAckedIndex].SequenceId; - _protocol.WriteMessage(_sequenceMessage, _connection.Transport.Output); - // don't need to call flush just for the SequenceMessage if we're writing more messages - var shouldFlush = true; + // Linked list of SerializedHubMessage arrays, sort of like ReadOnlySequence + private sealed class LinkedBuffer + { + private const int BufferLength = 10; - for (var i = 0; i < _buffer.Length; i++) + private int _currentIndex = -1; + private int _ackedIndex = -1; + private long _startingSequenceId = long.MinValue; + private LinkedBuffer? _next; + + private readonly SerializedHubMessage?[] _messages = new SerializedHubMessage?[BufferLength]; + + public void AddMessage(SerializedHubMessage hubMessage, long sequenceId) + { + if (_startingSequenceId < 0) + { + Debug.Assert(_currentIndex == -1); + _startingSequenceId = sequenceId; + } + + if (_currentIndex < BufferLength - 1) + { + Debug.Assert(_startingSequenceId + _currentIndex + 1 == sequenceId); + + _currentIndex++; + _messages[_currentIndex] = hubMessage; + } + else if (_next is null) { - var item = _buffer[(latestAckedIndex + i) % _buffer.Length]; - if (item.SequenceId > long.MinValue) + _next = new LinkedBuffer(); + _next.AddMessage(hubMessage, sequenceId); + } + else + { + // TODO: Should we avoid this path by keeping a tail pointer? + // Debug.Assert(false); + + var linkedBuffer = _next; + while (linkedBuffer._next is not null) { - finalResult = await _connection.Transport.Output.WriteAsync(item.Message!.GetSerializedMessage(_protocol)).ConfigureAwait(false); - shouldFlush = false; + linkedBuffer = linkedBuffer._next; + } + + // TODO: verify no stack overflow potential + linkedBuffer.AddMessage(hubMessage, sequenceId); + } + } + + public (LinkedBuffer, int returnCredit) RemoveMessages(long sequenceId, IHubProtocol protocol) + { + return RemoveMessagesCore(this, sequenceId, protocol); + } + + private static (LinkedBuffer, int returnCredit) RemoveMessagesCore(LinkedBuffer linkedBuffer, long sequenceId, IHubProtocol protocol) + { + var returnCredit = 0; + while (linkedBuffer._startingSequenceId <= sequenceId) + { + var numElements = (int)Math.Min(BufferLength, Math.Max(1, sequenceId - (linkedBuffer._startingSequenceId - 1))); + Debug.Assert(numElements > 0 && numElements < BufferLength + 1); + + for (var i = 0; i < numElements; i++) + { + returnCredit += linkedBuffer._messages[i]?.GetSerializedMessage(protocol).Length ?? 0; + linkedBuffer._messages[i] = null; + } + + linkedBuffer._ackedIndex = numElements - 1; + + if (numElements == BufferLength) + { + if (linkedBuffer._next is null) + { + linkedBuffer.Reset(shouldPool: false); + return (linkedBuffer, returnCredit); + } + else + { + var tmp = linkedBuffer; + linkedBuffer = linkedBuffer._next; + tmp.Reset(shouldPool: true); + } } else { - break; + return (linkedBuffer, returnCredit); } } - if (shouldFlush) + return (linkedBuffer, returnCredit); + } + + private void Reset(bool shouldPool) + { + _startingSequenceId = long.MinValue; + _currentIndex = -1; + _ackedIndex = -1; + _next = null; + + Array.Clear(_messages, 0, BufferLength); + + // TODO: Add back to pool + if (shouldPool) { - finalResult = await _connection.Transport.Output.FlushAsync().ConfigureAwait(false); } } - catch (Exception ex) + + public IEnumerable<(SerializedHubMessage? HubMessage, long SequenceId)> GetMessages() { - tcs.SetException(ex); + return new Enumerable(this); } - finally + + private struct Enumerable : IEnumerable<(SerializedHubMessage?, long)> { - _writeLock.Release(); - tcs.TrySetResult(finalResult); + private readonly LinkedBuffer _linkedBuffer; + + public Enumerable(LinkedBuffer linkedBuffer) + { + _linkedBuffer = linkedBuffer; + } + + public IEnumerator<(SerializedHubMessage?, long)> GetEnumerator() + { + return new Enumerator(_linkedBuffer); + } + + IEnumerator IEnumerable.GetEnumerator() + { + throw new NotImplementedException(); + } } - } - public void Dispose() - { - ((IDisposable)_timer).Dispose(); + private struct Enumerator : IEnumerator<(SerializedHubMessage?, long)> + { + private LinkedBuffer? _linkedBuffer; + private int _index; + + public Enumerator(LinkedBuffer linkedBuffer) + { + _linkedBuffer = linkedBuffer; + } + + public (SerializedHubMessage?, long) Current + { + get + { + if (_linkedBuffer is null) + { + return (null, long.MinValue); + } + + var index = _index - 1; + var firstMessageIndex = _linkedBuffer._ackedIndex + 1; + if (firstMessageIndex + index < BufferLength) + { + return (_linkedBuffer._messages[firstMessageIndex + index], _linkedBuffer._startingSequenceId + firstMessageIndex + index); + } + + return (null, long.MinValue); + } + } + + object IEnumerator.Current => throw new NotImplementedException(); + + public void Dispose() + { + _linkedBuffer = null; + } + + public bool MoveNext() + { + if (_linkedBuffer is null) + { + return false; + } + + var firstMessageIndex = _linkedBuffer._ackedIndex + 1; + if (firstMessageIndex + _index >= BufferLength) + { + _linkedBuffer = _linkedBuffer._next; + _index = 1; + } + else + { + if (_linkedBuffer._messages[firstMessageIndex + _index] is null) + { + _linkedBuffer = null; + } + else + { + _index++; + } + } + + return _linkedBuffer is not null; + } + + public void Reset() + { + throw new NotImplementedException(); + } + } } } diff --git a/src/SignalR/server/SignalR/test/Internal/MessageBufferTests.cs b/src/SignalR/server/SignalR/test/Internal/MessageBufferTests.cs index 52c5b78e67af..823470c29d81 100644 --- a/src/SignalR/server/SignalR/test/Internal/MessageBufferTests.cs +++ b/src/SignalR/server/SignalR/test/Internal/MessageBufferTests.cs @@ -45,14 +45,11 @@ public async Task WriteBlocksOnAckWhenBufferFull() { var protocol = new JsonHubProtocol(); var connection = new TestConnectionContext(); - var pipes = DuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions()); + var pipes = DuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions(pauseWriterThreshold: 200000, resumeWriterThreshold: 100000)); connection.Transport = pipes.Transport; using var messageBuffer = new MessageBuffer(connection, protocol); - for (var i = 0; i < 10; ++i) - { - await messageBuffer.WriteAsync(new SerializedHubMessage(new StreamItemMessage("id", null)), default); - } + await messageBuffer.WriteAsync(new SerializedHubMessage(new InvocationMessage("t", new object[] { new byte[100000] })), default); var writeTask = messageBuffer.WriteAsync(new SerializedHubMessage(new StreamItemMessage("id", null)), default); Assert.False(writeTask.IsCompleted); @@ -61,7 +58,7 @@ public async Task WriteBlocksOnAckWhenBufferFull() var buffer = res.Buffer; Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out var message)); - Assert.IsType(message); + Assert.IsType(message); pipes.Application.Input.AdvanceTo(buffer.Start); @@ -71,19 +68,13 @@ public async Task WriteBlocksOnAckWhenBufferFull() messageBuffer.Ack(new AckMessage(1)); await writeTask.DefaultTimeout(); - var count = 0; - while (count < 10) - { - res = await pipes.Application.Input.ReadAsync().DefaultTimeout(); + res = await pipes.Application.Input.ReadAsync().DefaultTimeout(); - buffer = res.Buffer; - Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message)); - Assert.IsType(message); - - pipes.Application.Input.AdvanceTo(buffer.Start); + buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message)); + Assert.IsType(message); - count++; - } + pipes.Application.Input.AdvanceTo(buffer.Start); } [Fact] @@ -203,6 +194,47 @@ public async Task ReceiveSequenceMessageWithLargerIDThanMessagesReceived() Assert.Throws(() => messageBuffer.ResetSequence(new SequenceMessage(2))); } + + [Fact] + public async Task WriteManyMessagesAckSomeProperlyBuffers() + { + var protocol = new JsonHubProtocol(); + var connection = new TestConnectionContext(); + var pipes = DuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions()); + connection.Transport = pipes.Transport; + using var messageBuffer = new MessageBuffer(connection, protocol); + + for (var i = 0; i < 1000; i++) + { + await messageBuffer.WriteAsync(new SerializedHubMessage(new StreamItemMessage("1", null)), default); + } + + var ackNum = Random.Shared.Next(0, 1000); + messageBuffer.Ack(new AckMessage(ackNum)); + + DuplexPipe.UpdateConnectionPair(ref pipes, connection); + messageBuffer.Resend(); + + var res = await pipes.Application.Input.ReadAsync(); + + var buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out var message)); + var seqMessage = Assert.IsType(message); + Assert.Equal(ackNum + 1, seqMessage.SequenceId); + + pipes.Application.Input.AdvanceTo(buffer.Start); + + for (var i = 0; i < 1000 - ackNum; i++) + { + res = await pipes.Application.Input.ReadAsync(); + + buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message)); + Assert.IsType(message); + + pipes.Application.Input.AdvanceTo(buffer.Start); + } + } } internal sealed class TestConnectionContext : ConnectionContext