diff --git a/src/Servers/Connections.Abstractions/src/Features/IReconnectFeature.cs b/src/Servers/Connections.Abstractions/src/Features/IReconnectFeature.cs new file mode 100644 index 000000000000..81a9ebe587a2 --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/Features/IReconnectFeature.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Connections.Abstractions; + +/// +/// +/// +public interface IReconnectFeature +{ + /// + /// + /// + public Action NotifyOnReconnect { get; set; } + + // TODO + // void DisableReconnect(); +} diff --git a/src/Servers/Connections.Abstractions/src/PublicAPI/net462/PublicAPI.Unshipped.txt b/src/Servers/Connections.Abstractions/src/PublicAPI/net462/PublicAPI.Unshipped.txt index 3e85ed9e89fc..39ed42614d79 100644 --- a/src/Servers/Connections.Abstractions/src/PublicAPI/net462/PublicAPI.Unshipped.txt +++ b/src/Servers/Connections.Abstractions/src/PublicAPI/net462/PublicAPI.Unshipped.txt @@ -1,4 +1,7 @@ #nullable enable +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature.NotifyOnReconnect.get -> System.Action! +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature.NotifyOnReconnect.set -> void Microsoft.AspNetCore.Connections.Features.IConnectionMetricsTagsFeature Microsoft.AspNetCore.Connections.Features.IConnectionMetricsTagsFeature.Tags.get -> System.Collections.Generic.ICollection>! Microsoft.AspNetCore.Connections.Features.IConnectionNamedPipeFeature diff --git a/src/Servers/Connections.Abstractions/src/PublicAPI/net8.0/PublicAPI.Unshipped.txt b/src/Servers/Connections.Abstractions/src/PublicAPI/net8.0/PublicAPI.Unshipped.txt index 9f1d00bb09ad..9e361db01313 100644 --- a/src/Servers/Connections.Abstractions/src/PublicAPI/net8.0/PublicAPI.Unshipped.txt +++ b/src/Servers/Connections.Abstractions/src/PublicAPI/net8.0/PublicAPI.Unshipped.txt @@ -1,4 +1,7 @@ #nullable enable +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature.NotifyOnReconnect.get -> System.Action! +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature.NotifyOnReconnect.set -> void Microsoft.AspNetCore.Connections.Features.IConnectionMetricsTagsFeature Microsoft.AspNetCore.Connections.Features.IConnectionMetricsTagsFeature.Tags.get -> System.Collections.Generic.ICollection>! Microsoft.AspNetCore.Connections.Features.IConnectionNamedPipeFeature diff --git a/src/Servers/Connections.Abstractions/src/PublicAPI/netstandard2.0/PublicAPI.Unshipped.txt b/src/Servers/Connections.Abstractions/src/PublicAPI/netstandard2.0/PublicAPI.Unshipped.txt index 3e85ed9e89fc..39ed42614d79 100644 --- a/src/Servers/Connections.Abstractions/src/PublicAPI/netstandard2.0/PublicAPI.Unshipped.txt +++ b/src/Servers/Connections.Abstractions/src/PublicAPI/netstandard2.0/PublicAPI.Unshipped.txt @@ -1,4 +1,7 @@ #nullable enable +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature.NotifyOnReconnect.get -> System.Action! +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature.NotifyOnReconnect.set -> void Microsoft.AspNetCore.Connections.Features.IConnectionMetricsTagsFeature Microsoft.AspNetCore.Connections.Features.IConnectionMetricsTagsFeature.Tags.get -> System.Collections.Generic.ICollection>! Microsoft.AspNetCore.Connections.Features.IConnectionNamedPipeFeature diff --git a/src/Servers/Connections.Abstractions/src/PublicAPI/netstandard2.1/PublicAPI.Unshipped.txt b/src/Servers/Connections.Abstractions/src/PublicAPI/netstandard2.1/PublicAPI.Unshipped.txt index 3e85ed9e89fc..39ed42614d79 100644 --- a/src/Servers/Connections.Abstractions/src/PublicAPI/netstandard2.1/PublicAPI.Unshipped.txt +++ b/src/Servers/Connections.Abstractions/src/PublicAPI/netstandard2.1/PublicAPI.Unshipped.txt @@ -1,4 +1,7 @@ #nullable enable +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature.NotifyOnReconnect.get -> System.Action! +Microsoft.AspNetCore.Connections.Abstractions.IReconnectFeature.NotifyOnReconnect.set -> void Microsoft.AspNetCore.Connections.Features.IConnectionMetricsTagsFeature Microsoft.AspNetCore.Connections.Features.IConnectionMetricsTagsFeature.Tags.get -> System.Collections.Generic.ICollection>! Microsoft.AspNetCore.Connections.Features.IConnectionNamedPipeFeature diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs index 3aa896767775..070fb2d2d43c 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs @@ -325,5 +325,14 @@ public static void ErrorHandshakeTimedOut(ILogger logger, TimeSpan handshakeTime [LoggerMessage(89, LogLevel.Trace, "Error sending Completion message for stream '{StreamId}'.", EventName = "ErrorSendingStreamCompletion")] public static partial void ErrorSendingStreamCompletion(ILogger logger, string streamId, Exception exception); + + [LoggerMessage(90, LogLevel.Trace, "Dropping {MessageType} with ID '{InvocationId}'.", EventName = "DroppingMessage")] + public static partial void DroppingMessage(ILogger logger, string messageType, string? invocationId); + + [LoggerMessage(91, LogLevel.Trace, "Received AckMessage with Sequence ID '{SequenceId}'.", EventName = "ReceivedAckMessage")] + public static partial void ReceivedAckMessage(ILogger logger, long sequenceId); + + [LoggerMessage(92, LogLevel.Trace, "Received SequenceMessage with Sequence ID '{SequenceId}'.", EventName = "ReceivedSequenceMessage")] + public static partial void ReceivedSequenceMessage(ILogger logger, long sequenceId); } } diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index 797ecee716c8..f2b9b70eda1f 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -8,6 +8,7 @@ using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.IO; +using System.IO.Pipelines; using System.Linq; using System.Net; using System.Reflection; @@ -16,6 +17,7 @@ using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Abstractions; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.Shared; @@ -946,11 +948,19 @@ private async Task InvokeStreamCore(ConnectionState connectionState, string meth private async Task SendHubMessage(ConnectionState connectionState, HubMessage hubMessage, CancellationToken cancellationToken = default) { _state.AssertConnectionValid(); - _protocol.WriteMessage(hubMessage, connectionState.Connection.Transport.Output); Log.SendingMessage(_logger, hubMessage); - await connectionState.Connection.Transport.Output.FlushAsync(cancellationToken).ConfigureAwait(false); + if (connectionState.UsingAcks()) + { + await connectionState.WriteAsync(new SerializedHubMessage(hubMessage), cancellationToken).ConfigureAwait(false); + } + else + { + _protocol.WriteMessage(hubMessage, connectionState.Connection.Transport.Output); + + await connectionState.Connection.Transport.Output.FlushAsync(cancellationToken).ConfigureAwait(false); + } Log.MessageSent(_logger, hubMessage); // We've sent a message, so don't ping for a while @@ -1004,6 +1014,11 @@ private async Task SendWithLock(ConnectionState expectedConnectionState, HubMess Log.ResettingKeepAliveTimer(_logger); connectionState.ResetTimeout(); + if (!connectionState.ShouldProcessMessage(message)) + { + return null; + } + InvocationRequest? irq; switch (message) { @@ -1055,6 +1070,14 @@ private async Task SendWithLock(ConnectionState expectedConnectionState, HubMess Log.ReceivedPing(_logger); // timeout is reset above, on receiving any message break; + case AckMessage ackMessage: + Log.ReceivedAckMessage(_logger, ackMessage.SequenceId); + connectionState.Ack(ackMessage); + break; + case SequenceMessage sequenceMessage: + Log.ReceivedSequenceMessage(_logger, sequenceMessage.SequenceId); + connectionState.ResetSequence(sequenceMessage); + break; default: throw new InvalidOperationException($"Unexpected message type: {message.GetType().FullName}"); } @@ -1235,6 +1258,7 @@ private async Task HandshakeAsync(ConnectionState startingConnectionState, Cance } Log.HandshakeComplete(_logger); + break; } } @@ -1813,6 +1837,7 @@ private sealed class ConnectionState : IInvocationBinder private readonly HubConnection _hubConnection; private readonly ILogger _logger; private readonly bool _hasInherentKeepAlive; + private readonly MessageBuffer? _messageBuffer; private readonly object _lock = new object(); private readonly Dictionary _pendingCalls = new Dictionary(StringComparer.Ordinal); @@ -1850,6 +1875,13 @@ public ConnectionState(ConnectionContext connection, HubConnection hubConnection _logger = _hubConnection._logger; _hasInherentKeepAlive = connection.Features.Get()?.HasInherentKeepAlive ?? false; + + if (Connection.Features.Get() is IReconnectFeature feature) + { + _messageBuffer = new MessageBuffer(connection, hubConnection._protocol); + + feature.NotifyOnReconnect = _messageBuffer.Resend; + } } public string GetNextId() => (++_nextInvocationId).ToString(CultureInfo.InvariantCulture); @@ -1935,6 +1967,8 @@ private async Task StopAsyncCore() { Log.Stopping(_logger); + _messageBuffer?.Dispose(); + // Complete our write pipe, which should cause everything to shut down Log.TerminatingReceiveLoop(_logger); Connection.Transport.Input.CancelPendingRead(); @@ -1966,6 +2000,44 @@ public async Task TimerLoop(TimerAwaitable timer) } } + public ValueTask WriteAsync(SerializedHubMessage message, CancellationToken cancellationToken) + { + Debug.Assert(_messageBuffer is not null); + return _messageBuffer.WriteAsync(message, cancellationToken); + } + + public bool ShouldProcessMessage(HubMessage message) + { + if (UsingAcks()) + { + if (!_messageBuffer.ShouldProcessMessage(message)) + { + Log.DroppingMessage(_logger, ((HubInvocationMessage)message).GetType().Name, ((HubInvocationMessage)message).InvocationId); + return false; + } + } + return true; + } + + public void Ack(AckMessage ackMessage) + { + if (UsingAcks()) + { + _messageBuffer.Ack(ackMessage); + } + } + + public void ResetSequence(SequenceMessage sequenceMessage) + { + if (UsingAcks()) + { + _messageBuffer.ResetSequence(sequenceMessage); + } + } + + [MemberNotNullWhen(true, nameof(_messageBuffer))] + public bool UsingAcks() => _messageBuffer is not null; + public void ResetSendPing() { Volatile.Write(ref _nextActivationSendPing, (DateTime.UtcNow + _hubConnection.KeepAliveInterval).Ticks); diff --git a/src/SignalR/clients/csharp/Client.Core/src/Internal/SerializedHubMessage.cs b/src/SignalR/clients/csharp/Client.Core/src/Internal/SerializedHubMessage.cs new file mode 100644 index 000000000000..4fd1d41d5df9 --- /dev/null +++ b/src/SignalR/clients/csharp/Client.Core/src/Internal/SerializedHubMessage.cs @@ -0,0 +1,191 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Protocol; + +namespace Microsoft.AspNetCore.SignalR.Internal; + +/// +/// Represents a serialization cache for a single message. +/// +internal class SerializedHubMessage +{ + private SerializedMessage _cachedItem1; + private SerializedMessage _cachedItem2; + private List? _cachedItems; + private readonly object _lock = new object(); + + /// + /// Gets the hub message for the serialization cache. + /// + public HubMessage? Message { get; } + + /// + /// Initializes a new instance of the class. + /// + /// A collection of already serialized messages to cache. + public SerializedHubMessage(IReadOnlyList messages) + { + // A lock isn't needed here because nobody has access to this type until the constructor finishes. + for (var i = 0; i < messages.Count; i++) + { + var message = messages[i]; + SetCacheUnsynchronized(message.ProtocolName, message.Serialized); + } + } + + /// + /// Initializes a new instance of the class. + /// + /// The hub message for the cache. This will be serialized with an in to get the message's serialized representation. + public SerializedHubMessage(HubMessage message) + { + Message = message; + } + + /// + /// Gets the serialized representation of the using the specified . + /// + /// The protocol used to create the serialized representation. + /// The serialized representation of the . + public ReadOnlyMemory GetSerializedMessage(IHubProtocol protocol) + { + lock (_lock) + { + if (!TryGetCachedUnsynchronized(protocol.Name, out var serialized)) + { + if (Message == null) + { + throw new InvalidOperationException( + "This message was received from another server that did not have the requested protocol available."); + } + + serialized = protocol.GetMessageBytes(Message); + SetCacheUnsynchronized(protocol.Name, serialized); + } + + return serialized; + } + } + + // Used for unit testing. + internal IReadOnlyList GetAllSerializations() + { + // Even if this is only used in tests, let's do it right. + lock (_lock) + { + if (_cachedItem1.ProtocolName == null) + { + return Array.Empty(); + } + + var list = new List(2); + list.Add(_cachedItem1); + + if (_cachedItem2.ProtocolName != null) + { + list.Add(_cachedItem2); + + if (_cachedItems != null) + { + list.AddRange(_cachedItems); + } + } + + return list; + } + } + + private void SetCacheUnsynchronized(string protocolName, ReadOnlyMemory serialized) + { + // We set the fields before moving on to the list, if we need it to hold more than 2 items. + // We have to read/write these fields under the lock because the structs might tear and another + // thread might observe them half-assigned + + if (_cachedItem1.ProtocolName == null) + { + _cachedItem1 = new SerializedMessage(protocolName, serialized); + } + else if (_cachedItem2.ProtocolName == null) + { + _cachedItem2 = new SerializedMessage(protocolName, serialized); + } + else + { + if (_cachedItems == null) + { + _cachedItems = new List(); + } + + foreach (var item in _cachedItems) + { + if (string.Equals(item.ProtocolName, protocolName, StringComparison.Ordinal)) + { + // No need to add + return; + } + } + + _cachedItems.Add(new SerializedMessage(protocolName, serialized)); + } + } + + private bool TryGetCachedUnsynchronized(string protocolName, out ReadOnlyMemory result) + { + if (string.Equals(_cachedItem1.ProtocolName, protocolName, StringComparison.Ordinal)) + { + result = _cachedItem1.Serialized; + return true; + } + + if (string.Equals(_cachedItem2.ProtocolName, protocolName, StringComparison.Ordinal)) + { + result = _cachedItem2.Serialized; + return true; + } + + if (_cachedItems != null) + { + foreach (var serializedMessage in _cachedItems) + { + if (string.Equals(serializedMessage.ProtocolName, protocolName, StringComparison.Ordinal)) + { + result = serializedMessage.Serialized; + return true; + } + } + } + + result = default; + return false; + } +} + +internal readonly struct SerializedMessage +{ + /// + /// Gets the protocol of the serialized message. + /// + public string ProtocolName { get; } + + /// + /// Gets the serialized representation of the message. + /// + public ReadOnlyMemory Serialized { get; } + + /// + /// Initializes a new instance of the class. + /// + /// The protocol of the serialized message. + /// The serialized representation of the message. + public SerializedMessage(string protocolName, ReadOnlyMemory serialized) + { + ProtocolName = protocolName; + Serialized = serialized; + } +} diff --git a/src/SignalR/clients/csharp/Client.Core/src/Microsoft.AspNetCore.SignalR.Client.Core.csproj b/src/SignalR/clients/csharp/Client.Core/src/Microsoft.AspNetCore.SignalR.Client.Core.csproj index 65cf5d649c19..e0cd9f43d95b 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/Microsoft.AspNetCore.SignalR.Client.Core.csproj +++ b/src/SignalR/clients/csharp/Client.Core/src/Microsoft.AspNetCore.SignalR.Client.Core.csproj @@ -14,6 +14,7 @@ + diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs index c3d34fa616de..966ed454e54a 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs @@ -3,6 +3,7 @@ using System.Net; using System.Net.Http; +using System.Net.WebSockets; using System.Text.Json; using System.Threading.Channels; using Microsoft.AspNetCore.Connections; @@ -2537,6 +2538,186 @@ public async Task ServerSentEventsWorksWithHttp2OnlyEndpoint() } } + [Fact] + public async Task CanReconnectAndSendMessageWhileDisconnected() + { + var protocol = HubProtocols["json"]; + await using (var server = await StartServer(w => w.EventId.Name == "ReceivedUnexpectedResponse")) + { + var websocket = new ClientWebSocket(); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + tcs.SetResult(); + + const string originalMessage = "SignalR"; + var connectionBuilder = new HubConnectionBuilder() + .WithLoggerFactory(LoggerFactory) + .WithUrl(server.Url + "/default", HttpTransportType.WebSockets, o => + { + o.WebSocketFactory = async (context, token) => + { + await tcs.Task; + await websocket.ConnectAsync(context.Uri, token); + tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + return websocket; + }; + o.UseAcks = true; + }); + connectionBuilder.Services.AddSingleton(protocol); + var connection = connectionBuilder.Build(); + + try + { + await connection.StartAsync().DefaultTimeout(); + var originalConnectionId = connection.ConnectionId; + + var result = await connection.InvokeAsync(nameof(TestHub.Echo), originalMessage).DefaultTimeout(); + + Assert.Equal(originalMessage, result); + + var originalWebsocket = websocket; + websocket = new ClientWebSocket(); + originalWebsocket.Dispose(); + + var resultTask = connection.InvokeAsync(nameof(TestHub.Echo), originalMessage).DefaultTimeout(); + tcs.SetResult(); + result = await resultTask; + + Assert.Equal(originalMessage, result); + Assert.Equal(originalConnectionId, connection.ConnectionId); + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().DefaultTimeout(); + } + } + } + + [Fact] + public async Task CanReconnectAndSendMessageOnceConnected() + { + var protocol = HubProtocols["json"]; + await using (var server = await StartServer(w => w.EventId.Name == "ReceivedUnexpectedResponse")) + { + var websocket = new ClientWebSocket(); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + const string originalMessage = "SignalR"; + var connectionBuilder = new HubConnectionBuilder() + .WithLoggerFactory(LoggerFactory) + .WithUrl(server.Url + "/default", HttpTransportType.WebSockets, o => + { + o.WebSocketFactory = async (context, token) => + { + await websocket.ConnectAsync(context.Uri, token); + tcs.SetResult(); + return websocket; + }; + o.UseAcks = true; + }) + .WithAutomaticReconnect(); + connectionBuilder.Services.AddSingleton(protocol); + var connection = connectionBuilder.Build(); + + var reconnectCalled = false; + connection.Reconnecting += ex => + { + reconnectCalled = true; + return Task.CompletedTask; + }; + + try + { + await connection.StartAsync().DefaultTimeout(); + await tcs.Task; + tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var originalConnectionId = connection.ConnectionId; + + var result = await connection.InvokeAsync(nameof(TestHub.Echo), originalMessage).DefaultTimeout(); + + Assert.Equal(originalMessage, result); + + var originalWebsocket = websocket; + websocket = new ClientWebSocket(); + + originalWebsocket.Dispose(); + + await tcs.Task.DefaultTimeout(); + result = await connection.InvokeAsync(nameof(TestHub.Echo), originalMessage).DefaultTimeout(); + + Assert.Equal(originalMessage, result); + Assert.Equal(originalConnectionId, connection.ConnectionId); + Assert.False(reconnectCalled); + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().DefaultTimeout(); + } + } + } + + [Fact] + public async Task ServerAbortsConnectionWithAckingEnabledNoReconnectAttempted() + { + var protocol = HubProtocols["json"]; + await using (var server = await StartServer()) + { + var connectCount = 0; + var connectionBuilder = new HubConnectionBuilder() + .WithLoggerFactory(LoggerFactory) + .WithUrl(server.Url + "/default", HttpTransportType.WebSockets, o => + { + o.WebSocketFactory = async (context, token) => + { + connectCount++; + var ws = new ClientWebSocket(); + await ws.ConnectAsync(context.Uri, token); + return ws; + }; + o.UseAcks = true; + }); + connectionBuilder.Services.AddSingleton(protocol); + var connection = connectionBuilder.Build(); + + var closedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection.Closed += ex => + { + closedTcs.SetResult(ex); + return Task.CompletedTask; + }; + + try + { + await connection.StartAsync().DefaultTimeout(); + + await connection.SendAsync(nameof(TestHub.Abort)).DefaultTimeout(); + + Assert.Null(await closedTcs.Task.DefaultTimeout()); + Assert.Equal(HubConnectionState.Disconnected, connection.State); + Assert.Equal(1, connectCount); + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().DefaultTimeout(); + } + } + } + private class OneAtATimeSynchronizationContext : SynchronizationContext, IAsyncDisposable { private readonly Channel<(SendOrPostCallback, object)> _taskQueue = Channel.CreateUnbounded<(SendOrPostCallback, object)>(); diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/Hubs.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/Hubs.cs index 0b270d5e4a4b..fdb2f56c175a 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/Hubs.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/Hubs.cs @@ -101,6 +101,11 @@ public string GetHttpProtocol() return Context.GetHttpContext()?.Request?.Protocol ?? "unknown"; } + public void Abort() + { + Context.Abort(); + } + public async Task CallWithUnserializableObject() { await Clients.All.SendAsync("Foo", Unserializable.Create()); diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionFactoryTests.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionFactoryTests.cs index bc9952577264..f019ec2a749b 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionFactoryTests.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionFactoryTests.cs @@ -100,6 +100,7 @@ public void ShallowCopyHttpConnectionOptionsCopiesAllPublicProperties() { $"{nameof(HttpConnectionOptions.WebSocketFactory)}", webSocketFactory }, { $"{nameof(HttpConnectionOptions.ApplicationMaxBufferSize)}", 1L * 1024 * 1024 }, { $"{nameof(HttpConnectionOptions.TransportMaxBufferSize)}", 1L * 1024 * 1024 }, + { $"{nameof(HttpConnectionOptions.UseAcks)}", true }, }; var options = new HttpConnectionOptions(); diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs index 810633c73587..ba621a57e32f 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.Negotiate.cs @@ -508,7 +508,7 @@ public async Task StartSkipsOverTransportsThatTheClientDoesNotUnderstand() var transportFactory = new Mock(MockBehavior.Strict); - transportFactory.Setup(t => t.CreateTransport(HttpTransportType.LongPolling)) + transportFactory.Setup(t => t.CreateTransport(HttpTransportType.LongPolling, false)) .Returns(new TestTransport(transferFormat: TransferFormat.Text | TransferFormat.Binary)); using (var noErrorScope = new VerifyNoErrorsScope()) @@ -523,7 +523,7 @@ await WithConnectionAsync( } [Fact] - public async Task StartSkipsOverTransportsThatDoNotSupportTheRequredTransferFormat() + public async Task StartSkipsOverTransportsThatDoNotSupportTheRequiredTransferFormat() { var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); @@ -557,7 +557,7 @@ public async Task StartSkipsOverTransportsThatDoNotSupportTheRequredTransferForm var transportFactory = new Mock(MockBehavior.Strict); - transportFactory.Setup(t => t.CreateTransport(HttpTransportType.LongPolling)) + transportFactory.Setup(t => t.CreateTransport(HttpTransportType.LongPolling, false)) .Returns(new TestTransport(transferFormat: TransferFormat.Text | TransferFormat.Binary)); await WithConnectionAsync( diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/TestTransportFactory.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/TestTransportFactory.cs index 2923c387e755..07491c809d7b 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/TestTransportFactory.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/TestTransportFactory.cs @@ -15,7 +15,7 @@ public TestTransportFactory(ITransport transport) _transport = transport; } - public ITransport CreateTransport(HttpTransportType availableServerTransports) + public ITransport CreateTransport(HttpTransportType availableServerTransports, bool useAck) { return _transport; } diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs index 1d56c314f784..22c250f865f7 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs @@ -6,12 +6,12 @@ using System.Diagnostics.CodeAnalysis; using System.IO.Pipelines; using System.Linq; -using System.Net; using System.Net.Http; using System.Net.WebSockets; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Abstractions; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Connections.Client.Internal; using Microsoft.AspNetCore.Http.Features; @@ -313,7 +313,7 @@ private async Task SelectAndStartTransport(TransferFormat transferFormat, Cancel if (_httpConnectionOptions.Transports == HttpTransportType.WebSockets) { Log.StartingTransport(_logger, _httpConnectionOptions.Transports, uri); - await StartTransport(uri, _httpConnectionOptions.Transports, transferFormat, cancellationToken).ConfigureAwait(false); + await StartTransport(uri, _httpConnectionOptions.Transports, transferFormat, cancellationToken, useAck: false).ConfigureAwait(false); } else { @@ -398,12 +398,14 @@ private async Task SelectAndStartTransport(TransferFormat transferFormat, Cancel // The negotiation response gets cleared in the fallback scenario. if (negotiationResponse == null) { + // Temporary until other transports work + _httpConnectionOptions.UseAcks = transportType == HttpTransportType.WebSockets ? _httpConnectionOptions.UseAcks : false; negotiationResponse = await GetNegotiationResponseAsync(uri, cancellationToken).ConfigureAwait(false); connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionToken); } Log.StartingTransport(_logger, transportType, uri); - await StartTransport(connectUrl, transportType, transferFormat, cancellationToken).ConfigureAwait(false); + await StartTransport(connectUrl, transportType, transferFormat, cancellationToken, negotiationResponse.UseAcking).ConfigureAwait(false); break; } } @@ -455,6 +457,11 @@ private async Task NegotiateAsync(Uri url, HttpClient httpC uri = Utils.AppendQueryString(urlBuilder.Uri, $"negotiateVersion={_protocolVersionNumber}"); } + if (_httpConnectionOptions.UseAcks) + { + uri = Utils.AppendQueryString(uri, "useAck=true"); + } + using (var request = new HttpRequestMessage(HttpMethod.Post, uri)) { #if NET5_0_OR_GREATER @@ -500,10 +507,10 @@ private static Uri CreateConnectUrl(Uri url, string? connectionId) return Utils.AppendQueryString(url, $"id={connectionId}"); } - private async Task StartTransport(Uri connectUrl, HttpTransportType transportType, TransferFormat transferFormat, CancellationToken cancellationToken) + private async Task StartTransport(Uri connectUrl, HttpTransportType transportType, TransferFormat transferFormat, CancellationToken cancellationToken, bool useAck) { // Construct the transport - var transport = _transportFactory.CreateTransport(transportType); + var transport = _transportFactory.CreateTransport(transportType, useAck); // Start the transport, giving it one end of the pipe try @@ -524,6 +531,11 @@ private async Task StartTransport(Uri connectUrl, HttpTransportType transportTyp // We successfully started, set the transport properties (we don't want to set these until the transport is definitely running). _transport = transport; + if (useAck && _transport is IReconnectFeature reconnectFeature) + { + Features.Set(reconnectFeature); + } + Log.TransportStarted(_logger, transportType); } diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnectionFactory.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnectionFactory.cs index 03b52fa385a4..b5f2b035e071 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnectionFactory.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnectionFactory.cs @@ -88,7 +88,8 @@ internal static HttpConnectionOptions ShallowCopyHttpConnectionOptions(HttpConne CloseTimeout = options.CloseTimeout, DefaultTransferFormat = options.DefaultTransferFormat, ApplicationMaxBufferSize = options.ApplicationMaxBufferSize, - TransportMaxBufferSize = options.TransportMaxBufferSize + TransportMaxBufferSize = options.TransportMaxBufferSize, + UseAcks = options.UseAcks, }; if (!OperatingSystem.IsBrowser()) diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnectionOptions.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnectionOptions.cs index 443dc2307786..c7a2ae78ed07 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnectionOptions.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnectionOptions.cs @@ -275,6 +275,16 @@ public Action? WebSocketConfiguration } } + /// + /// Setting to enable acking bytes sent between client and server, this allows reconnecting that preserves messages sent while disconnected. + /// Also preserves the when the reconnect is successful. + /// + /// + /// Only works with WebSockets transport currently. + /// API likely to change in future previews. + /// + public bool UseAcks { get; set; } + private static void ThrowIfUnsupportedPlatform() { if (OperatingSystem.IsBrowser()) diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/DefaultTransportFactory.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/DefaultTransportFactory.cs index 3a99961b41bd..5b37a0c561a4 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/DefaultTransportFactory.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/DefaultTransportFactory.cs @@ -31,13 +31,13 @@ public DefaultTransportFactory(HttpTransportType requestedTransportType, ILogger _accessTokenProvider = accessTokenProvider; } - public ITransport CreateTransport(HttpTransportType availableServerTransports) + public ITransport CreateTransport(HttpTransportType availableServerTransports, bool useAck) { if (_websocketsSupported && (availableServerTransports & HttpTransportType.WebSockets & _requestedTransportType) == HttpTransportType.WebSockets) { try { - return new WebSocketsTransport(_httpConnectionOptions, _loggerFactory, _accessTokenProvider, _httpClient); + return new WebSocketsTransport(_httpConnectionOptions, _loggerFactory, _accessTokenProvider, _httpClient, useAck); } catch (PlatformNotSupportedException ex) { @@ -49,13 +49,13 @@ public ITransport CreateTransport(HttpTransportType availableServerTransports) if ((availableServerTransports & HttpTransportType.ServerSentEvents & _requestedTransportType) == HttpTransportType.ServerSentEvents) { // We don't need to give the transport the accessTokenProvider because the HttpClient has a message handler that does the work for us. - return new ServerSentEventsTransport(_httpClient!, _httpConnectionOptions, _loggerFactory); + return new ServerSentEventsTransport(_httpClient!, _httpConnectionOptions, _loggerFactory, useAck); } if ((availableServerTransports & HttpTransportType.LongPolling & _requestedTransportType) == HttpTransportType.LongPolling) { // We don't need to give the transport the accessTokenProvider because the HttpClient has a message handler that does the work for us. - return new LongPollingTransport(_httpClient!, _httpConnectionOptions, _loggerFactory); + return new LongPollingTransport(_httpClient!, _httpConnectionOptions, _loggerFactory, useAck); } throw new InvalidOperationException("No requested transports available on the server."); diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ITransportFactory.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ITransportFactory.cs index cad87cd31772..e9779ba01fa8 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ITransportFactory.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ITransportFactory.cs @@ -5,5 +5,5 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal; internal interface ITransportFactory { - ITransport CreateTransport(HttpTransportType availableServerTransports); + ITransport CreateTransport(HttpTransportType availableServerTransports, bool useAck); } diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs index 8a14b77ba3e0..66d2a02c015b 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs @@ -19,6 +19,7 @@ internal sealed partial class LongPollingTransport : ITransport private readonly HttpClient _httpClient; private readonly ILogger _logger; private readonly HttpConnectionOptions _httpConnectionOptions; + private readonly bool _useAck; private IDuplexPipe? _application; private IDuplexPipe? _transport; // Volatile so that the poll loop sees the updated value set from a different thread @@ -32,11 +33,12 @@ internal sealed partial class LongPollingTransport : ITransport public PipeWriter Output => _transport!.Output; - public LongPollingTransport(HttpClient httpClient, HttpConnectionOptions? httpConnectionOptions = null, ILoggerFactory? loggerFactory = null) + public LongPollingTransport(HttpClient httpClient, HttpConnectionOptions? httpConnectionOptions = null, ILoggerFactory? loggerFactory = null, bool useAck = false) { _httpClient = httpClient; _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); _httpConnectionOptions = httpConnectionOptions ?? new(); + _useAck = useAck; } public async Task StartAsync(Uri url, TransferFormat transferFormat, CancellationToken cancellationToken = default) diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs index 25cbce7ba4a2..d6758849df4f 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs @@ -26,6 +26,7 @@ internal sealed partial class ServerSentEventsTransport : ITransport private readonly CancellationTokenSource _transportCts = new CancellationTokenSource(); private readonly CancellationTokenSource _inputCts = new CancellationTokenSource(); private readonly ServerSentEventsMessageParser _parser = new ServerSentEventsMessageParser(); + private readonly bool _useAck; private IDuplexPipe? _transport; private IDuplexPipe? _application; @@ -35,10 +36,11 @@ internal sealed partial class ServerSentEventsTransport : ITransport public PipeWriter Output => _transport!.Output; - public ServerSentEventsTransport(HttpClient httpClient, HttpConnectionOptions? httpConnectionOptions = null, ILoggerFactory? loggerFactory = null) + public ServerSentEventsTransport(HttpClient httpClient, HttpConnectionOptions? httpConnectionOptions = null, ILoggerFactory? loggerFactory = null, bool useAck = false) { ArgumentNullThrowHelper.ThrowIfNull(httpClient); + _useAck = useAck; _httpClient = httpClient; _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); _httpConnectionOptions = httpConnectionOptions ?? new(); diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.Log.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.Log.cs index 1540844ec4b0..8792a85bd2a9 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.Log.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.Log.cs @@ -72,5 +72,11 @@ private static partial class Log [LoggerMessage(20, LogLevel.Warning, $"Configuring request headers using {nameof(HttpConnectionOptions)}.{nameof(HttpConnectionOptions.Headers)} is not supported when using websockets transport " + "on the browser platform.", EventName = "HeadersNotSupported")] public static partial void HeadersNotSupported(ILogger logger); + + [LoggerMessage(21, LogLevel.Debug, "Receive loop errored.", EventName = "ReceiveErrored")] + public static partial void ReceiveErrored(ILogger logger, Exception exception); + + [LoggerMessage(22, LogLevel.Debug, "Send loop errored.", EventName = "SendErrored")] + public static partial void SendErrored(ILogger logger, Exception exception); } } diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs index 9c3d8184fc26..578ff3460bc5 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/WebSocketsTransport.cs @@ -2,24 +2,30 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Buffers; using System.Diagnostics; +using System.IO; using System.IO.Pipelines; using System.Net; using System.Net.Http; +using System.Net.Sockets; using System.Net.WebSockets; using System.Runtime.InteropServices; using System.Security.Cryptography.X509Certificates; +using System.Text; using System.Text.Encodings.Web; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Abstractions; using Microsoft.AspNetCore.Shared; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using static System.IO.Pipelines.DuplexPipe; namespace Microsoft.AspNetCore.Http.Connections.Client.Internal; -internal sealed partial class WebSocketsTransport : ITransport +internal sealed partial class WebSocketsTransport : ITransport, IReconnectFeature { private WebSocket? _webSocket; private IDuplexPipe? _application; @@ -29,9 +35,14 @@ internal sealed partial class WebSocketsTransport : ITransport private volatile bool _aborted; private readonly HttpConnectionOptions _httpConnectionOptions; private readonly HttpClient? _httpClient; - private readonly CancellationTokenSource _stopCts = new CancellationTokenSource(); + private CancellationTokenSource _stopCts = default!; + private readonly bool _useAck; private IDuplexPipe? _transport; + // Used for reconnect (when enabled) to determine if the close was ungraceful or not, reconnect only happens on ungraceful disconnect + // The assumption is that a graceful close was triggered purposefully by either the client or server and a reconnect shouldn't occur + private bool _gracefulClose; + private Action? _notifyOnReconnect; internal Task Running { get; private set; } = Task.CompletedTask; @@ -39,8 +50,12 @@ internal sealed partial class WebSocketsTransport : ITransport public PipeWriter Output => _transport!.Output; - public WebSocketsTransport(HttpConnectionOptions httpConnectionOptions, ILoggerFactory loggerFactory, Func> accessTokenProvider, HttpClient? httpClient) + public Action NotifyOnReconnect { get => _notifyOnReconnect is not null ? _notifyOnReconnect : () => { }; set => _notifyOnReconnect = value; } + + public WebSocketsTransport(HttpConnectionOptions httpConnectionOptions, ILoggerFactory loggerFactory, Func> accessTokenProvider, HttpClient? httpClient, + bool useAck = false) { + _useAck = useAck; _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); _httpConnectionOptions = httpConnectionOptions ?? new HttpConnectionOptions(); @@ -278,18 +293,29 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio Log.StartedTransport(_logger); - // Create the pipe pair (Application's writer is connected to Transport's reader, and vice versa) - var pair = DuplexPipe.CreateConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); + _stopCts = new CancellationTokenSource(); + + var ignoreFirstCanceled = false; - _transport = pair.Transport; - _application = pair.Application; + if (_transport is null) + { + // Create the pipe pair (Application's writer is connected to Transport's reader, and vice versa) + var pair = CreateConnectionPair(_httpConnectionOptions.TransportPipeOptions, _httpConnectionOptions.AppPipeOptions); + + _transport = pair.Transport; + _application = pair.Application; + } + else + { + ignoreFirstCanceled = true; + } // TODO: Handle TCP connection errors // https://github.com/SignalR/SignalR/blob/1fba14fa3437e24c204dfaf8a18db3fce8acad3c/src/Microsoft.AspNet.SignalR.Core/Owin/WebSockets/WebSocketHandler.cs#L248-L251 - Running = ProcessSocketAsync(_webSocket); + Running = ProcessSocketAsync(_webSocket, url, ignoreFirstCanceled); } - private async Task ProcessSocketAsync(WebSocket socket) + private async Task ProcessSocketAsync(WebSocket socket, Uri url, bool ignoreFirstCanceled) { Debug.Assert(_application != null); @@ -297,7 +323,7 @@ private async Task ProcessSocketAsync(WebSocket socket) { // Begin sending and receiving. var receiving = StartReceiving(socket); - var sending = StartSending(socket); + var sending = StartSending(socket, ignoreFirstCanceled); // Wait for send or receive to complete var trigger = await Task.WhenAny(receiving, sending).ConfigureAwait(false); @@ -335,9 +361,18 @@ private async Task ProcessSocketAsync(WebSocket socket) socket.Abort(); // Cancel any pending flush so that we can quit - _application.Output.CancelPendingFlush(); + if (_gracefulClose) + { + _application.Output.CancelPendingFlush(); + } } } + + if (_useAck && !_gracefulClose) + { + UpdateConnectionPair(); + await StartAsync(url, _webSocketMessageType == WebSocketMessageType.Binary ? TransferFormat.Binary : TransferFormat.Text, default).ConfigureAwait(false); + } } private async Task StartReceiving(WebSocket socket) @@ -354,6 +389,7 @@ private async Task StartReceiving(WebSocket socket) if (result.MessageType == WebSocketMessageType.Close) { + _gracefulClose = true; Log.WebSocketClosed(_logger, socket.CloseStatus); if (socket.CloseStatus != WebSocketCloseStatus.NormalClosure) @@ -380,6 +416,7 @@ private async Task StartReceiving(WebSocket socket) // Need to check again for netstandard2.1 because a close can happen between a 0-byte read and the actual read if (receiveResult.MessageType == WebSocketMessageType.Close) { + _gracefulClose = true; Log.WebSocketClosed(_logger, socket.CloseStatus); if (socket.CloseStatus != WebSocketCloseStatus.NormalClosure) @@ -412,19 +449,30 @@ private async Task StartReceiving(WebSocket socket) { if (!_aborted) { - _application.Output.Complete(ex); + if (_gracefulClose) + { + _application.Output.Complete(ex); + } + else + { + // only logging in this case because the other case gets the exception flowed to application code + Log.ReceiveErrored(_logger, ex); + } } } finally { // We're done writing - _application.Output.Complete(); + if (_gracefulClose) + { + _application.Output.Complete(); + } Log.ReceiveStopped(_logger); } } - private async Task StartSending(WebSocket socket) + private async Task StartSending(WebSocket socket, bool ignoreFirstCanceled) { Debug.Assert(_application != null); @@ -441,11 +489,14 @@ private async Task StartSending(WebSocket socket) try { - if (result.IsCanceled) + if (result.IsCanceled && !ignoreFirstCanceled) { + _logger.LogInformation("send canceled"); break; } + ignoreFirstCanceled = false; + if (!buffer.IsEmpty) { try @@ -509,7 +560,17 @@ private async Task StartSending(WebSocket socket) } } - _application.Input.Complete(); + if (_gracefulClose) + { + _application.Input.Complete(error); + } + else + { + if (error is not null) + { + Log.SendErrored(_logger, error); + } + } Log.SendStopped(_logger); } @@ -539,6 +600,7 @@ private static Uri ResolveWebSocketsUrl(Uri url) public async Task StopAsync() { + _gracefulClose = true; Log.TransportStopping(_logger); if (_application == null) @@ -574,4 +636,23 @@ public async Task StopAsync() Log.TransportStopped(_logger, null); } + + private void UpdateConnectionPair() + { + var prevPipe = _application!.Input; + var input = new Pipe(_httpConnectionOptions.TransportPipeOptions); + + // Add new pipe for reading from and writing to transport from app code + var transportToApplication = new DuplexPipe(_transport!.Input, input.Writer); + var applicationToTransport = new DuplexPipe(input.Reader, _application!.Output); + + _application = applicationToTransport; + _transport = transportToApplication; + + // Close previous pipe with specific error that application code can catch to know a restart is occurring + prevPipe.Complete(new ConnectionResetException("")); + + Debug.Assert(_notifyOnReconnect is not null); + _notifyOnReconnect.Invoke(); + } } diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/PublicAPI.Unshipped.txt b/src/SignalR/clients/csharp/Http.Connections.Client/src/PublicAPI.Unshipped.txt index 7dc5c58110bf..80a34974298d 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/PublicAPI.Unshipped.txt +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/PublicAPI.Unshipped.txt @@ -1 +1,3 @@ #nullable enable +Microsoft.AspNetCore.Http.Connections.Client.HttpConnectionOptions.UseAcks.get -> bool +Microsoft.AspNetCore.Http.Connections.Client.HttpConnectionOptions.UseAcks.set -> void diff --git a/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs b/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs index 72b6838753a4..8ff9ba259ed1 100644 --- a/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs +++ b/src/SignalR/common/Http.Connections.Common/src/NegotiateProtocol.cs @@ -34,6 +34,8 @@ public static class NegotiateProtocol private static readonly JsonEncodedText ErrorPropertyNameBytes = JsonEncodedText.Encode(ErrorPropertyName); private const string NegotiateVersionPropertyName = "negotiateVersion"; private static readonly JsonEncodedText NegotiateVersionPropertyNameBytes = JsonEncodedText.Encode(NegotiateVersionPropertyName); + private const string AckPropertyName = "useAck"; + private static readonly JsonEncodedText AckPropertyNameBytes = JsonEncodedText.Encode(AckPropertyName); // Use C#7.3's ReadOnlySpan optimization for static data https://vcsjones.com/2019/02/01/csharp-readonly-span-bytes-static/ // Used to detect ASP.NET SignalR Server connection attempt @@ -64,6 +66,11 @@ public static void WriteResponse(NegotiationResponse response, IBufferWriter content) List? availableTransports = null; string? error = null; int version = 0; + bool useAck = false; var completed = false; while (!completed && reader.CheckRead()) @@ -206,6 +214,10 @@ public static NegotiationResponse ParseResponse(ReadOnlySpan content) { throw new InvalidOperationException("Detected a connection attempt to an ASP.NET SignalR Server. This client only supports connecting to an ASP.NET Core SignalR Server. See https://aka.ms/signalr-core-differences for details."); } + else if (reader.ValueTextEquals(AckPropertyNameBytes.EncodedUtf8Bytes)) + { + useAck = reader.ReadAsBoolean(AckPropertyName); + } else { reader.Skip(); @@ -249,7 +261,8 @@ public static NegotiationResponse ParseResponse(ReadOnlySpan content) AccessToken = accessToken, AvailableTransports = availableTransports, Error = error, - Version = version + Version = version, + UseAcking = useAck, }; } catch (Exception ex) diff --git a/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs b/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs index 17c838137f44..4ec99bba1055 100644 --- a/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs +++ b/src/SignalR/common/Http.Connections.Common/src/NegotiationResponse.cs @@ -44,4 +44,9 @@ public class NegotiationResponse /// An optional error during the negotiate. If this is not null the other properties on the response can be ignored. /// public string? Error { get; set; } + + /// + /// + /// + public bool UseAcking { get; set; } } diff --git a/src/SignalR/common/Http.Connections.Common/src/PublicAPI.Unshipped.txt b/src/SignalR/common/Http.Connections.Common/src/PublicAPI.Unshipped.txt index 7dc5c58110bf..437066a1e801 100644 --- a/src/SignalR/common/Http.Connections.Common/src/PublicAPI.Unshipped.txt +++ b/src/SignalR/common/Http.Connections.Common/src/PublicAPI.Unshipped.txt @@ -1 +1,3 @@ #nullable enable +Microsoft.AspNetCore.Http.Connections.NegotiationResponse.UseAcking.get -> bool +Microsoft.AspNetCore.Http.Connections.NegotiationResponse.UseAcking.set -> void diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs index a74f17d6aad7..8ec21012979f 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs @@ -7,10 +7,12 @@ using System.Security.Claims; using System.Security.Principal; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Abstractions; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Connections.Features; using Microsoft.AspNetCore.Http.Connections.Internal.Transports; using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Http.Timeouts; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -28,7 +30,8 @@ internal sealed partial class HttpConnectionContext : ConnectionContext, IHttpTransportFeature, IConnectionInherentKeepAliveFeature, IConnectionLifetimeFeature, - IConnectionLifetimeNotificationFeature + IConnectionLifetimeNotificationFeature, + IReconnectFeature { private readonly HttpConnectionDispatcherOptions _options; @@ -46,6 +49,7 @@ internal sealed partial class HttpConnectionContext : ConnectionContext, private CancellationTokenSource? _sendCts; private bool _activeSend; private long _startedSendTime; + private readonly bool _useAcks; private readonly object _sendingLock = new object(); internal CancellationToken SendingToken { get; private set; } @@ -57,7 +61,8 @@ internal sealed partial class HttpConnectionContext : ConnectionContext, /// Creates the DefaultConnectionContext without Pipes to avoid upfront allocations. /// The caller is expected to set the and pipes manually. /// - public HttpConnectionContext(string connectionId, string connectionToken, ILogger logger, MetricsContext metricsContext, IDuplexPipe transport, IDuplexPipe application, HttpConnectionDispatcherOptions options) + public HttpConnectionContext(string connectionId, string connectionToken, ILogger logger, MetricsContext metricsContext, + IDuplexPipe transport, IDuplexPipe application, HttpConnectionDispatcherOptions options, bool useAcks) { Transport = transport; _applicationStream = new PipeWriterStream(application.Output); @@ -89,14 +94,22 @@ public HttpConnectionContext(string connectionId, string connectionToken, ILogge Features.Set(this); Features.Set(this); + if (useAcks) + { + Features.Set(this); + } + _connectionClosedTokenSource = new CancellationTokenSource(); ConnectionClosed = _connectionClosedTokenSource.Token; _connectionCloseRequested = new CancellationTokenSource(); ConnectionClosedRequested = _connectionCloseRequested.Token; AuthenticationExpiration = DateTimeOffset.MaxValue; + _useAcks = useAcks; } + public bool UseAcks => _useAcks; + public CancellationTokenSource? Cancellation { get; set; } public HttpTransportType TransportType { get; set; } @@ -113,7 +126,7 @@ public HttpConnectionContext(string connectionId, string connectionToken, ILogge internal bool IsAuthenticationExpirationEnabled => _options.CloseOnAuthenticationExpiration; - public Task? TransportTask { get; set; } + public Task? TransportTask { get; set; } public Task PreviousPollTask { get; set; } = Task.CompletedTask; @@ -189,6 +202,8 @@ public IDuplexPipe Application public CancellationToken ConnectionClosedRequested { get; set; } + public Action NotifyOnReconnect { get; set; } = () => { }; + public override void Abort() { ThreadPool.UnsafeQueueUserWorkItem(cts => ((CancellationTokenSource)cts!).Cancel(), _connectionClosedTokenSource); @@ -384,6 +399,7 @@ private async Task WaitOnTasks(Task applicationTask, Task transportTask, bool cl internal bool TryActivatePersistentConnection( ConnectionDelegate connectionDelegate, IHttpTransport transport, + Task currentRequestTask, HttpContext context, ILogger dispatcherLogger) { @@ -393,12 +409,16 @@ internal bool TryActivatePersistentConnection( { Status = HttpConnectionStatus.Active; + PreviousPollTask = currentRequestTask; + // Call into the end point passing the connection - ApplicationTask = ExecuteApplication(connectionDelegate); + ApplicationTask ??= ExecuteApplication(connectionDelegate); // Start the transport TransportTask = transport.ProcessRequestAsync(context, context.RequestAborted); + context.Features.Get()?.DisableTimeout(); + return true; } else @@ -440,7 +460,12 @@ public bool TryActivateLongPollingConnection( // On the first poll, we flush the response immediately to mark the poll as "initialized" so future // requests can be made safely - TransportTask = nonClonedContext.Response.Body.FlushAsync(); + TransportTask = Func(); + async Task Func() + { + await nonClonedContext.Response.Body.FlushAsync(); + return false; + }; } else { @@ -520,6 +545,14 @@ internal async Task CancelPreviousPoll(HttpContext context) // Cancel the previous request cts?.Cancel(); + // TODO: remove transport check once other transports support acks + if (UseAcks && TransportType == HttpTransportType.WebSockets) + { + // Break transport send loop in case it's still waiting on reading from the application + Application.Input.CancelPendingRead(); + UpdateConnectionPair(); + } + try { // Wait for the previous request to drain @@ -620,6 +653,23 @@ public void RequestClose() ThreadPool.UnsafeQueueUserWorkItem(static cts => ((CancellationTokenSource)cts!).Cancel(), _connectionCloseRequested); } + private void UpdateConnectionPair() + { + var prevPipe = Application.Input; + var input = new Pipe(_options.TransportPipeOptions); + + // Add new pipe for reading from and writing to transport from app code + var transportToApplication = new DuplexPipe(Transport.Input, input.Writer); + var applicationToTransport = new DuplexPipe(input.Reader, Application.Output); + + Application = applicationToTransport; + Transport = transportToApplication; + + // Close previous pipe with specific error that application code can catch to know a restart is occurring + prevPipe.Complete(new ConnectionResetException("")); + Features.GetRequiredFeature().NotifyOnReconnect.Invoke(); + } + private static partial class Log { [LoggerMessage(1, LogLevel.Trace, "Disposing connection {TransportConnectionId}.", EventName = "DisposingConnection")] diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs index de9fc73186df..de6cc55921a0 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs @@ -153,67 +153,74 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti // We only need to provide the Input channel since writing to the application is handled through /send. var sse = new ServerSentEventsServerTransport(connection.Application.Input, connection.ConnectionId, connection, _loggerFactory); - await DoPersistentConnection(connectionDelegate, sse, context, connection); - } - else if (context.WebSockets.IsWebSocketRequest) - { - // Connection can be established lazily - var connection = await GetOrCreateConnectionAsync(context, options); - if (connection == null) - { - // No such connection, GetOrCreateConnection already set the response status code - return; - } - - if (!await EnsureConnectionStateAsync(connection, context, HttpTransportType.WebSockets, supportedTransports, logScope)) + if (connection.TryActivatePersistentConnection(connectionDelegate, sse, Task.CompletedTask, context, _logger)) { - // Bad connection state. It's already set the response status code. - return; + await DoPersistentConnection(connection); } - - Log.EstablishedConnection(_logger); - - // Allow the reads to be canceled - connection.Cancellation = new CancellationTokenSource(); - - var ws = new WebSocketsServerTransport(options.WebSockets, connection.Application, connection, _loggerFactory); - - await DoPersistentConnection(connectionDelegate, ws, context, connection); } else { - // GET /{path} maps to long polling + // GET /{path} maps to long polling or WebSockets - AddNoCacheHeaders(context.Response); + HttpConnectionContext? connection; + var transport = HttpTransportType.LongPolling; + if (context.WebSockets.IsWebSocketRequest) + { + transport = HttpTransportType.WebSockets; + connection = await GetOrCreateConnectionAsync(context, options); + } + else + { + AddNoCacheHeaders(context.Response); + // Connection must already exist + connection = await GetConnectionAsync(context); + } - // Connection must already exist - var connection = await GetConnectionAsync(context); if (connection == null) { // No such connection, GetConnection already set the response status code return; } - if (!await EnsureConnectionStateAsync(connection, context, HttpTransportType.LongPolling, supportedTransports, logScope)) + if (!await EnsureConnectionStateAsync(connection, context, transport, supportedTransports, logScope)) { // Bad connection state. It's already set the response status code. return; } - if (!await connection.CancelPreviousPoll(context)) + if (connection.TransportType != HttpTransportType.WebSockets || connection.UseAcks) { - // Connection closed. It's already set the response status code. - return; + if (!await connection.CancelPreviousPoll(context)) + { + // Connection closed. It's already set the response status code. + return; + } } // Create a new Tcs every poll to keep track of the poll finishing, so we can properly wait on previous polls var currentRequestTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - if (!connection.TryActivateLongPollingConnection( - connectionDelegate, context, options.LongPolling.PollTimeout, - currentRequestTcs.Task, _loggerFactory, _logger)) + switch (transport) { - return; + case HttpTransportType.None: + break; + case HttpTransportType.WebSockets: + var ws = new WebSocketsServerTransport(options.WebSockets, connection.Application, connection, _loggerFactory); + if (!connection.TryActivatePersistentConnection(connectionDelegate, ws, currentRequestTcs.Task, context, _logger)) + { + return; + } + break; + case HttpTransportType.LongPolling: + if (!connection.TryActivateLongPollingConnection( + connectionDelegate, context, options.LongPolling.PollTimeout, + currentRequestTcs.Task, _loggerFactory, _logger)) + { + return; + } + break; + default: + break; } context.Features.Get()?.DisableTimeout(); @@ -244,8 +251,15 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti } else { - // Only allow repoll if we aren't removing the connection. - connection.MarkInactive(); + if (transport != HttpTransportType.LongPolling) + { + await _manager.DisposeAndRemoveAsync(connection, closeGracefully: false, HttpConnectionStopStatus.NormalClosure); + } + else + { + // Only allow repoll if we aren't removing the connection. + connection.MarkInactive(); + } } } else if (resultTask.IsFaulted || resultTask.IsCanceled) @@ -258,8 +272,18 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti } else { - // Only allow repoll if we aren't removing the connection. - connection.MarkInactive(); + // If false then the transport was ungracefully closed, this can mean a temporary network disconnection + // We'll mark the connection as inactive and allow the connection to reconnect if that's the case. + // TODO: If acks aren't enabled we can close the connection immediately (not LongPolling) + if (await connection.TransportTask!) + { + await _manager.DisposeAndRemoveAsync(connection, closeGracefully: true, HttpConnectionStopStatus.NormalClosure); + } + else + { + // Only allow repoll if we aren't removing the connection. + connection.MarkInactive(); + } } } finally @@ -271,19 +295,12 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti } } - private async Task DoPersistentConnection(ConnectionDelegate connectionDelegate, - IHttpTransport transport, - HttpContext context, - HttpConnectionContext connection) + private async Task DoPersistentConnection(HttpConnectionContext connection) { - if (connection.TryActivatePersistentConnection(connectionDelegate, transport, context, _logger)) - { - context.Features.Get()?.DisableTimeout(); - // Wait for any of them to end - await Task.WhenAny(connection.ApplicationTask!, connection.TransportTask!); + // Wait for any of them to end + await Task.WhenAny(connection.ApplicationTask!, connection.TransportTask!); - await _manager.DisposeAndRemoveAsync(connection, closeGracefully: true, HttpConnectionStopStatus.NormalClosure); - } + await _manager.DisposeAndRemoveAsync(connection, closeGracefully: true, HttpConnectionStopStatus.NormalClosure); } private async Task ProcessNegotiate(HttpContext context, HttpConnectionDispatcherOptions options, ConnectionLogScope logScope) @@ -317,11 +334,18 @@ private async Task ProcessNegotiate(HttpContext context, HttpConnectionDispatche Log.NegotiateProtocolVersionMismatch(_logger, 0); } + var useAck = false; + if (context.Request.Query.TryGetValue("UseAck", out var useAckValue)) + { + var useAckStringValue = useAckValue.ToString(); + bool.TryParse(useAckStringValue, out useAck); + } + // Establish the connection HttpConnectionContext? connection = null; if (error == null) { - connection = CreateConnection(options, clientProtocolVersion); + connection = CreateConnection(options, clientProtocolVersion, useAck); } // Set the Connection ID on the logging scope so that logs from now on will have the @@ -334,7 +358,7 @@ private async Task ProcessNegotiate(HttpContext context, HttpConnectionDispatche try { // Get the bytes for the connection id - WriteNegotiatePayload(writer, connection?.ConnectionId, connection?.ConnectionToken, context, options, clientProtocolVersion, error); + WriteNegotiatePayload(writer, connection?.ConnectionId, connection?.ConnectionToken, context, options, clientProtocolVersion, error, useAck); Log.NegotiationRequest(_logger); @@ -349,7 +373,7 @@ private async Task ProcessNegotiate(HttpContext context, HttpConnectionDispatche } private static void WriteNegotiatePayload(IBufferWriter writer, string? connectionId, string? connectionToken, HttpContext context, HttpConnectionDispatcherOptions options, - int clientProtocolVersion, string? error) + int clientProtocolVersion, string? error, bool useAck) { var response = new NegotiationResponse(); @@ -364,6 +388,7 @@ private static void WriteNegotiatePayload(IBufferWriter writer, string? co response.ConnectionId = connectionId; response.ConnectionToken = connectionToken; response.AvailableTransports = new List(); + response.UseAcking = useAck; if ((options.Transports & HttpTransportType.WebSockets) != 0 && ServerHasWebSockets(context.Features)) { @@ -745,9 +770,9 @@ private static void CloneHttpContext(HttpContext context, HttpConnectionContext return connection; } - private HttpConnectionContext CreateConnection(HttpConnectionDispatcherOptions options, int clientProtocolVersion = 0) + private HttpConnectionContext CreateConnection(HttpConnectionDispatcherOptions options, int clientProtocolVersion = 0, bool useAck = false) { - return _manager.CreateConnection(options, clientProtocolVersion); + return _manager.CreateConnection(options, clientProtocolVersion, useAck); } private static void AddNoCacheHeaders(HttpResponse response) diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs index 3f1cce2edb6b..c970679841ce 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs @@ -4,13 +4,13 @@ using System.Collections.Concurrent; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; -using System.IO.Pipelines; using System.Net.WebSockets; using System.Security.Cryptography; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using static System.IO.Pipelines.DuplexPipe; namespace Microsoft.AspNetCore.Http.Connections.Internal; @@ -67,7 +67,7 @@ internal HttpConnectionContext CreateConnection() /// Creates a connection without Pipes setup to allow saving allocations until Pipes are needed. /// /// - internal HttpConnectionContext CreateConnection(HttpConnectionDispatcherOptions options, int negotiateVersion = 0) + internal HttpConnectionContext CreateConnection(HttpConnectionDispatcherOptions options, int negotiateVersion = 0, bool useAck = false) { string connectionToken; var id = MakeNewConnectionId(); @@ -90,8 +90,8 @@ internal HttpConnectionContext CreateConnection(HttpConnectionDispatcherOptions HttpConnectionsEventSource.Log.ConnectionStart(id); _metrics.ConnectionStart(metricsContext); - var pair = DuplexPipe.CreateConnectionPair(options.TransportPipeOptions, options.AppPipeOptions); - var connection = new HttpConnectionContext(id, connectionToken, _connectionLogger, metricsContext, pair.Application, pair.Transport, options); + var pair = CreateConnectionPair(options.TransportPipeOptions, options.AppPipeOptions); + var connection = new HttpConnectionContext(id, connectionToken, _connectionLogger, metricsContext, pair.Application, pair.Transport, options, useAck); _connections.TryAdd(connectionToken, (connection, startTimestamp)); diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/IHttpTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/IHttpTransport.cs index af9ab19f4ef4..557c42c550e5 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/IHttpTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/IHttpTransport.cs @@ -11,5 +11,5 @@ internal interface IHttpTransport /// /// /// A that completes when the transport has finished processing - Task ProcessRequestAsync(HttpContext context, CancellationToken token); + Task ProcessRequestAsync(HttpContext context, CancellationToken token); } diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/LongPollingServerTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/LongPollingServerTransport.cs index e3e360328848..34251849d8d1 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/LongPollingServerTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/LongPollingServerTransport.cs @@ -29,7 +29,7 @@ public LongPollingServerTransport(CancellationToken timeoutToken, PipeReader app _logger = loggerFactory.CreateLogger("Microsoft.AspNetCore.Http.Connections.Internal.Transports.LongPollingTransport"); } - public async Task ProcessRequestAsync(HttpContext context, CancellationToken token) + public async Task ProcessRequestAsync(HttpContext context, CancellationToken token) { try { @@ -43,7 +43,7 @@ public async Task ProcessRequestAsync(HttpContext context, CancellationToken tok Log.LongPolling204(_logger); context.Response.ContentType = "text/plain"; context.Response.StatusCode = StatusCodes.Status204NoContent; - return; + return false; } // We're intentionally not checking cancellation here because we need to drain messages we've got so far, @@ -109,6 +109,7 @@ public async Task ProcessRequestAsync(HttpContext context, CancellationToken tok context.Response.StatusCode = StatusCodes.Status500InternalServerError; throw; } + return false; } private static partial class Log diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/ServerSentEventsServerTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/ServerSentEventsServerTransport.cs index 72c3c816b423..bf4f9d85d97b 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/ServerSentEventsServerTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/ServerSentEventsServerTransport.cs @@ -28,7 +28,7 @@ public ServerSentEventsServerTransport(PipeReader application, string connection _logger = loggerFactory.CreateLogger("Microsoft.AspNetCore.Http.Connections.Internal.Transports.ServerSentEventsTransport"); } - public async Task ProcessRequestAsync(HttpContext context, CancellationToken cancellationToken) + public async Task ProcessRequestAsync(HttpContext context, CancellationToken cancellationToken) { context.Response.ContentType = "text/event-stream"; context.Response.Headers.CacheControl = "no-cache,no-store"; @@ -81,6 +81,8 @@ public async Task ProcessRequestAsync(HttpContext context, CancellationToken can { // Closed connection } + + return true; } private static partial class Log diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.Log.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.Log.cs index 026285075b7d..d085cf8b0f84 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.Log.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.Log.cs @@ -51,5 +51,8 @@ private static partial class Log [LoggerMessage(15, LogLevel.Debug, "Closing webSocket failed.", EventName = "ClosingWebSocketFailed")] public static partial void ClosingWebSocketFailed(ILogger logger, Exception ex); + + [LoggerMessage(16, LogLevel.Debug, "Send loop errored.", EventName = "SendErrored")] + public static partial void SendErrored(ILogger logger, Exception exception); } } diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs index 3d473aa095b9..77892ab942dc 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs @@ -17,6 +17,9 @@ internal sealed partial class WebSocketsServerTransport : IHttpTransport private readonly HttpConnectionContext _connection; private volatile bool _aborted; + // Used to determine if the close was graceful or a network issue + private bool _gracefulClose; + public WebSocketsServerTransport(WebSocketOptions options, IDuplexPipe application, HttpConnectionContext connection, ILoggerFactory loggerFactory) { ArgumentNullException.ThrowIfNull(options); @@ -31,7 +34,7 @@ public WebSocketsServerTransport(WebSocketOptions options, IDuplexPipe applicati _logger = loggerFactory.CreateLogger("Microsoft.AspNetCore.Http.Connections.Internal.Transports.WebSocketsTransport"); } - public async Task ProcessRequestAsync(HttpContext context, CancellationToken token) + public async Task ProcessRequestAsync(HttpContext context, CancellationToken token) { Debug.Assert(context.WebSockets.IsWebSocketRequest, "Not a websocket request"); @@ -50,13 +53,16 @@ public async Task ProcessRequestAsync(HttpContext context, CancellationToken tok Log.SocketClosed(_logger); } } + + return _gracefulClose; } public async Task ProcessSocketAsync(WebSocket socket) { - // Begin sending and receiving. Receiving must be started first because ExecuteAsync enables SendAsync. + var ignoreFirstCancel = false; + var receiving = StartReceiving(socket); - var sending = StartSending(socket); + var sending = StartSending(socket, ignoreFirstCancel); // Wait for send or receive to complete var trigger = await Task.WhenAny(receiving, sending); @@ -135,6 +141,7 @@ private async Task StartReceiving(WebSocket socket) if (result.MessageType == WebSocketMessageType.Close) { + _gracefulClose = true; return; } @@ -145,6 +152,7 @@ private async Task StartReceiving(WebSocket socket) // Need to check again for netcoreapp3.0 and later because a close can happen between a 0-byte read and the actual read if (receiveResult.MessageType == WebSocketMessageType.Close) { + _gracefulClose = true; return; } @@ -175,17 +183,21 @@ private async Task StartReceiving(WebSocket socket) { if (!_aborted && !token.IsCancellationRequested) { + _gracefulClose = true; _application.Output.Complete(ex); } } finally { - // We're done writing - _application.Output.Complete(); + if (_gracefulClose) + { + // We're done writing + _application.Output.Complete(); + } } } - private async Task StartSending(WebSocket socket) + private async Task StartSending(WebSocket socket, bool ignoreFirstCancel) { Exception? error = null; @@ -200,11 +212,13 @@ private async Task StartSending(WebSocket socket) try { - if (result.IsCanceled) + if (result.IsCanceled && !ignoreFirstCancel) { break; } + ignoreFirstCancel = false; + if (!buffer.IsEmpty) { try @@ -225,6 +239,12 @@ private async Task StartSending(WebSocket socket) break; } } + catch (OperationCanceledException ex) when (ex.CancellationToken == _connection.SendingToken) + { + _gracefulClose = true; + // TODO: probably log + break; + } catch (Exception ex) { if (!_aborted) @@ -266,7 +286,14 @@ private async Task StartSending(WebSocket socket) } } - _application.Input.Complete(); + if (_gracefulClose) + { + _application.Input.Complete(error); + } + else if (error is not null) + { + Log.SendErrored(_logger, error); + } } } diff --git a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs index 9d16f32f4a6b..40b25d09dfc0 100644 --- a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs +++ b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs @@ -1392,7 +1392,7 @@ public async Task RequestToActiveConnectionId409ForStreamingTransports(HttpTrans var options = new HttpConnectionDispatcherOptions(); var request1 = dispatcher.ExecuteAsync(context1, options, app); - await dispatcher.ExecuteAsync(context2, options, app); + await dispatcher.ExecuteAsync(context2, options, app).DefaultTimeout(); Assert.Equal(StatusCodes.Status409Conflict, context2.Response.StatusCode); diff --git a/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs b/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs index ff067e567f4a..a08e47a0b79f 100644 --- a/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs +++ b/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs @@ -95,6 +95,7 @@ public async Task DisposingConnectionsClosesBothSidesOfThePipe(ConnectionStates { throw new Exception("Transport failed"); } + return false; }); } @@ -102,7 +103,7 @@ public async Task DisposingConnectionsClosesBothSidesOfThePipe(ConnectionStates { // If the transport is faulted then we want to make sure the transport task only completes after // the application completes - connection.TransportTask = Task.FromException(new Exception("Application failed")); + connection.TransportTask = Task.FromException(new Exception("Application failed")); connection.ApplicationTask = Task.Run(async () => { // Wait for the application to end @@ -113,7 +114,7 @@ public async Task DisposingConnectionsClosesBothSidesOfThePipe(ConnectionStates else { connection.ApplicationTask = Task.CompletedTask; - connection.TransportTask = Task.CompletedTask; + connection.TransportTask = Task.FromResult(true); } try @@ -271,6 +272,7 @@ public async Task CloseConnectionsEndsAllPendingConnections() { connection.Application.Input.AdvanceTo(result.Buffer.End); } + return true; }); connectionManager.CloseConnections(); @@ -286,7 +288,7 @@ public async Task DisposingConnectionMultipleTimesWaitsOnConnectionClose() { var connectionManager = CreateConnectionManager(LoggerFactory); var connection = connectionManager.CreateConnection(); - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); connection.ApplicationTask = tcs.Task; connection.TransportTask = tcs.Task; @@ -296,7 +298,7 @@ public async Task DisposingConnectionMultipleTimesWaitsOnConnectionClose() Assert.False(firstTask.IsCompleted); Assert.False(secondTask.IsCompleted); - tcs.TrySetResult(); + tcs.TrySetResult(true); await Task.WhenAll(firstTask, secondTask).DefaultTimeout(); } @@ -309,7 +311,7 @@ public async Task DisposingConnectionMultipleGetsExceptionFromTransportOrApp() { var connectionManager = CreateConnectionManager(LoggerFactory); var connection = connectionManager.CreateConnection(); - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); connection.ApplicationTask = tcs.Task; connection.TransportTask = tcs.Task; @@ -336,7 +338,7 @@ public async Task DisposingConnectionMultipleGetsCancellation() { var connectionManager = CreateConnectionManager(LoggerFactory); var connection = connectionManager.CreateConnection(); - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); connection.ApplicationTask = tcs.Task; connection.TransportTask = tcs.Task; diff --git a/src/SignalR/common/Http.Connections/test/WebSocketsTests.cs b/src/SignalR/common/Http.Connections/test/WebSocketsTests.cs index a3283f1b8a51..48163419726b 100644 --- a/src/SignalR/common/Http.Connections/test/WebSocketsTests.cs +++ b/src/SignalR/common/Http.Connections/test/WebSocketsTests.cs @@ -110,7 +110,8 @@ public async Task WebSocketTransportSetsMessageTypeBasedOnTransferFormatFeature( private HttpConnectionContext CreateHttpConnectionContext(DuplexPipe.DuplexPipePair pair, string loggerName = null) { - return new HttpConnectionContext("foo", connectionToken: null, LoggerFactory.CreateLogger(loggerName ?? nameof(HttpConnectionContext)), metricsContext: default, pair.Transport, pair.Application, new()); + return new HttpConnectionContext("foo", connectionToken: null, LoggerFactory.CreateLogger(loggerName ?? nameof(HttpConnectionContext)), + metricsContext: default, pair.Transport, pair.Application, new(), useAcks: false); } [Fact] diff --git a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs index 3feaefc13d10..01322cd1ecab 100644 --- a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs +++ b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs @@ -43,6 +43,8 @@ public sealed class JsonHubProtocol : IHubProtocol private static readonly JsonEncodedText ArgumentsPropertyNameBytes = JsonEncodedText.Encode(ArgumentsPropertyName); private const string HeadersPropertyName = "headers"; private static readonly JsonEncodedText HeadersPropertyNameBytes = JsonEncodedText.Encode(HeadersPropertyName); + private const string SequenceIdPropertyName = "sequenceId"; + private static readonly JsonEncodedText SequenceIdPropertyNameBytes = JsonEncodedText.Encode(SequenceIdPropertyName); private const string ProtocolName = "json"; private const int ProtocolVersion = 1; @@ -139,6 +141,7 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) Dictionary? headers = null; var completed = false; var allowReconnect = false; + long? sequenceId = null; var reader = new Utf8JsonReader(input, isFinalBlock: true, state: default); @@ -325,6 +328,10 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) reader.CheckRead(); headers = ReadHeaders(ref reader); } + else if (reader.ValueTextEquals(SequenceIdPropertyNameBytes.EncodedUtf8Bytes)) + { + sequenceId = reader.ReadAsInt64(SequenceIdPropertyName); + } else { reader.CheckRead(); @@ -452,6 +459,10 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) return PingMessage.Instance; case HubProtocolConstants.CloseMessageType: return BindCloseMessage(error, allowReconnect); + case HubProtocolConstants.AckMessageType: + return BindAckMessage(sequenceId); + case HubProtocolConstants.SequenceMessageType: + return BindSequenceMessage(sequenceId); case null: throw new InvalidDataException($"Missing required property '{TypePropertyName}'."); default: @@ -544,6 +555,14 @@ private void WriteMessageCore(HubMessage message, IBufferWriter stream) WriteMessageType(writer, HubProtocolConstants.CloseMessageType); WriteCloseMessage(m, writer); break; + case AckMessage m: + WriteMessageType(writer, HubProtocolConstants.AckMessageType); + WriteAckMessage(m, writer); + break; + case SequenceMessage m: + WriteMessageType(writer, HubProtocolConstants.SequenceMessageType); + WriteSequenceMessage(m, writer); + break; default: throw new InvalidOperationException($"Unsupported message type: {message.GetType().FullName}"); } @@ -651,6 +670,16 @@ private static void WriteCloseMessage(CloseMessage message, Utf8JsonWriter write } } + private static void WriteAckMessage(AckMessage message, Utf8JsonWriter writer) + { + writer.WriteNumber(SequenceIdPropertyName, message.SequenceId); + } + + private static void WriteSequenceMessage(SequenceMessage message, Utf8JsonWriter writer) + { + writer.WriteNumber(SequenceIdPropertyName, message.SequenceId); + } + private void WriteArguments(object?[] arguments, Utf8JsonWriter writer) { writer.WriteStartArray(ArgumentsPropertyNameBytes); @@ -741,7 +770,8 @@ private static HubMessage BindStreamItemMessage(string invocationId, object? ite return new StreamItemMessage(invocationId, item); } - private static HubMessage BindStreamInvocationMessage(string? invocationId, string target, object?[]? arguments, bool hasArguments, string[]? streamIds) + private static HubMessage BindStreamInvocationMessage(string? invocationId, string target, + object?[]? arguments, bool hasArguments, string[]? streamIds) { if (string.IsNullOrEmpty(invocationId)) { @@ -763,7 +793,8 @@ private static HubMessage BindStreamInvocationMessage(string? invocationId, stri return new StreamInvocationMessage(invocationId, target, arguments, streamIds); } - private static HubMessage BindInvocationMessage(string? invocationId, string target, object?[]? arguments, bool hasArguments, string[]? streamIds) + private static HubMessage BindInvocationMessage(string? invocationId, string target, + object?[]? arguments, bool hasArguments, string[]? streamIds) { if (string.IsNullOrEmpty(target)) { @@ -853,6 +884,26 @@ private static CloseMessage BindCloseMessage(string? error, bool allowReconnect) return new CloseMessage(error, allowReconnect); } + private static AckMessage BindAckMessage(long? sequenceId) + { + if (sequenceId is null) + { + throw new InvalidDataException("Missing 'sequenceId' in Ack message."); + } + + return new AckMessage(sequenceId.Value); + } + + private static SequenceMessage BindSequenceMessage(long? sequenceId) + { + if (sequenceId is null) + { + throw new InvalidDataException("Missing 'sequenceId' in Sequence message."); + } + + return new SequenceMessage(sequenceId.Value); + } + private static HubMessage ApplyHeaders(HubMessage message, Dictionary? headers) { if (headers != null && message is HubInvocationMessage invocationMessage) diff --git a/src/SignalR/common/Shared/MessageBuffer.cs b/src/SignalR/common/Shared/MessageBuffer.cs new file mode 100644 index 000000000000..01c3b2850349 --- /dev/null +++ b/src/SignalR/common/Shared/MessageBuffer.cs @@ -0,0 +1,512 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Internal; +using Microsoft.AspNetCore.SignalR.Protocol; + +namespace Microsoft.AspNetCore.SignalR.Internal; + +internal sealed class MessageBuffer : IDisposable +{ + private static readonly TaskCompletionSource _completedTCS = new TaskCompletionSource(); + + private readonly ConnectionContext _connection; + private readonly IHubProtocol _protocol; + private readonly AckMessage _ackMessage = new(0); + private readonly SequenceMessage _sequenceMessage = new(0); + private readonly Channel _waitForAck = Channel.CreateBounded(new BoundedChannelOptions(1) { FullMode = BoundedChannelFullMode.DropOldest }); + private readonly int _bufferLimit = 100 * 1000; + +#if NET8_0_OR_GREATER + private readonly PeriodicTimer _timer = new(TimeSpan.FromSeconds(1)); +#else + private readonly TimerAwaitable _timer = new(TimeSpan.FromSeconds(1), TimeSpan.FromSeconds(1)); +#endif + private readonly SemaphoreSlim _writeLock = new(1, 1); + + private long _totalMessageCount; + private bool _waitForSequenceMessage; + + // Message IDs start at 1 and always increment by 1 + private long _currentReceivingSequenceId = 1; + private long _latestReceivedSequenceId = long.MinValue; + private long _lastAckedId = long.MinValue; + + private TaskCompletionSource _resend = _completedTCS; + + private object Lock => _buffer; + + private LinkedBuffer _buffer; + private int _bufferedByteCount; + + static MessageBuffer() + { + _completedTCS.SetResult(new()); + } + + // TODO: pass in limits + public MessageBuffer(ConnectionContext connection, IHubProtocol protocol) + { + // TODO: pool + _buffer = new LinkedBuffer(); + + _connection = connection; + _protocol = protocol; + +#if !NET8_0_OR_GREATER + _timer.Start(); +#endif + _ = RunTimer(); + } + + private async Task RunTimer() + { + using (_timer) + { +#if NET8_0_OR_GREATER + while (await _timer.WaitForNextTickAsync().ConfigureAwait(false)) +#else + while (await _timer) +#endif + { + if (_lastAckedId < _latestReceivedSequenceId) + { + // TODO: consider a minimum time between sending these? + + var sequenceId = _latestReceivedSequenceId; + + _ackMessage.SequenceId = sequenceId; + + await _writeLock.WaitAsync().ConfigureAwait(false); + try + { + _protocol.WriteMessage(_ackMessage, _connection.Transport.Output); + await _connection.Transport.Output.FlushAsync().ConfigureAwait(false); + _lastAckedId = sequenceId; + } + finally + { + _writeLock.Release(); + } + } + } + } + } + + /// + /// Calling code is assumed to not call this method in parallel. Currently HubConnection and HubConnectionContext respect that. + /// + // TODO: WriteAsync(HubMessage) overload, so we don't allocate SerializedHubMessage for messages that aren't going to be buffered + public async ValueTask WriteAsync(SerializedHubMessage hubMessage, CancellationToken cancellationToken) + { + // TODO: Backpressure based on message count and total message size + if (_bufferedByteCount > _bufferLimit) + { + // primitive backpressure if buffer is full + while (await _waitForAck.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) + { + if (_waitForAck.Reader.TryRead(out var count) && count < _bufferLimit) + { + break; + } + } + } + + // Avoid condition where last Ack position is the position we're currently writing into the buffer + // If we wrote messages around the entire buffer before another Ack arrived we would end up reading the Ack position and writing over a buffered message + _waitForAck.Reader.TryRead(out _); + + // TODO: We could consider buffering messages until they hit backpressure in the case when the connection is down + await _resend.Task.ConfigureAwait(false); + + var waitForResend = false; + + await _writeLock.WaitAsync(cancellationToken: default).ConfigureAwait(false); + try + { + if (hubMessage.Message is HubInvocationMessage invocationMessage) + { + _totalMessageCount++; + } + else + { + // Non-ackable message, don't add to buffer + return await _connection.Transport.Output.WriteAsync(hubMessage.GetSerializedMessage(_protocol), cancellationToken).ConfigureAwait(false); + } + + var messageBytes = hubMessage.GetSerializedMessage(_protocol); + lock (Lock) + { + _bufferedByteCount += messageBytes.Length; + _buffer.AddMessage(hubMessage, _totalMessageCount); + } + + return await _connection.Transport.Output.WriteAsync(messageBytes, cancellationToken).ConfigureAwait(false); + } + // TODO: figure out what exception to use + catch (ConnectionResetException) + { + waitForResend = true; + } + finally + { + _writeLock.Release(); + } + + if (waitForResend) + { + var oldTcs = Interlocked.Exchange(ref _resend, new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously)); + if (!oldTcs.Task.IsCompleted) + { + return await oldTcs.Task.ConfigureAwait(false); + } + return await _resend.Task.ConfigureAwait(false); + } + + throw new NotImplementedException("shouldn't reach here"); + } + + public void Ack(AckMessage ackMessage) + { + // TODO: what if ackMessage.SequenceId is larger than last sent message? + + var newCount = -1; + + lock (Lock) + { + var item = _buffer.RemoveMessages(ackMessage.SequenceId, _protocol); + _buffer = item.Item1; + _bufferedByteCount -= item.Item2; + + newCount = _bufferedByteCount; + } + + // Release potential backpressure + if (newCount >= 0) + { + _waitForAck.Writer.TryWrite(newCount); + } + } + + internal bool ShouldProcessMessage(HubMessage message) + { + // TODO: if we're expecting a sequence message but get here should we error or ignore or maybe even continue to process them? + if (_waitForSequenceMessage) + { + if (message is SequenceMessage) + { + _waitForSequenceMessage = false; + return true; + } + else + { + // ignore messages received while waiting for sequence message + return false; + } + } + + // Only care about messages implementing HubInvocationMessage currently (e.g. ignore ping, close, ack, sequence) + // Could expand in the future, but should probably rev the ack version if changes are made + if (message is not HubInvocationMessage) + { + return true; + } + + var currentId = _currentReceivingSequenceId; + _currentReceivingSequenceId++; + if (currentId <= _latestReceivedSequenceId) + { + // Ignore, this is a duplicate message + return false; + } + _latestReceivedSequenceId = currentId; + + return true; + } + + internal void ResetSequence(SequenceMessage sequenceMessage) + { + // TODO: is a sequence message expected right now? + + if (sequenceMessage.SequenceId > _currentReceivingSequenceId) + { + throw new InvalidOperationException("Sequence ID greater than amount of messages we've received."); + } + _currentReceivingSequenceId = sequenceMessage.SequenceId; + } + + internal void Resend() + { + _waitForSequenceMessage = true; + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var oldTcs = Interlocked.Exchange(ref _resend, tcs); + // WriteAsync can also try to swap the TCS, we need to check if it's completed to know if it was swapped or not + if (!oldTcs.Task.IsCompleted) + { + // Swap back to the TCS created by WriteAsync since it's waiting on the result of that task + Interlocked.Exchange(ref _resend, oldTcs); + tcs = oldTcs; + } + + _ = DoResendAsync(tcs); + } + + private async Task DoResendAsync(TaskCompletionSource tcs) + { + FlushResult finalResult = new(); + await _writeLock.WaitAsync().ConfigureAwait(false); + try + { + _sequenceMessage.SequenceId = _totalMessageCount + 1; + + var isFirst = true; + foreach (var item in _buffer.GetMessages()) + { + if (item.SequenceId > 0) + { + if (isFirst) + { + _sequenceMessage.SequenceId = item.SequenceId; + _protocol.WriteMessage(_sequenceMessage, _connection.Transport.Output); + isFirst = false; + } + finalResult = await _connection.Transport.Output.WriteAsync(item.HubMessage!.GetSerializedMessage(_protocol)).ConfigureAwait(false); + } + } + + if (isFirst) + { + _protocol.WriteMessage(_sequenceMessage, _connection.Transport.Output); + finalResult = await _connection.Transport.Output.FlushAsync().ConfigureAwait(false); + } + } + catch (Exception ex) + { + tcs.SetException(ex); + } + finally + { + _writeLock.Release(); + tcs.TrySetResult(finalResult); + } + } + + public void Dispose() + { + ((IDisposable)_timer).Dispose(); + } + + // Linked list of SerializedHubMessage arrays, sort of like ReadOnlySequence + private sealed class LinkedBuffer + { + private const int BufferLength = 10; + + private int _currentIndex = -1; + private int _ackedIndex = -1; + private long _startingSequenceId = long.MinValue; + private LinkedBuffer? _next; + + private readonly SerializedHubMessage?[] _messages = new SerializedHubMessage?[BufferLength]; + + public void AddMessage(SerializedHubMessage hubMessage, long sequenceId) + { + if (_startingSequenceId < 0) + { + Debug.Assert(_currentIndex == -1); + _startingSequenceId = sequenceId; + } + + if (_currentIndex < BufferLength - 1) + { + Debug.Assert(_startingSequenceId + _currentIndex + 1 == sequenceId); + + _currentIndex++; + _messages[_currentIndex] = hubMessage; + } + else if (_next is null) + { + _next = new LinkedBuffer(); + _next.AddMessage(hubMessage, sequenceId); + } + else + { + // TODO: Should we avoid this path by keeping a tail pointer? + // Debug.Assert(false); + + var linkedBuffer = _next; + while (linkedBuffer._next is not null) + { + linkedBuffer = linkedBuffer._next; + } + + // TODO: verify no stack overflow potential + linkedBuffer.AddMessage(hubMessage, sequenceId); + } + } + + public (LinkedBuffer, int returnCredit) RemoveMessages(long sequenceId, IHubProtocol protocol) + { + return RemoveMessagesCore(this, sequenceId, protocol); + } + + private static (LinkedBuffer, int returnCredit) RemoveMessagesCore(LinkedBuffer linkedBuffer, long sequenceId, IHubProtocol protocol) + { + var returnCredit = 0; + while (linkedBuffer._startingSequenceId <= sequenceId) + { + var numElements = (int)Math.Min(BufferLength, Math.Max(1, sequenceId - (linkedBuffer._startingSequenceId - 1))); + Debug.Assert(numElements > 0 && numElements < BufferLength + 1); + + for (var i = 0; i < numElements; i++) + { + returnCredit += linkedBuffer._messages[i]?.GetSerializedMessage(protocol).Length ?? 0; + linkedBuffer._messages[i] = null; + } + + linkedBuffer._ackedIndex = numElements - 1; + + if (numElements == BufferLength) + { + if (linkedBuffer._next is null) + { + linkedBuffer.Reset(shouldPool: false); + return (linkedBuffer, returnCredit); + } + else + { + var tmp = linkedBuffer; + linkedBuffer = linkedBuffer._next; + tmp.Reset(shouldPool: true); + } + } + else + { + return (linkedBuffer, returnCredit); + } + } + + return (linkedBuffer, returnCredit); + } + + private void Reset(bool shouldPool) + { + _startingSequenceId = long.MinValue; + _currentIndex = -1; + _ackedIndex = -1; + _next = null; + + Array.Clear(_messages, 0, BufferLength); + + // TODO: Add back to pool + if (shouldPool) + { + } + } + + public IEnumerable<(SerializedHubMessage? HubMessage, long SequenceId)> GetMessages() + { + return new Enumerable(this); + } + + private struct Enumerable : IEnumerable<(SerializedHubMessage?, long)> + { + private readonly LinkedBuffer _linkedBuffer; + + public Enumerable(LinkedBuffer linkedBuffer) + { + _linkedBuffer = linkedBuffer; + } + + public IEnumerator<(SerializedHubMessage?, long)> GetEnumerator() + { + return new Enumerator(_linkedBuffer); + } + + IEnumerator IEnumerable.GetEnumerator() + { + throw new NotImplementedException(); + } + } + + private struct Enumerator : IEnumerator<(SerializedHubMessage?, long)> + { + private LinkedBuffer? _linkedBuffer; + private int _index; + + public Enumerator(LinkedBuffer linkedBuffer) + { + _linkedBuffer = linkedBuffer; + } + + public (SerializedHubMessage?, long) Current + { + get + { + if (_linkedBuffer is null) + { + return (null, long.MinValue); + } + + var index = _index - 1; + var firstMessageIndex = _linkedBuffer._ackedIndex + 1; + if (firstMessageIndex + index < BufferLength) + { + return (_linkedBuffer._messages[firstMessageIndex + index], _linkedBuffer._startingSequenceId + firstMessageIndex + index); + } + + return (null, long.MinValue); + } + } + + object IEnumerator.Current => throw new NotImplementedException(); + + public void Dispose() + { + _linkedBuffer = null; + } + + public bool MoveNext() + { + if (_linkedBuffer is null) + { + return false; + } + + var firstMessageIndex = _linkedBuffer._ackedIndex + 1; + if (firstMessageIndex + _index >= BufferLength) + { + _linkedBuffer = _linkedBuffer._next; + _index = 1; + } + else + { + if (_linkedBuffer._messages[firstMessageIndex + _index] is null) + { + _linkedBuffer = null; + } + else + { + _index++; + } + } + + return _linkedBuffer is not null; + } + + public void Reset() + { + throw new NotImplementedException(); + } + } + } +} diff --git a/src/SignalR/common/Shared/SystemTextJsonExtensions.cs b/src/SignalR/common/Shared/SystemTextJsonExtensions.cs index c28ba9e16398..30f1c9adc6ab 100644 --- a/src/SignalR/common/Shared/SystemTextJsonExtensions.cs +++ b/src/SignalR/common/Shared/SystemTextJsonExtensions.cs @@ -97,4 +97,21 @@ public static string ReadAsString(this ref Utf8JsonReader reader, string propert return reader.GetInt32(); } + + public static long? ReadAsInt64(this ref Utf8JsonReader reader, string propertyName) + { + reader.Read(); + + if (reader.TokenType == JsonTokenType.Null) + { + return null; + } + + if (reader.TokenType != JsonTokenType.Number) + { + throw new InvalidDataException($"Expected '{propertyName}' to be of type {JsonTokenType.Number}."); + } + + return reader.GetInt64(); + } } diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs b/src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs new file mode 100644 index 000000000000..33ef91881fa9 --- /dev/null +++ b/src/SignalR/common/SignalR.Common/src/Protocol/AckMessage.cs @@ -0,0 +1,50 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR.Protocol; + +/// +/// Represents the ID being acknowledged so we can stop buffering older messages. +/// +public sealed class AckMessage : HubMessage +{ + /// + /// + /// + /// + public AckMessage(long sequenceId) + { + SequenceId = sequenceId; + } + + /// + /// + /// + public long SequenceId { get; set; } +} + +/// +/// Represents the restart of the sequence of messages being sent. is the starting ID of messages being sent, which might be duplicate messages. +/// +public sealed class SequenceMessage : HubMessage +{ + /// + /// + /// + /// + public SequenceMessage(long sequenceId) + { + SequenceId = sequenceId; + } + + /// + /// + /// + public long SequenceId { get; set; } +} diff --git a/src/SignalR/common/SignalR.Common/src/Protocol/HubProtocolConstants.cs b/src/SignalR/common/SignalR.Common/src/Protocol/HubProtocolConstants.cs index eb1e3914ac17..c5e67987ae92 100644 --- a/src/SignalR/common/SignalR.Common/src/Protocol/HubProtocolConstants.cs +++ b/src/SignalR/common/SignalR.Common/src/Protocol/HubProtocolConstants.cs @@ -42,4 +42,14 @@ public static class HubProtocolConstants /// Represents the close message type. /// public const int CloseMessageType = 7; + + /// + /// + /// + public const int AckMessageType = 8; + + /// + /// + /// + public const int SequenceMessageType = 9; } diff --git a/src/SignalR/common/SignalR.Common/src/PublicAPI/net462/PublicAPI.Unshipped.txt b/src/SignalR/common/SignalR.Common/src/PublicAPI/net462/PublicAPI.Unshipped.txt index 7dc5c58110bf..29b26f2e7839 100644 --- a/src/SignalR/common/SignalR.Common/src/PublicAPI/net462/PublicAPI.Unshipped.txt +++ b/src/SignalR/common/SignalR.Common/src/PublicAPI/net462/PublicAPI.Unshipped.txt @@ -1 +1,11 @@ #nullable enable +const Microsoft.AspNetCore.SignalR.Protocol.HubProtocolConstants.AckMessageType = 8 -> int +const Microsoft.AspNetCore.SignalR.Protocol.HubProtocolConstants.SequenceMessageType = 9 -> int +Microsoft.AspNetCore.SignalR.Protocol.AckMessage +Microsoft.AspNetCore.SignalR.Protocol.AckMessage.AckMessage(long sequenceId) -> void +Microsoft.AspNetCore.SignalR.Protocol.AckMessage.SequenceId.get -> long +Microsoft.AspNetCore.SignalR.Protocol.AckMessage.SequenceId.set -> void +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage.SequenceId.get -> long +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage.SequenceId.set -> void +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage.SequenceMessage(long sequenceId) -> void diff --git a/src/SignalR/common/SignalR.Common/src/PublicAPI/net8.0/PublicAPI.Unshipped.txt b/src/SignalR/common/SignalR.Common/src/PublicAPI/net8.0/PublicAPI.Unshipped.txt index 7dc5c58110bf..29b26f2e7839 100644 --- a/src/SignalR/common/SignalR.Common/src/PublicAPI/net8.0/PublicAPI.Unshipped.txt +++ b/src/SignalR/common/SignalR.Common/src/PublicAPI/net8.0/PublicAPI.Unshipped.txt @@ -1 +1,11 @@ #nullable enable +const Microsoft.AspNetCore.SignalR.Protocol.HubProtocolConstants.AckMessageType = 8 -> int +const Microsoft.AspNetCore.SignalR.Protocol.HubProtocolConstants.SequenceMessageType = 9 -> int +Microsoft.AspNetCore.SignalR.Protocol.AckMessage +Microsoft.AspNetCore.SignalR.Protocol.AckMessage.AckMessage(long sequenceId) -> void +Microsoft.AspNetCore.SignalR.Protocol.AckMessage.SequenceId.get -> long +Microsoft.AspNetCore.SignalR.Protocol.AckMessage.SequenceId.set -> void +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage.SequenceId.get -> long +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage.SequenceId.set -> void +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage.SequenceMessage(long sequenceId) -> void diff --git a/src/SignalR/common/SignalR.Common/src/PublicAPI/netstandard2.0/PublicAPI.Unshipped.txt b/src/SignalR/common/SignalR.Common/src/PublicAPI/netstandard2.0/PublicAPI.Unshipped.txt index 7dc5c58110bf..29b26f2e7839 100644 --- a/src/SignalR/common/SignalR.Common/src/PublicAPI/netstandard2.0/PublicAPI.Unshipped.txt +++ b/src/SignalR/common/SignalR.Common/src/PublicAPI/netstandard2.0/PublicAPI.Unshipped.txt @@ -1 +1,11 @@ #nullable enable +const Microsoft.AspNetCore.SignalR.Protocol.HubProtocolConstants.AckMessageType = 8 -> int +const Microsoft.AspNetCore.SignalR.Protocol.HubProtocolConstants.SequenceMessageType = 9 -> int +Microsoft.AspNetCore.SignalR.Protocol.AckMessage +Microsoft.AspNetCore.SignalR.Protocol.AckMessage.AckMessage(long sequenceId) -> void +Microsoft.AspNetCore.SignalR.Protocol.AckMessage.SequenceId.get -> long +Microsoft.AspNetCore.SignalR.Protocol.AckMessage.SequenceId.set -> void +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage.SequenceId.get -> long +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage.SequenceId.set -> void +Microsoft.AspNetCore.SignalR.Protocol.SequenceMessage.SequenceMessage(long sequenceId) -> void diff --git a/src/SignalR/docs/specs/TransportProtocols.md b/src/SignalR/docs/specs/TransportProtocols.md index a4c10f4eadfa..e2fb3d28cb17 100644 --- a/src/SignalR/docs/specs/TransportProtocols.md +++ b/src/SignalR/docs/specs/TransportProtocols.md @@ -20,12 +20,20 @@ Throughout this document, the term `[endpoint-base]` is used to refer to the rou The `POST [endpoint-base]/negotiate` request is used to establish a connection between the client and the server. +*negotiateVersion:* + In the POST request the client sends a query string parameter with the key "negotiateVersion" and the value as the negotiate protocol version it would like to use. If the query string is omitted, the server treats the version as zero. The server will include a "negotiateVersion" property in the json response that says which version it will be using. The version is chosen as described below: * If the servers minimum supported protocol version is greater than the version requested by the client it will send an error response and close the connection * If the server supports the request version it will respond with the requested version * If the requested version is greater than the servers largest supported version the server will respond with its largest supported version The client may close the connection if the "negotiateVersion" in the response is not acceptable. +*useAck:* + +In the POST request the client may include a query string parameter with the key "useAck" and the value of "true". If this is included the server will decide if it supports/allows the [ack protocol](#todo) described below, and return "useAck": "true" as a json property in the negotiate response if it will use the ack protocol. If true, the client can reconnect using the same transport and reuse the connectionToken/connectionId. The server may still reject the reconnect if it takes too long or for any reason it chooses. If false, the client must not reuse connectionToken/connectionId. If the "useAck" property is missing from the negotiate response this also implies false, so the ack protocol should not be used. + +----------- + The content type of the response is `application/json` and is a JSON payload containing properties to assist the client in establishing a persistent connection. Extra JSON properties that the client does not know about should be ignored. This allows for future additions without breaking older clients. ### Version 1 diff --git a/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs b/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs index bdbf8443d752..d0efd3d4489c 100644 --- a/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs +++ b/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs @@ -35,6 +35,7 @@ public void GlobalSetup() new HubContext(hubLifetimeManager), enableDetailedErrors: false, disableImplicitFromServiceParameters: true, + useAcks: false, new Logger>(NullLoggerFactory.Instance), hubFilters: null, hubLifetimeManager); diff --git a/src/SignalR/samples/ClientSample/HubSample.cs b/src/SignalR/samples/ClientSample/HubSample.cs index f77b9c3a9a8d..506a520b2238 100644 --- a/src/SignalR/samples/ClientSample/HubSample.cs +++ b/src/SignalR/samples/ClientSample/HubSample.cs @@ -99,7 +99,7 @@ public static async Task ExecuteAsync(string baseUrl) try { - await connection.InvokeAsync("Send", line); + await connection.InvokeAsync("Send", "C#", line); } catch when (closedTokenSource.IsCancellationRequested) { diff --git a/src/SignalR/samples/SignalRSamples/Program.cs b/src/SignalR/samples/SignalRSamples/Program.cs index 3675e5c32b37..c610b486315f 100644 --- a/src/SignalR/samples/SignalRSamples/Program.cs +++ b/src/SignalR/samples/SignalRSamples/Program.cs @@ -25,12 +25,12 @@ public static Task Main(string[] args) { factory.AddConfiguration(c.Configuration.GetSection("Logging")); factory.AddConsole(); - //factory.SetMinimumLevel(LogLevel.Trace); + factory.SetMinimumLevel(LogLevel.Debug); }) .UseKestrel(options => { // Default port - options.ListenAnyIP(0); + options.ListenAnyIP(5000); // Hub bound to TCP end point //options.Listen(IPAddress.Any, 9001, builder => diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index a2a9f24429ef..d9dcc8beca2e 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -8,6 +8,7 @@ using System.IO.Pipelines; using System.Security.Claims; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Abstractions; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.SignalR.Internal; @@ -36,6 +37,7 @@ public partial class HubConnectionContext private readonly CancellationTokenRegistration _closedRegistration; private readonly CancellationTokenRegistration? _closedRequestedRegistration; + private MessageBuffer? _messageBuffer; private StreamTracker? _streamTracker; private long _lastSendTick; private ReadOnlyMemory _cachedPingMessage; @@ -48,6 +50,10 @@ public partial class HubConnectionContext private TimeSpan _receivedMessageElapsed; private long _receivedMessageTick; private ClaimsPrincipal? _user; + private bool _useAcks; + + [MemberNotNullWhen(true, nameof(_messageBuffer))] + internal bool UsingAcks() => _useAcks; /// /// Initializes a new instance of the class. @@ -254,11 +260,18 @@ private ValueTask WriteCore(HubMessage message, CancellationToken c { try { - // We know that we are only writing this message to one receiver, so we can - // write it without caching. - Protocol.WriteMessage(message, _connectionContext.Transport.Output); + if (UsingAcks()) + { + return _messageBuffer.WriteAsync(new SerializedHubMessage(message), cancellationToken); + } + else + { + // We know that we are only writing this message to one receiver, so we can + // write it without caching. + Protocol.WriteMessage(message, _connectionContext.Transport.Output); - return _connectionContext.Transport.Output.FlushAsync(cancellationToken); + return _connectionContext.Transport.Output.FlushAsync(cancellationToken); + } } catch (Exception ex) { @@ -275,10 +288,18 @@ private ValueTask WriteCore(SerializedHubMessage message, Cancellat { try { - // Grab a preserialized buffer for this protocol. - var buffer = message.GetSerializedMessage(Protocol); + if (UsingAcks()) + { + Debug.Assert(_messageBuffer is not null); + return _messageBuffer.WriteAsync(message, cancellationToken); + } + else + { + // Grab a potentially pre-serialized buffer for this protocol. + var buffer = message.GetSerializedMessage(Protocol); - return _connectionContext.Transport.Output.WriteAsync(buffer, cancellationToken); + return _connectionContext.Transport.Output.WriteAsync(buffer, cancellationToken); + } } catch (Exception ex) { @@ -550,6 +571,13 @@ await WriteHandshakeResponseAsync(new HandshakeResponseMessage( Log.HandshakeComplete(_logger, Protocol.Name); await WriteHandshakeResponseAsync(HandshakeResponseMessage.Empty); + + if (_connectionContext.Features.Get() is IReconnectFeature feature) + { + _useAcks = true; + _messageBuffer = new MessageBuffer(_connectionContext, Protocol); + feature.NotifyOnReconnect = _messageBuffer.Resend; + } return true; } else if (overLength) @@ -725,10 +753,36 @@ internal void StopClientTimeout() internal void Cleanup() { + _messageBuffer?.Dispose(); _closedRegistration.Dispose(); _closedRequestedRegistration?.Dispose(); // Use _streamTracker to avoid lazy init from StreamTracker getter if it doesn't exist _streamTracker?.CompleteAll(new OperationCanceledException("The underlying connection was closed.")); } + + internal void Ack(AckMessage ackMessage) + { + if (UsingAcks()) + { + _messageBuffer.Ack(ackMessage); + } + } + + internal bool ShouldProcessMessage(HubMessage message) + { + if (UsingAcks()) + { + return _messageBuffer.ShouldProcessMessage(message); + } + return true; + } + + internal void ResetSequence(SequenceMessage sequenceMessage) + { + if (UsingAcks()) + { + _messageBuffer.ResetSequence(sequenceMessage); + } + } } diff --git a/src/SignalR/server/Core/src/HubConnectionHandler.cs b/src/SignalR/server/Core/src/HubConnectionHandler.cs index ab3d0f5bbd7b..21e9061897e8 100644 --- a/src/SignalR/server/Core/src/HubConnectionHandler.cs +++ b/src/SignalR/server/Core/src/HubConnectionHandler.cs @@ -94,6 +94,8 @@ IServiceScopeFactory serviceScopeFactory new HubContext(lifetimeManager), _enableDetailedErrors, disableImplicitFromServiceParameters, + // TODO + useAcks: true, new Logger>(loggerFactory), hubFilters, lifetimeManager); diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 3458f6760a9f..a632f59a7538 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -28,15 +28,17 @@ internal sealed partial class DefaultHubDispatcher : HubDispatcher w private readonly Func? _onConnectedMiddleware; private readonly Func? _onDisconnectedMiddleware; private readonly HubLifetimeManager _hubLifetimeManager; + private readonly bool _useAcks; public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContext hubContext, bool enableDetailedErrors, - bool disableImplicitFromServiceParameters, ILogger> logger, List? hubFilters, HubLifetimeManager lifetimeManager) + bool disableImplicitFromServiceParameters, bool useAcks, ILogger> logger, List? hubFilters, HubLifetimeManager lifetimeManager) { _serviceScopeFactory = serviceScopeFactory; _hubContext = hubContext; _enableDetailedErrors = enableDetailedErrors; _logger = logger; _hubLifetimeManager = lifetimeManager; + _useAcks = useAcks; DiscoverHubMethods(disableImplicitFromServiceParameters); var count = hubFilters?.Count ?? 0; @@ -130,6 +132,12 @@ public override Task DispatchMessageAsync(HubConnectionContext connection, HubMe // With parallel invokes enabled, messages run sequentially until they go async and then the next message will be allowed to start running. + if (!connection.ShouldProcessMessage(hubMessage)) + { + Log.DroppingMessage(_logger, ((HubInvocationMessage)hubMessage).GetType().Name, ((HubInvocationMessage)hubMessage).InvocationId); + return Task.CompletedTask; + } + switch (hubMessage) { case InvocationBindingFailureMessage bindingFailureMessage: @@ -186,6 +194,16 @@ public override Task DispatchMessageAsync(HubConnectionContext connection, HubMe } break; + case AckMessage ackMessage: + Log.ReceivedAckMessage(_logger, ackMessage.SequenceId); + connection.Ack(ackMessage); + break; + + case SequenceMessage sequenceMessage: + Log.ReceivedSequenceMessage(_logger, sequenceMessage.SequenceId); + connection.ResetSequence(sequenceMessage); + break; + // Other kind of message we weren't expecting default: Log.UnsupportedMessageReceived(_logger, hubMessage.GetType().FullName!); @@ -374,7 +392,7 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, // No InvocationId - Send Async, no response expected if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId)) { - // Invoke Async, one reponse expected + // Invoke Async, one response expected await connection.WriteAsync(CompletionMessage.WithResult(hubMethodInvocationMessage.InvocationId, result)); } } @@ -555,8 +573,7 @@ private async Task StreamAsync(string invocationId, HubConnectionContext connect } } - private static async Task SendInvocationError(string? invocationId, - HubConnectionContext connection, string errorMessage) + private static async Task SendInvocationError(string? invocationId, HubConnectionContext connection, string errorMessage) { if (string.IsNullOrEmpty(invocationId)) { diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcherLog.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcherLog.cs index 24d6ee8b7fca..dce901a2d952 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcherLog.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcherLog.cs @@ -108,4 +108,13 @@ public static void ClosingStreamWithBindingError(ILogger logger, CompletionMessa [LoggerMessage(25, LogLevel.Error, "Invocation ID {InvocationId}: Failed while sending stream items from hub method {HubMethod}.", EventName = "FailedStreaming")] public static partial void FailedStreaming(ILogger logger, string invocationId, string hubMethod, Exception exception); + + [LoggerMessage(26, LogLevel.Trace, "Dropping {MessageType} with ID '{InvocationId}'.", EventName = "DroppingMessage")] + public static partial void DroppingMessage(ILogger logger, string messageType, string? invocationId); + + [LoggerMessage(27, LogLevel.Trace, "Received AckMessage with Sequence ID '{SequenceId}'.", EventName = "ReceivedAckMessage")] + public static partial void ReceivedAckMessage(ILogger logger, long sequenceId); + + [LoggerMessage(28, LogLevel.Trace, "Received SequenceMessage with Sequence ID '{SequenceId}'.", EventName = "ReceivedSequenceMessage")] + public static partial void ReceivedSequenceMessage(ILogger logger, long sequenceId); } diff --git a/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj b/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj index db4c7d7e81df..b4552ea0c9f6 100644 --- a/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj +++ b/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj @@ -17,6 +17,7 @@ + diff --git a/src/SignalR/server/SignalR/test/DefaultTransportFactoryTests.cs b/src/SignalR/server/SignalR/test/DefaultTransportFactoryTests.cs index 3daaaa4e4477..3f954b0e37a2 100644 --- a/src/SignalR/server/SignalR/test/DefaultTransportFactoryTests.cs +++ b/src/SignalR/server/SignalR/test/DefaultTransportFactoryTests.cs @@ -53,7 +53,7 @@ public void DefaultTransportFactoryCreatesRequestedTransportIfAvailable(HttpTran { var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null); Assert.IsType(expectedTransportType, - transportFactory.CreateTransport(AllTransportTypes)); + transportFactory.CreateTransport(AllTransportTypes, useAck: true)); } [Theory] @@ -66,7 +66,7 @@ public void DefaultTransportFactoryThrowsIfItCannotCreateRequestedTransport(Http var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null); var ex = Assert.Throws( - () => transportFactory.CreateTransport(~requestedTransport)); + () => transportFactory.CreateTransport(~requestedTransport, useAck: true)); Assert.Equal("No requested transports available on the server.", ex.Message); } @@ -77,7 +77,7 @@ public void DefaultTransportFactoryCreatesWebSocketsTransportIfAvailable() { Assert.IsType( new DefaultTransportFactory(AllTransportTypes, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null) - .CreateTransport(AllTransportTypes)); + .CreateTransport(AllTransportTypes, useAck: true)); } [Theory] @@ -90,7 +90,7 @@ public void DefaultTransportFactoryCreatesRequestedTransportIfAvailable_Win7(Htt { var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null); Assert.IsType(expectedTransportType, - transportFactory.CreateTransport(AllTransportTypes)); + transportFactory.CreateTransport(AllTransportTypes, useAck: true)); } } @@ -103,7 +103,7 @@ public void DefaultTransportFactoryThrowsIfItCannotCreateRequestedTransport_Win7 var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null); var ex = Assert.Throws( - () => transportFactory.CreateTransport(AllTransportTypes)); + () => transportFactory.CreateTransport(AllTransportTypes, useAck: true)); Assert.Equal("No requested transports available on the server.", ex.Message); } diff --git a/src/SignalR/server/SignalR/test/EndToEndTests.cs b/src/SignalR/server/SignalR/test/EndToEndTests.cs index a6e85b802510..e280f20af9ef 100644 --- a/src/SignalR/server/SignalR/test/EndToEndTests.cs +++ b/src/SignalR/server/SignalR/test/EndToEndTests.cs @@ -341,7 +341,7 @@ async Task ReceiveMessage() { logger.LogInformation("Receiving message"); // Big timeout here because it can take a while to receive all the bytes - var receivedData = await connection.Transport.Input.ReadAsync(bytes.Length).DefaultTimeout(TimeSpan.FromMinutes(2)); + var receivedData = await connection.Transport.Input.ReadAsync(bytes.Length).DefaultTimeout(); Assert.Equal(message, Encoding.UTF8.GetString(receivedData)); logger.LogInformation("Completed receive"); } @@ -685,7 +685,7 @@ private class TestTransportFactory : ITransportFactory { private ITransport _transport; - public ITransport CreateTransport(HttpTransportType availableServerTransports) + public ITransport CreateTransport(HttpTransportType availableServerTransports, bool useAck) { if (_transport == null) { diff --git a/src/SignalR/server/SignalR/test/Internal/MessageBufferTests.cs b/src/SignalR/server/SignalR/test/Internal/MessageBufferTests.cs new file mode 100644 index 000000000000..823470c29d81 --- /dev/null +++ b/src/SignalR/server/SignalR/test/Internal/MessageBufferTests.cs @@ -0,0 +1,323 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.IO.Pipelines; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.AspNetCore.Testing; + +namespace Microsoft.AspNetCore.SignalR.Tests.Internal; + +public class MessageBufferTests +{ + [Fact] + public async Task CanWriteNonBufferedMessagesWithoutBlocking() + { + var protocol = new JsonHubProtocol(); + var connection = new TestConnectionContext(); + var pipes = DuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions()); + connection.Transport = pipes.Transport; + using var messageBuffer = new MessageBuffer(connection, protocol); + + for (var i = 0; i < 100; i++) + { + await messageBuffer.WriteAsync(new SerializedHubMessage(PingMessage.Instance), default).DefaultTimeout(); + } + + var count = 0; + while (count < 100) + { + var res = await pipes.Application.Input.ReadAsync().DefaultTimeout(); + + var buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out var message)); + Assert.IsType(message); + + pipes.Application.Input.AdvanceTo(buffer.Start); + count++; + } + } + + [Fact] + public async Task WriteBlocksOnAckWhenBufferFull() + { + var protocol = new JsonHubProtocol(); + var connection = new TestConnectionContext(); + var pipes = DuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions(pauseWriterThreshold: 200000, resumeWriterThreshold: 100000)); + connection.Transport = pipes.Transport; + using var messageBuffer = new MessageBuffer(connection, protocol); + + await messageBuffer.WriteAsync(new SerializedHubMessage(new InvocationMessage("t", new object[] { new byte[100000] })), default); + + var writeTask = messageBuffer.WriteAsync(new SerializedHubMessage(new StreamItemMessage("id", null)), default); + Assert.False(writeTask.IsCompleted); + + var res = await pipes.Application.Input.ReadAsync(); + + var buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out var message)); + Assert.IsType(message); + + pipes.Application.Input.AdvanceTo(buffer.Start); + + // Write not unblocked by read, only unblocked after ack received + Assert.False(writeTask.IsCompleted); + + messageBuffer.Ack(new AckMessage(1)); + await writeTask.DefaultTimeout(); + + res = await pipes.Application.Input.ReadAsync().DefaultTimeout(); + + buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message)); + Assert.IsType(message); + + pipes.Application.Input.AdvanceTo(buffer.Start); + } + + [Fact] + public async Task UnAckedMessageResentOnReconnect() + { + var protocol = new JsonHubProtocol(); + var connection = new TestConnectionContext(); + var pipes = DuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions()); + connection.Transport = pipes.Transport; + using var messageBuffer = new MessageBuffer(connection, protocol); + + await messageBuffer.WriteAsync(new SerializedHubMessage(new StreamItemMessage("id", null)), default); + + var res = await pipes.Application.Input.ReadAsync(); + + var buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out var message)); + Assert.IsType(message); + + pipes.Application.Input.AdvanceTo(buffer.Start); + + DuplexPipe.UpdateConnectionPair(ref pipes, connection); + messageBuffer.Resend(); + + // Any message except SequenceMessage will be ignored until a SequenceMessage is received + Assert.False(messageBuffer.ShouldProcessMessage(PingMessage.Instance)); + Assert.False(messageBuffer.ShouldProcessMessage(CompletionMessage.WithResult("1", null))); + Assert.True(messageBuffer.ShouldProcessMessage(new SequenceMessage(1))); + + res = await pipes.Application.Input.ReadAsync(); + + buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message)); + var seqMessage = Assert.IsType(message); + Assert.Equal(1, seqMessage.SequenceId); + + pipes.Application.Input.AdvanceTo(buffer.Start); + + res = await pipes.Application.Input.ReadAsync(); + + buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message)); + Assert.IsType(message); + + pipes.Application.Input.AdvanceTo(buffer.Start); + + messageBuffer.ResetSequence(new SequenceMessage(1)); + + Assert.True(messageBuffer.ShouldProcessMessage(PingMessage.Instance)); + Assert.True(messageBuffer.ShouldProcessMessage(CompletionMessage.WithResult("1", null))); + } + + [Fact] + public async Task AckedMessageNotResentOnReconnect() + { + var protocol = new JsonHubProtocol(); + var connection = new TestConnectionContext(); + var pipes = DuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions()); + connection.Transport = pipes.Transport; + using var messageBuffer = new MessageBuffer(connection, protocol); + + await messageBuffer.WriteAsync(new SerializedHubMessage(new StreamItemMessage("id", null)), default); + + var res = await pipes.Application.Input.ReadAsync(); + + var buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out var message)); + Assert.IsType(message); + + pipes.Application.Input.AdvanceTo(buffer.Start); + + messageBuffer.Ack(new AckMessage(1)); + + DuplexPipe.UpdateConnectionPair(ref pipes, connection); + messageBuffer.Resend(); + + res = await pipes.Application.Input.ReadAsync(); + + buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message)); + var seqMessage = Assert.IsType(message); + Assert.Equal(2, seqMessage.SequenceId); + + pipes.Application.Input.AdvanceTo(buffer.Start); + + await messageBuffer.WriteAsync(new SerializedHubMessage(CompletionMessage.WithResult("1", null)), default); + + res = await pipes.Application.Input.ReadAsync(); + + buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message)); + Assert.IsType(message); + + pipes.Application.Input.AdvanceTo(buffer.Start); + } + + [Fact] + public async Task ReceiveSequenceMessageWithLargerIDThanMessagesReceived() + { + var protocol = new JsonHubProtocol(); + var connection = new TestConnectionContext(); + var pipes = DuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions()); + connection.Transport = pipes.Transport; + using var messageBuffer = new MessageBuffer(connection, protocol); + + DuplexPipe.UpdateConnectionPair(ref pipes, connection); + messageBuffer.Resend(); + + var res = await pipes.Application.Input.ReadAsync(); + + var buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out var message)); + var seqMessage = Assert.IsType(message); + Assert.Equal(1, seqMessage.SequenceId); + + pipes.Application.Input.AdvanceTo(buffer.Start); + + Assert.Throws(() => messageBuffer.ResetSequence(new SequenceMessage(2))); + } + + [Fact] + public async Task WriteManyMessagesAckSomeProperlyBuffers() + { + var protocol = new JsonHubProtocol(); + var connection = new TestConnectionContext(); + var pipes = DuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions()); + connection.Transport = pipes.Transport; + using var messageBuffer = new MessageBuffer(connection, protocol); + + for (var i = 0; i < 1000; i++) + { + await messageBuffer.WriteAsync(new SerializedHubMessage(new StreamItemMessage("1", null)), default); + } + + var ackNum = Random.Shared.Next(0, 1000); + messageBuffer.Ack(new AckMessage(ackNum)); + + DuplexPipe.UpdateConnectionPair(ref pipes, connection); + messageBuffer.Resend(); + + var res = await pipes.Application.Input.ReadAsync(); + + var buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out var message)); + var seqMessage = Assert.IsType(message); + Assert.Equal(ackNum + 1, seqMessage.SequenceId); + + pipes.Application.Input.AdvanceTo(buffer.Start); + + for (var i = 0; i < 1000 - ackNum; i++) + { + res = await pipes.Application.Input.ReadAsync(); + + buffer = res.Buffer; + Assert.True(protocol.TryParseMessage(ref buffer, new TestBinder(), out message)); + Assert.IsType(message); + + pipes.Application.Input.AdvanceTo(buffer.Start); + } + } +} + +internal sealed class TestConnectionContext : ConnectionContext +{ + public override string ConnectionId { get; set; } + public override IFeatureCollection Features { get; } = new FeatureCollection(); + public override IDictionary Items { get; set; } + public override IDuplexPipe Transport { get; set; } +} + +internal sealed class DuplexPipe : IDuplexPipe +{ + public DuplexPipe(PipeReader reader, PipeWriter writer) + { + Input = reader; + Output = writer; + } + + public PipeReader Input { get; } + + public PipeWriter Output { get; } + + public static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) + { + var input = new Pipe(inputOptions); + var output = new Pipe(outputOptions); + + var transportToApplication = new DuplexPipe(output.Reader, input.Writer); + var applicationToTransport = new DuplexPipe(input.Reader, output.Writer); + + return new DuplexPipePair(applicationToTransport, transportToApplication); + } + + // This class exists to work around issues with value tuple on .NET Framework + public struct DuplexPipePair + { + public IDuplexPipe Transport { get; set; } + public IDuplexPipe Application { get; set; } + + public DuplexPipePair(IDuplexPipe transport, IDuplexPipe application) + { + Transport = transport; + Application = application; + } + } + + public static void UpdateConnectionPair(ref DuplexPipePair duplexPipePair, ConnectionContext connection) + { + var prevPipe = duplexPipePair.Application.Input; + var input = new Pipe(); + + // Add new pipe for reading from and writing to transport from app code + var transportToApplication = new DuplexPipe(duplexPipePair.Transport.Input, input.Writer); + var applicationToTransport = new DuplexPipe(input.Reader, duplexPipePair.Application.Output); + + duplexPipePair.Application = applicationToTransport; + duplexPipePair.Transport = transportToApplication; + + connection.Transport = duplexPipePair.Transport; + + // Close previous pipe with specific error that application code can catch to know a restart is occurring + prevPipe.Complete(new ConnectionResetException("")); + } +} + +internal sealed class TestBinder : IInvocationBinder +{ + public IReadOnlyList GetParameterTypes(string methodName) + { + var list = new List + { + typeof(object) + }; + return list; + } + + public Type GetReturnType(string invocationId) + { + return typeof(object); + } + + public Type GetStreamItemType(string streamId) + { + return typeof(object); + } +}