diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs
index 898c4aff32..e8c4e46553 100644
--- a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs
+++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs
@@ -24,6 +24,8 @@ namespace Microsoft.AspNetCore.SignalR.Client
{
public class HubConnection
{
+ public static readonly TimeSpan DefaultServerTimeout = TimeSpan.FromSeconds(30); // Server ping rate is 15 sec, this is 2 times that.
+
private readonly ILoggerFactory _loggerFactory;
private readonly ILogger _logger;
private readonly IConnection _connection;
@@ -38,9 +40,17 @@ public class HubConnection
private int _nextId = 0;
private volatile bool _startCalled;
+ private Timer _timeoutTimer;
+ private bool _needKeepAlive;
public Task Closed { get; }
+ ///
+ /// Gets or sets the server timeout interval for the connection. Changes to this value
+ /// will not be applied until the Keep Alive timer is next reset.
+ ///
+ public TimeSpan ServerTimeout { get; set; } = DefaultServerTimeout;
+
public HubConnection(IConnection connection, IHubProtocol protocol, ILoggerFactory loggerFactory)
{
if (connection == null)
@@ -64,6 +74,9 @@ public HubConnection(IConnection connection, IHubProtocol protocol, ILoggerFacto
Shutdown(task.Exception);
return task;
}).Unwrap();
+
+ // Create the timer for timeout, but disabled by default (we enable it when started).
+ _timeoutTimer = new Timer(state => ((HubConnection)state).TimeoutElapsed(), this, Timeout.Infinite, Timeout.Infinite);
}
public async Task StartAsync()
@@ -78,6 +91,20 @@ public async Task StartAsync()
}
}
+ private void TimeoutElapsed()
+ {
+ _connection.AbortAsync(new TimeoutException($"Server timeout ({ServerTimeout.TotalMilliseconds:0.00}ms) elapsed without receiving a message from the server."));
+ }
+
+ private void ResetTimeoutTimer()
+ {
+ if (_needKeepAlive)
+ {
+ _logger.ResettingKeepAliveTimer();
+ _timeoutTimer.Change(ServerTimeout, Timeout.InfiniteTimeSpan);
+ }
+ }
+
private async Task StartAsyncCore()
{
var transferModeFeature = _connection.Features.Get();
@@ -94,6 +121,7 @@ private async Task StartAsyncCore()
transferModeFeature.TransferMode = requestedTransferMode;
await _connection.StartAsync();
+ _needKeepAlive = _connection.Features.Get() == null;
var actualTransferMode = transferModeFeature.TransferMode;
_protocolReaderWriter = new HubProtocolReaderWriter(_protocol, GetDataEncoder(requestedTransferMode, actualTransferMode));
@@ -105,6 +133,8 @@ private async Task StartAsyncCore()
NegotiationProtocol.WriteMessage(new NegotiationMessage(_protocol.Name), memoryStream);
await _connection.SendAsync(memoryStream.ToArray(), _connectionActive.Token);
}
+
+ ResetTimeoutTimer();
}
private IDataEncoder GetDataEncoder(TransferMode requestedTransferMode, TransferMode actualTransferMode)
@@ -125,6 +155,7 @@ private IDataEncoder GetDataEncoder(TransferMode requestedTransferMode, Transfer
private async Task DisposeAsyncCore()
{
+ _timeoutTimer.Dispose();
await _connection.DisposeAsync();
await Closed;
}
@@ -298,6 +329,7 @@ private async Task SendAsyncCore(string methodName, object[] args, CancellationT
private async Task OnDataReceivedAsync(byte[] data)
{
+ ResetTimeoutTimer();
if (_protocolReaderWriter.ReadMessages(data, _binder, out var messages))
{
foreach (var message in messages)
diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/Internal/SignalRClientLoggerExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/Internal/SignalRClientLoggerExtensions.cs
index c30c4e17f3..1f7edd7c29 100644
--- a/src/Microsoft.AspNetCore.SignalR.Client.Core/Internal/SignalRClientLoggerExtensions.cs
+++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/Internal/SignalRClientLoggerExtensions.cs
@@ -85,6 +85,9 @@ internal static class SignalRClientLoggerExtensions
private static readonly Action _preparingStreamingInvocation =
LoggerMessage.Define(LogLevel.Trace, new EventId(24, nameof(PreparingStreamingInvocation)), "Preparing streaming invocation '{invocationId}' of '{target}', with return type '{returnType}' and {argumentCount} argument(s).");
+ private static readonly Action _resettingKeepAliveTimer =
+ LoggerMessage.Define(LogLevel.Trace, new EventId(25, nameof(ResettingKeepAliveTimer)), "Resetting keep-alive timer, received a message from the server.");
+
// Category: Streaming and NonStreaming
private static readonly Action _invocationCreated =
LoggerMessage.Define(LogLevel.Trace, new EventId(0, nameof(InvocationCreated)), "Invocation {invocationId} created.");
@@ -282,7 +285,12 @@ public static void StreamItemOnNonStreamInvocation(this ILogger logger, string i
public static void ErrorInvokingClientSideMethod(this ILogger logger, string methodName, Exception exception)
{
- _errorInvokingClientSideMethod(logger, methodName, exception);
+ _errorInvokingClientSideMethod(logger, methodName, exception);
+ }
+
+ public static void ResettingKeepAliveTimer(this ILogger logger)
+ {
+ _resettingKeepAliveTimer(logger, null);
}
}
}
diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/Microsoft.AspNetCore.SignalR.Client.Core.csproj b/src/Microsoft.AspNetCore.SignalR.Client.Core/Microsoft.AspNetCore.SignalR.Client.Core.csproj
index e05b0b19f7..9e3d8f1784 100644
--- a/src/Microsoft.AspNetCore.SignalR.Client.Core/Microsoft.AspNetCore.SignalR.Client.Core.csproj
+++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/Microsoft.AspNetCore.SignalR.Client.Core.csproj
@@ -3,6 +3,7 @@
Client for ASP.NET Core SignalR
netstandard2.0
+ Microsoft.AspNetCore.SignalR.Client
diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolConstants.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolConstants.cs
index 09179c276f..b8efdfa960 100644
--- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolConstants.cs
+++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolConstants.cs
@@ -1,6 +1,6 @@
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
- internal static class HubProtocolConstants
+ public static class HubProtocolConstants
{
public const int InvocationMessageType = 1;
public const int StreamItemMessageType = 2;
diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs
index be3f5f26fe..7152c33f9b 100644
--- a/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs
+++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs
@@ -1,4 +1,4 @@
-// Copyright (c) .NET Foundation. All rights reserved.
+// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
@@ -13,6 +13,7 @@ public interface IConnection
Task StartAsync();
Task SendAsync(byte[] data, CancellationToken cancellationToken);
Task DisposeAsync();
+ Task AbortAsync(Exception ex);
IDisposable OnReceived(Func callback, object state);
diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs
index e312bad199..0773017e92 100644
--- a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs
+++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs
@@ -95,7 +95,7 @@ public HttpConnection(Uri url, ITransportFactory transportFactory, ILoggerFactor
_logger = _loggerFactory.CreateLogger();
_httpOptions = httpOptions;
_httpClient = _httpOptions?.HttpMessageHandler == null ? new HttpClient() : new HttpClient(_httpOptions?.HttpMessageHandler);
- _httpClient.Timeout = HttpClientTimeout;
+ _httpClient.Timeout = HttpClientTimeout;
_transportFactory = transportFactory ?? throw new ArgumentNullException(nameof(transportFactory));
}
@@ -303,7 +303,7 @@ private async Task StartTransport(Uri connectUrl)
// Start the transport, giving it one end of the pipeline
try
{
- await _transport.StartAsync(connectUrl, applicationSide, requestedTransferMode: GetTransferMode(), connectionId: _connectionId);
+ await _transport.StartAsync(connectUrl, applicationSide, GetTransferMode(), _connectionId, this);
// actual transfer mode can differ from the one that was requested so set it on the feature
Debug.Assert(_transport.Mode.HasValue, "transfer mode not set after transport started");
@@ -435,11 +435,25 @@ private async Task SendAsyncCore(byte[] data, CancellationToken cancellationToke
}
}
+ public async Task AbortAsync(Exception ex) => await DisposeAsyncCore(ex ?? new InvalidOperationException("Connection aborted")).ForceAsync();
+
public async Task DisposeAsync() => await DisposeAsyncCore().ForceAsync();
- private async Task DisposeAsyncCore()
+ private async Task DisposeAsyncCore(Exception ex = null)
{
- _logger.StoppingClient(_connectionId);
+ if (ex != null)
+ {
+ _logger.AbortingClient(_connectionId, ex);
+
+ // Immediately fault the close task. When the transport shuts down,
+ // it will trigger the close task to be completed, so we want it to be
+ // marked faulted before that happens
+ _closedTcs.TrySetException(ex);
+ }
+ else
+ {
+ _logger.StoppingClient(_connectionId);
+ }
if (Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected) == ConnectionState.Initial)
{
@@ -472,6 +486,7 @@ private async Task DisposeAsyncCore()
await _receiveLoopTask;
}
+ // If we haven't already done so, trigger the Closed task.
_closedTcs.TrySetResult(null);
_httpClient?.Dispose();
}
diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransport.cs
index 784400db85..620e0e1a1d 100644
--- a/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransport.cs
+++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransport.cs
@@ -9,7 +9,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
{
public interface ITransport
{
- Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, string connectionId);
+ Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, string connectionId, IConnection connection);
Task StopAsync();
TransferMode? Mode { get; }
}
diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs
index 1ac5a0a0d4..a68cfb9659 100644
--- a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs
+++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs
@@ -153,6 +153,9 @@ internal static class SocketClientLoggerExtensions
private static readonly Action _exceptionThrownFromCallback =
LoggerMessage.Define(LogLevel.Error, new EventId(19, nameof(ExceptionThrownFromCallback)), "{time}: Connection Id {connectionId}: An exception was thrown from the '{callback}' callback");
+ private static readonly Action _abortingClient =
+ LoggerMessage.Define(LogLevel.Error, new EventId(20, nameof(AbortingClient)), "{time}: Connection Id {connectionId}: Aborting client.");
+
public static void StartTransport(this ILogger logger, string connectionId, TransferMode transferMode)
{
@@ -506,6 +509,14 @@ public static void SendingMessage(this ILogger logger, string connectionId)
}
}
+ public static void AbortingClient(this ILogger logger, string connectionId, Exception ex)
+ {
+ if (logger.IsEnabled(LogLevel.Error))
+ {
+ _abortingClient(logger, DateTime.Now, connectionId, ex);
+ }
+ }
+
public static void StoppingClient(this ILogger logger, string connectionId)
{
if (logger.IsEnabled(LogLevel.Information))
diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs
index b2c8dec3ae..5135a8a56c 100644
--- a/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs
+++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs
@@ -2,14 +2,14 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
-using System.Collections.Generic;
using System.Net;
using System.Net.Http;
using System.Threading;
-using System.Threading.Tasks;
using System.Threading.Channels;
+using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets.Client.Http;
using Microsoft.AspNetCore.Sockets.Client.Internal;
+using Microsoft.AspNetCore.Sockets.Features;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
@@ -42,13 +42,15 @@ public LongPollingTransport(HttpClient httpClient, HttpOptions httpOptions, ILog
_logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger();
}
- public Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, string connectionId)
+ public Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, string connectionId, IConnection connection)
{
if (requestedTransferMode != TransferMode.Binary && requestedTransferMode != TransferMode.Text)
{
throw new ArgumentException("Invalid transfer mode.", nameof(requestedTransferMode));
}
+ connection.Features.Set(new ConnectionInherentKeepAliveFeature(_httpClient.Timeout));
+
_application = application;
Mode = requestedTransferMode;
_connectionId = connectionId;
diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs
index 2f288562a1..c0c5b929cd 100644
--- a/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs
+++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs
@@ -49,7 +49,7 @@ public ServerSentEventsTransport(HttpClient httpClient, HttpOptions httpOptions,
_logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger();
}
- public Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, string connectionId)
+ public Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, string connectionId, IConnection connection)
{
if (requestedTransferMode != TransferMode.Binary && requestedTransferMode != TransferMode.Text)
{
diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs
index fdaa2b64c0..f68c2d8b37 100644
--- a/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs
+++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs
@@ -55,7 +55,7 @@ public WebSocketsTransport(HttpOptions httpOptions, ILoggerFactory loggerFactory
_logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger();
}
- public async Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, string connectionId)
+ public async Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, string connectionId, IConnection connection)
{
if (url == null)
{
diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs
index 86d975fb89..baa342476e 100644
--- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs
+++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs
@@ -155,7 +155,7 @@ public async Task CanStopStartingConnection()
releaseDisposeTcs.SetResult(null);
await disposeTask.OrTimeout();
- transport.Verify(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny()), Times.Never);
+ transport.Verify(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny()), Times.Never);
}
[Fact]
@@ -263,8 +263,8 @@ public async Task ReceivedCallbackNotRaisedAfterConnectionIsDisposed()
var mockTransport = new Mock();
Channel channel = null;
- mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny()))
- .Returns, TransferMode, string>((url, c, transferMode, connectionId) =>
+ mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny()))
+ .Returns, TransferMode, string, IConnection>((url, c, transferMode, connectionId, _) =>
{
channel = c;
return Task.CompletedTask;
@@ -311,8 +311,8 @@ public async Task EventsAreNotRunningOnMainLoop()
var mockTransport = new Mock();
Channel channel = null;
- mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny()))
- .Returns, TransferMode, string>((url, c, transferMode, connectionId) =>
+ mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny()))
+ .Returns, TransferMode, string, IConnection>((url, c, transferMode, connectionId, _) =>
{
channel = c;
return Task.CompletedTask;
@@ -368,8 +368,8 @@ public async Task EventQueueTimeout()
var mockTransport = new Mock();
Channel channel = null;
- mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny()))
- .Returns, TransferMode, string>((url, c, transferMode, connectionId) =>
+ mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny()))
+ .Returns, TransferMode, string, IConnection>((url, c, transferMode, connectionId, _) =>
{
channel = c;
return Task.CompletedTask;
@@ -413,8 +413,8 @@ public async Task EventQueueTimeoutWithException()
var mockTransport = new Mock();
Channel channel = null;
- mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny()))
- .Returns, TransferMode, string>((url, c, transferMode, connectionId) =>
+ mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny()))
+ .Returns, TransferMode, string, IConnection>((url, c, transferMode, connectionId, _) =>
{
channel = c;
return Task.CompletedTask;
@@ -925,8 +925,8 @@ public async Task CanStartConnectionWithoutSettingTransferModeFeature()
var mockTransport = new Mock();
Channel channel = null;
- mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny()))
- .Returns, TransferMode, string>((url, c, transferMode, connectionId) =>
+ mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny(), It.IsAny()))
+ .Returns, TransferMode, string, IConnection>((url, c, transferMode, connectionId, _) =>
{
channel = c;
return Task.CompletedTask;
@@ -947,7 +947,7 @@ public async Task CanStartConnectionWithoutSettingTransferModeFeature()
await connection.DisposeAsync().OrTimeout();
mockTransport.Verify(t => t.StartAsync(
- It.IsAny(), It.IsAny>(), TransferMode.Text, It.IsAny()), Times.Once);
+ It.IsAny(), It.IsAny>(), TransferMode.Text, It.IsAny(), It.IsAny()), Times.Once);
Assert.NotNull(transferModeFeature);
Assert.Equal(TransferMode.Binary, transferModeFeature.TransferMode);
}
diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs
index 35cd3a1a6e..85784f347d 100644
--- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs
+++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs
@@ -9,6 +9,7 @@
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets.Client;
+using Microsoft.AspNetCore.Sockets.Features;
using Microsoft.Extensions.Logging;
using Moq;
using Xunit;
@@ -33,6 +34,7 @@ public async Task StartAsyncCallsConnectionStart()
public async Task DisposeAsyncCallsConnectionStart()
{
var connection = new Mock();
+ connection.Setup(m => m.Features).Returns(new FeatureCollection());
connection.Setup(m => m.StartAsync()).Verifiable();
var hubConnection = new HubConnection(connection.Object, Mock.Of(), null);
await hubConnection.DisposeAsync();
@@ -185,6 +187,19 @@ public async Task PendingInvocationsAreTerminatedWithExceptionWhenConnectionClos
await Assert.ThrowsAsync(async () => await invokeTask);
}
+ [Fact]
+ public async Task ConnectionTerminatedIfServerTimeoutIntervalElapsesWithNoMessages()
+ {
+ var connection = new TestConnection();
+ var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory());
+
+ hubConnection.ServerTimeout = TimeSpan.FromMilliseconds(100);
+
+ await hubConnection.StartAsync().OrTimeout();
+ var ex = await Assert.ThrowsAsync(async () => await hubConnection.Closed.OrTimeout());
+ Assert.Equal("Server timeout (100.00ms) elapsed without receiving a message from the server.", ex.Message);
+ }
+
// Moq really doesn't handle out parameters well, so to make these tests work I added a manual mock -anurse
private class MockHubProtocol : IHubProtocol
{
diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs
index 55c1452223..e4e5e43a15 100644
--- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs
+++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs
@@ -9,6 +9,7 @@
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
+using Microsoft.AspNetCore.SignalR.Client.Tests;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.AspNetCore.Sockets.Internal;
@@ -43,7 +44,7 @@ public async Task LongPollingTransportStopsPollAndSendLoopsWhenTransportStopped(
var connectionToTransport = Channel.CreateUnbounded();
var transportToConnection = Channel.CreateUnbounded();
var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection);
- await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty);
+ await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty, connection: new TestConnection());
transportActiveTask = longPollingTransport.Running;
@@ -79,7 +80,7 @@ public async Task LongPollingTransportStopsWhenPollReceives204()
var connectionToTransport = Channel.CreateUnbounded();
var transportToConnection = Channel.CreateUnbounded();
var channelConnection = ChannelConnection.Create(connectionToTransport, transportToConnection);
- await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty);
+ await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty, connection: new TestConnection());
await longPollingTransport.Running.OrTimeout();
Assert.True(transportToConnection.Reader.Completion.IsCompleted);
@@ -132,7 +133,7 @@ public async Task LongPollingTransportResponseWithNoContentDoesNotStopPoll()
var connectionToTransport = Channel.CreateUnbounded();
var transportToConnection = Channel.CreateUnbounded();
var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection);
- await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty);
+ await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty, connection: new TestConnection());
var data = await transportToConnection.Reader.ReadAllAsync().OrTimeout();
await longPollingTransport.Running.OrTimeout();
@@ -168,7 +169,7 @@ public async Task LongPollingTransportStopsWhenPollRequestFails()
var connectionToTransport = Channel.CreateUnbounded();
var transportToConnection = Channel.CreateUnbounded();
var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection);
- await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty);
+ await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty, connection: new TestConnection());
var exception =
await Assert.ThrowsAsync(async () => await transportToConnection.Reader.Completion.OrTimeout());
@@ -204,7 +205,7 @@ public async Task LongPollingTransportStopsWhenSendRequestFails()
var connectionToTransport = Channel.CreateUnbounded();
var transportToConnection = Channel.CreateUnbounded();
var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection);
- await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty);
+ await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty, connection: new TestConnection());
await connectionToTransport.Writer.WriteAsync(new SendMessage());
@@ -245,7 +246,7 @@ public async Task LongPollingTransportShutsDownWhenChannelIsClosed()
var connectionToTransport = Channel.CreateUnbounded();
var transportToConnection = Channel.CreateUnbounded();
var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection);
- await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty);
+ await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty, connection: new TestConnection());
connectionToTransport.Writer.Complete();
@@ -296,7 +297,7 @@ public async Task LongPollingTransportDispatchesMessagesReceivedFromPoll()
var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection);
// Start the transport
- await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty);
+ await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty, connection: new TestConnection());
// Wait for the transport to finish
await longPollingTransport.Running.OrTimeout();
@@ -361,7 +362,7 @@ public async Task LongPollingTransportSendsAvailableMessagesWhenTheyArrive()
await connectionToTransport.Writer.WriteAsync(new SendMessage(Encoding.UTF8.GetBytes("World"), tcs2)).OrTimeout();
// Start the transport
- await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty);
+ await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty, connection: new TestConnection());
connectionToTransport.Writer.Complete();
@@ -404,7 +405,7 @@ public async Task LongPollingTransportSetsTransferMode(TransferMode transferMode
var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection);
Assert.Null(longPollingTransport.Mode);
- await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, transferMode, connectionId: string.Empty);
+ await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, transferMode, connectionId: string.Empty, connection: new TestConnection());
Assert.Equal(transferMode, longPollingTransport.Mode);
}
finally
@@ -430,7 +431,7 @@ public async Task LongPollingTransportThrowsForInvalidTransferMode()
{
var longPollingTransport = new LongPollingTransport(httpClient);
var exception = await Assert.ThrowsAsync(() =>
- longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), null, TransferMode.Text | TransferMode.Binary, connectionId: string.Empty));
+ longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), null, TransferMode.Text | TransferMode.Binary, connectionId: string.Empty, connection: new TestConnection()));
Assert.Contains("Invalid transfer mode.", exception.Message);
Assert.Equal("requestedTransferMode", exception.ParamName);
@@ -468,7 +469,7 @@ public async Task LongPollingTransportRePollsIfRequestCancelled()
var connectionToTransport = Channel.CreateUnbounded();
var transportToConnection = Channel.CreateUnbounded();
var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection);
- await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty);
+ await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty, connection: new TestConnection());
var completedTask = await Task.WhenAny(completionTcs.Task, longPollingTransport.Running).OrTimeout();
Assert.Equal(completionTcs.Task, completedTask);
diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs
index 961700585f..4a4e8f81a0 100644
--- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs
+++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs
@@ -55,7 +55,7 @@ public async Task CanStartStopSSETransport()
var transportToConnection = Channel.CreateUnbounded();
var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection);
await sseTransport.StartAsync(
- new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty).OrTimeout();
+ new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty, connection: Mock.Of()).OrTimeout();
await eventStreamTcs.Task.OrTimeout();
await sseTransport.StopAsync().OrTimeout();
@@ -108,7 +108,7 @@ public async Task SSETransportStopsSendAndReceiveLoopsWhenTransportStopped()
var transportToConnection = Channel.CreateUnbounded();
var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection);
await sseTransport.StartAsync(
- new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty).OrTimeout();
+ new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty, connection: Mock.Of()).OrTimeout();
transportActiveTask = sseTransport.Running;
Assert.False(transportActiveTask.IsCompleted);
@@ -156,7 +156,7 @@ public async Task SSETransportStopsWithErrorIfServerSendsIncompleteResults()
var transportToConnection = Channel.CreateUnbounded();
var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection);
await sseTransport.StartAsync(
- new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty).OrTimeout();
+ new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty, connection: Mock.Of()).OrTimeout();
var exception = await Assert.ThrowsAsync(() => sseTransport.Running.OrTimeout());
Assert.Equal("Incomplete message.", exception.Message);
@@ -202,7 +202,7 @@ public async Task SSETransportStopsWithErrorIfSendingMessageFails()
var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection);
await sseTransport.StartAsync(
- new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty).OrTimeout();
+ new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty, connection: Mock.Of()).OrTimeout();
await eventStreamTcs.Task;
var sendTcs = new TaskCompletionSource