diff --git a/src/Servers/Kestrel/Core/ref/Microsoft.AspNetCore.Server.Kestrel.Core.netcoreapp.cs b/src/Servers/Kestrel/Core/ref/Microsoft.AspNetCore.Server.Kestrel.Core.netcoreapp.cs index 8f042c6dae18..aee4918239f3 100644 --- a/src/Servers/Kestrel/Core/ref/Microsoft.AspNetCore.Server.Kestrel.Core.netcoreapp.cs +++ b/src/Servers/Kestrel/Core/ref/Microsoft.AspNetCore.Server.Kestrel.Core.netcoreapp.cs @@ -137,6 +137,7 @@ public KestrelServerOptions() { } public Microsoft.AspNetCore.Server.Kestrel.Core.KestrelServerLimits Limits { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } } public Microsoft.AspNetCore.Server.Kestrel.KestrelConfigurationLoader Configure() { throw null; } public Microsoft.AspNetCore.Server.Kestrel.KestrelConfigurationLoader Configure(Microsoft.Extensions.Configuration.IConfiguration config) { throw null; } + public Microsoft.AspNetCore.Server.Kestrel.KestrelConfigurationLoader Configure(Microsoft.Extensions.Configuration.IConfiguration config, bool reloadOnChange) { throw null; } public void ConfigureEndpointDefaults(System.Action configureOptions) { } public void ConfigureHttpsDefaults(System.Action configureOptions) { } public void Listen(System.Net.EndPoint endPoint) { } diff --git a/src/Servers/Kestrel/Core/src/Internal/AddressBindContext.cs b/src/Servers/Kestrel/Core/src/Internal/AddressBindContext.cs index f4c1859b7fe6..fc425b1c9c3c 100644 --- a/src/Servers/Kestrel/Core/src/Internal/AddressBindContext.cs +++ b/src/Servers/Kestrel/Core/src/Internal/AddressBindContext.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -10,8 +10,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal { internal class AddressBindContext { - public ICollection Addresses { get; set; } - public List ListenOptions { get; set; } + public ServerAddressesFeature ServerAddressesFeature { get; set; } + public ICollection Addresses => ServerAddressesFeature.InternalCollection; + public KestrelServerOptions ServerOptions { get; set; } public ILogger Logger { get; set; } diff --git a/src/Servers/Kestrel/Core/src/Internal/AddressBinder.cs b/src/Servers/Kestrel/Core/src/Internal/AddressBinder.cs index c87b0653bdbd..42c74ee95f86 100644 --- a/src/Servers/Kestrel/Core/src/Internal/AddressBinder.cs +++ b/src/Servers/Kestrel/Core/src/Internal/AddressBinder.cs @@ -19,30 +19,17 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal { internal class AddressBinder { - public static async Task BindAsync(IServerAddressesFeature addresses, - KestrelServerOptions serverOptions, - ILogger logger, - Func createBinding) + public static async Task BindAsync(IEnumerable listenOptions, AddressBindContext context) { - var listenOptions = serverOptions.ListenOptions; var strategy = CreateStrategy( listenOptions.ToArray(), - addresses.Addresses.ToArray(), - addresses.PreferHostingUrls); - - var context = new AddressBindContext - { - Addresses = addresses.Addresses, - ListenOptions = listenOptions, - ServerOptions = serverOptions, - Logger = logger, - CreateBinding = createBinding - }; + context.Addresses.ToArray(), + context.ServerAddressesFeature.PreferHostingUrls); // reset options. The actual used options and addresses will be populated // by the address binding feature - listenOptions.Clear(); - addresses.Addresses.Clear(); + context.ServerOptions.OptionsInUse.Clear(); + context.Addresses.Clear(); await strategy.BindAsync(context).ConfigureAwait(false); } @@ -109,7 +96,7 @@ internal static async Task BindEndpointAsync(ListenOptions endpoint, AddressBind throw new IOException(CoreStrings.FormatEndpointAlreadyInUse(endpoint), ex); } - context.ListenOptions.Add(endpoint); + context.ServerOptions.OptionsInUse.Add(endpoint); } internal static ListenOptions ParseAddress(string address, out bool https) diff --git a/src/Servers/Kestrel/Core/src/Internal/ConfigSectionClone.cs b/src/Servers/Kestrel/Core/src/Internal/ConfigSectionClone.cs new file mode 100644 index 000000000000..230006d1ba67 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/ConfigSectionClone.cs @@ -0,0 +1,60 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +#nullable enable + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.Extensions.Configuration; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal +{ + internal class ConfigSectionClone + { + public ConfigSectionClone(IConfigurationSection configSection) + { + Value = configSection.Value; + + // GetChildren() should return an empty IEnumerable instead of null, but we guard against it since it's a public interface. + var children = configSection.GetChildren() ?? Enumerable.Empty(); + Children = children.ToDictionary(child => child.Key, child => new ConfigSectionClone(child)); + } + + public string Value { get; } + public Dictionary Children { get; } + + public override bool Equals(object? obj) + { + if (!(obj is ConfigSectionClone other)) + { + return false; + } + + if (Value != other.Value || Children.Count != other.Children.Count) + { + return false; + } + + foreach (var kvp in Children) + { + if (!other.Children.TryGetValue(kvp.Key, out var child)) + { + return false; + } + + if (kvp.Value != child) + { + return false; + } + } + + return true; + } + + public override int GetHashCode() => HashCode.Combine(Value, Children.Count); + + public static bool operator ==(ConfigSectionClone lhs, ConfigSectionClone rhs) => lhs is null ? rhs is null : lhs.Equals(rhs); + public static bool operator !=(ConfigSectionClone lhs, ConfigSectionClone rhs) => !(lhs == rhs); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/ConfigurationReader.cs b/src/Servers/Kestrel/Core/src/Internal/ConfigurationReader.cs index 79dc84d5c1c1..f54f43ca7a64 100644 --- a/src/Servers/Kestrel/Core/src/Internal/ConfigurationReader.cs +++ b/src/Servers/Kestrel/Core/src/Internal/ConfigurationReader.cs @@ -17,95 +17,50 @@ internal class ConfigurationReader private const string UrlKey = "Url"; private const string Latin1RequestHeadersKey = "Latin1RequestHeaders"; - private IConfiguration _configuration; - private IDictionary _certificates; - private IList _endpoints; - private EndpointDefaults _endpointDefaults; - private bool? _latin1RequestHeaders; + private readonly IConfiguration _configuration; public ConfigurationReader(IConfiguration configuration) { _configuration = configuration ?? throw new ArgumentNullException(nameof(configuration)); + Certificates = ReadCertificates(); + EndpointDefaults = ReadEndpointDefaults(); + Endpoints = ReadEndpoints(); + Latin1RequestHeaders = _configuration.GetValue(Latin1RequestHeadersKey); } - public IDictionary Certificates - { - get - { - if (_certificates == null) - { - ReadCertificates(); - } - - return _certificates; - } - } - - public EndpointDefaults EndpointDefaults - { - get - { - if (_endpointDefaults == null) - { - ReadEndpointDefaults(); - } - - return _endpointDefaults; - } - } - - public IEnumerable Endpoints - { - get - { - if (_endpoints == null) - { - ReadEndpoints(); - } - - return _endpoints; - } - } + public IDictionary Certificates { get; } + public EndpointDefaults EndpointDefaults { get; } + public IEnumerable Endpoints { get; } + public bool Latin1RequestHeaders { get; } - public bool Latin1RequestHeaders + private IDictionary ReadCertificates() { - get - { - if (_latin1RequestHeaders is null) - { - _latin1RequestHeaders = _configuration.GetValue(Latin1RequestHeadersKey); - } - - return _latin1RequestHeaders.Value; - } - } - - private void ReadCertificates() - { - _certificates = new Dictionary(0); + var certificates = new Dictionary(0); var certificatesConfig = _configuration.GetSection(CertificatesKey).GetChildren(); foreach (var certificateConfig in certificatesConfig) { - _certificates.Add(certificateConfig.Key, new CertificateConfig(certificateConfig)); + certificates.Add(certificateConfig.Key, new CertificateConfig(certificateConfig)); } + + return certificates; } // "EndpointDefaults": { // "Protocols": "Http1AndHttp2", // } - private void ReadEndpointDefaults() + private EndpointDefaults ReadEndpointDefaults() { var configSection = _configuration.GetSection(EndpointDefaultsKey); - _endpointDefaults = new EndpointDefaults + return new EndpointDefaults { Protocols = ParseProtocols(configSection[ProtocolsKey]) }; } - private void ReadEndpoints() + private IEnumerable ReadEndpoints() { - _endpoints = new List(); + var endpoints = new List(); var endpointsConfig = _configuration.GetSection(EndpointsKey).GetChildren(); foreach (var endpointConfig in endpointsConfig) @@ -133,8 +88,11 @@ private void ReadEndpoints() ConfigSection = endpointConfig, Certificate = new CertificateConfig(endpointConfig.GetSection(CertificateKey)), }; - _endpoints.Add(endpoint); + + endpoints.Add(endpoint); } + + return endpoints; } private static HttpProtocols? ParseProtocols(string protocols) @@ -154,7 +112,6 @@ private void ReadEndpoints() internal class EndpointDefaults { public HttpProtocols? Protocols { get; set; } - public IConfigurationSection ConfigSection { get; set; } } // "EndpointName": { @@ -167,11 +124,41 @@ internal class EndpointDefaults // } internal class EndpointConfig { + private IConfigurationSection _configSection; + private ConfigSectionClone _configSectionClone; + public string Name { get; set; } public string Url { get; set; } public HttpProtocols? Protocols { get; set; } - public IConfigurationSection ConfigSection { get; set; } public CertificateConfig Certificate { get; set; } + + // Compare config sections because it's accessible to app developers via an Action callback. + // We cannot rely entirely on comparing config sections for equality, because KestrelConfigurationLoader.Reload() sets + // EndpointConfig properties to their default values. If a default value changes, the properties would no longer be equal, + // but the config sections could still be equal. + public IConfigurationSection ConfigSection + { + get => _configSection; + set + { + _configSection = value; + // The IConfigrationSection will mutate, so we need to take a snapshot to compare against later and check for changes. + _configSectionClone = new ConfigSectionClone(value); + } + } + + public override bool Equals(object obj) => + obj is EndpointConfig other && + Name == other.Name && + Url == other.Url && + (Protocols ?? ListenOptions.DefaultHttpProtocols) == (other.Protocols ?? ListenOptions.DefaultHttpProtocols) && + Certificate == other.Certificate && + _configSectionClone == other._configSectionClone; + + public override int GetHashCode() => HashCode.Combine(Name, Url, Protocols ?? ListenOptions.DefaultHttpProtocols, Certificate, _configSectionClone); + + public static bool operator ==(EndpointConfig lhs, EndpointConfig rhs) => lhs is null ? rhs is null : lhs.Equals(rhs); + public static bool operator !=(EndpointConfig lhs, EndpointConfig rhs) => !(lhs == rhs); } // "CertificateName": { @@ -206,5 +193,19 @@ public CertificateConfig(IConfigurationSection configSection) public string Location { get; set; } public bool? AllowInvalid { get; set; } + + public override bool Equals(object obj) => + obj is CertificateConfig other && + Path == other.Path && + Password == other.Password && + Subject == other.Subject && + Store == other.Store && + Location == other.Location && + (AllowInvalid ?? false) == (other.AllowInvalid ?? false); + + public override int GetHashCode() => HashCode.Combine(Path, Password, Subject, Store, Location, AllowInvalid ?? false); + + public static bool operator ==(CertificateConfig lhs, CertificateConfig rhs) => lhs is null ? rhs is null : lhs.Equals(rhs); + public static bool operator !=(CertificateConfig lhs, CertificateConfig rhs) => !(lhs == rhs); } } diff --git a/src/Servers/Kestrel/Core/src/Internal/ConnectionDispatcher.cs b/src/Servers/Kestrel/Core/src/Internal/ConnectionDispatcher.cs index a373c240bf8d..66964340ff19 100644 --- a/src/Servers/Kestrel/Core/src/Internal/ConnectionDispatcher.cs +++ b/src/Servers/Kestrel/Core/src/Internal/ConnectionDispatcher.cs @@ -10,29 +10,31 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal { - internal class ConnectionDispatcher + internal class ConnectionDispatcher where T : BaseConnectionContext { private static long _lastConnectionId = long.MinValue; private readonly ServiceContext _serviceContext; - private readonly ConnectionDelegate _connectionDelegate; + private readonly Func _connectionDelegate; + private readonly TransportConnectionManager _transportConnectionManager; private readonly TaskCompletionSource _acceptLoopTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - public ConnectionDispatcher(ServiceContext serviceContext, ConnectionDelegate connectionDelegate) + public ConnectionDispatcher(ServiceContext serviceContext, Func connectionDelegate, TransportConnectionManager transportConnectionManager) { _serviceContext = serviceContext; _connectionDelegate = connectionDelegate; + _transportConnectionManager = transportConnectionManager; } private IKestrelTrace Log => _serviceContext.Log; - public Task StartAcceptingConnections(IConnectionListener listener) + public Task StartAcceptingConnections(IConnectionListener listener) { ThreadPool.UnsafeQueueUserWorkItem(StartAcceptingConnectionsCore, listener, preferLocal: false); return _acceptLoopTcs.Task; } - private void StartAcceptingConnectionsCore(IConnectionListener listener) + private void StartAcceptingConnectionsCore(IConnectionListener listener) { // REVIEW: Multiple accept loops in parallel? _ = AcceptConnectionsAsync(); @@ -53,9 +55,10 @@ async Task AcceptConnectionsAsync() // Add the connection to the connection manager before we queue it for execution var id = Interlocked.Increment(ref _lastConnectionId); - var kestrelConnection = new KestrelConnection(id, _serviceContext, c => _connectionDelegate(c), connection, Log); + var kestrelConnection = new KestrelConnection( + id, _serviceContext, _transportConnectionManager, _connectionDelegate, connection, Log); - _serviceContext.ConnectionManager.AddConnection(id, kestrelConnection); + _transportConnectionManager.AddConnection(id, kestrelConnection); Log.ConnectionAccepted(connection.ConnectionId); diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ConnectionManager.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ConnectionManager.cs index 05bb0f0726a5..87e5f08be547 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ConnectionManager.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ConnectionManager.cs @@ -31,9 +31,9 @@ public ConnectionManager(IKestrelTrace trace, ResourceCounter upgradedConnection /// public ResourceCounter UpgradedConnectionCount { get; } - public void AddConnection(long id, KestrelConnection connection) + public void AddConnection(long id, ConnectionReference connectionReference) { - if (!_connectionReferences.TryAdd(id, new ConnectionReference(connection))) + if (!_connectionReferences.TryAdd(id, connectionReference)) { throw new ArgumentException(nameof(id)); } @@ -67,52 +67,13 @@ public void Walk(Action callback) // It's safe to modify the ConcurrentDictionary in the foreach. // The connection reference has become unrooted because the application never completed. _trace.ApplicationNeverCompleted(reference.ConnectionId); + reference.StopTrasnsportTracking(); } // If both conditions are false, the connection was removed during the heartbeat. } } - public async Task CloseAllConnectionsAsync(CancellationToken token) - { - var closeTasks = new List(); - - Walk(connection => - { - connection.RequestClose(); - closeTasks.Add(connection.ExecutionTask); - }); - - var allClosedTask = Task.WhenAll(closeTasks.ToArray()); - return await Task.WhenAny(allClosedTask, CancellationTokenAsTask(token)).ConfigureAwait(false) == allClosedTask; - } - - public async Task AbortAllConnectionsAsync() - { - var abortTasks = new List(); - - Walk(connection => - { - connection.TransportConnection.Abort(new ConnectionAbortedException(CoreStrings.ConnectionAbortedDuringServerShutdown)); - abortTasks.Add(connection.ExecutionTask); - }); - - var allAbortedTask = Task.WhenAll(abortTasks.ToArray()); - return await Task.WhenAny(allAbortedTask, Task.Delay(1000)).ConfigureAwait(false) == allAbortedTask; - } - - private static Task CancellationTokenAsTask(CancellationToken token) - { - if (token.IsCancellationRequested) - { - return Task.CompletedTask; - } - - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - token.Register(() => tcs.SetResult(null)); - return tcs.Task; - } - private static ResourceCounter GetCounter(long? number) => number.HasValue ? ResourceCounter.Quota(number.Value) diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ConnectionReference.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ConnectionReference.cs index dd31fde12fb2..f4d58aa52ef4 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ConnectionReference.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/ConnectionReference.cs @@ -7,12 +7,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure { internal class ConnectionReference { + private readonly long _id; private readonly WeakReference _weakReference; + private readonly TransportConnectionManager _transportConnectionManager; - public ConnectionReference(KestrelConnection connection) + public ConnectionReference(long id, KestrelConnection connection, TransportConnectionManager transportConnectionManager) { + _id = id; + _weakReference = new WeakReference(connection); ConnectionId = connection.TransportConnection.ConnectionId; + + _transportConnectionManager = transportConnectionManager; } public string ConnectionId { get; } @@ -21,5 +27,10 @@ public bool TryGetConnection(out KestrelConnection connection) { return _weakReference.TryGetTarget(out connection); } + + public void StopTrasnsportTracking() + { + _transportConnectionManager.StopTracking(_id); + } } } diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/IConnectionListenerBase.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/IConnectionListenerBase.cs new file mode 100644 index 000000000000..262da3da7d16 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/IConnectionListenerBase.cs @@ -0,0 +1,28 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Net; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + /// + /// Defines an interface that represents a listener bound to a specific . + /// + internal interface IConnectionListenerBase : IAsyncDisposable + { + /// + /// The endpoint that was bound. This may differ from the requested endpoint, such as when the caller requested that any free port be selected. + /// + EndPoint EndPoint { get; } + + /// + /// Stops listening for incoming connections. + /// + /// The token to monitor for cancellation requests. + /// A that represents the unbind operation. + ValueTask UnbindAsync(CancellationToken cancellationToken = default); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/IConnectionListenerOfT.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/IConnectionListenerOfT.cs new file mode 100644 index 000000000000..40fb0a2f751a --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/IConnectionListenerOfT.cs @@ -0,0 +1,23 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Net; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + /// + /// Defines an interface that represents a listener bound to a specific . + /// + internal interface IConnectionListener : IConnectionListenerBase where T : BaseConnectionContext + { + /// + /// Begins an asynchronous operation to accept an incoming connection. + /// + /// The token to monitor for cancellation requests. + /// A that completes when a connection is accepted, yielding the representing the connection. + ValueTask AcceptAsync(CancellationToken cancellationToken = default); + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/KestrelConnection.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/KestrelConnection.cs index 5365ed739790..c81dce45673c 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/KestrelConnection.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/KestrelConnection.cs @@ -23,13 +23,16 @@ internal abstract class KestrelConnection : IConnectionHeartbeatFeature, IConnec private readonly TaskCompletionSource _completionTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); protected readonly long _id; protected readonly ServiceContext _serviceContext; + protected readonly TransportConnectionManager _transportConnectionManager; public KestrelConnection(long id, ServiceContext serviceContext, + TransportConnectionManager transportConnectionManager, IKestrelTrace logger) { _id = id; _serviceContext = serviceContext; + _transportConnectionManager = transportConnectionManager; Logger = logger; ConnectionClosedRequested = _connectionClosingCts.Token; diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/KestrelConnectionOfT.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/KestrelConnectionOfT.cs index d758bf3f1a80..465440ee00ac 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/KestrelConnectionOfT.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/KestrelConnectionOfT.cs @@ -14,10 +14,11 @@ internal class KestrelConnection : KestrelConnection, IThreadPoolWorkItem whe public KestrelConnection(long id, ServiceContext serviceContext, + TransportConnectionManager transportConnectionManager, Func connectionDelegate, T connectionContext, IKestrelTrace logger) - : base(id, serviceContext, logger) + : base(id, serviceContext, transportConnectionManager, logger) { _connectionDelegate = connectionDelegate; _transportConnection = connectionContext; @@ -66,7 +67,7 @@ internal async Task ExecuteAsync() // is properly torn down. await connectionContext.DisposeAsync(); - _serviceContext.ConnectionManager.RemoveConnection(_id); + _transportConnectionManager.RemoveConnection(_id); } } } diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/TransportConnectionManager.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/TransportConnectionManager.cs new file mode 100644 index 000000000000..0021bc5b0cc4 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/TransportConnectionManager.cs @@ -0,0 +1,103 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +#nullable enable + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + internal class TransportConnectionManager + { + private readonly ConnectionManager _connectionManager; + private readonly ConcurrentDictionary _connectionReferences = new ConcurrentDictionary(); + + public TransportConnectionManager(ConnectionManager connectionManager) + { + _connectionManager = connectionManager; + } + + public void AddConnection(long id, KestrelConnection connection) + { + var connectionReference = new ConnectionReference(id, connection, this); + + if (!_connectionReferences.TryAdd(id, connectionReference)) + { + throw new ArgumentException(nameof(id)); + } + + _connectionManager.AddConnection(id, connectionReference); + } + + public void RemoveConnection(long id) + { + if (!_connectionReferences.TryRemove(id, out _)) + { + throw new ArgumentException(nameof(id)); + } + + _connectionManager.RemoveConnection(id); + } + + // This is only called by the ConnectionManager when the connection reference becomes + // unrooted because the application never completed. + public void StopTracking(long id) + { + if (!_connectionReferences.TryRemove(id, out _)) + { + throw new ArgumentException(nameof(id)); + } + } + + public async Task CloseAllConnectionsAsync(CancellationToken token) + { + var closeTasks = new List(); + + foreach (var kvp in _connectionReferences) + { + if (kvp.Value.TryGetConnection(out var connection)) + { + connection.RequestClose(); + closeTasks.Add(connection.ExecutionTask); + } + } + + var allClosedTask = Task.WhenAll(closeTasks.ToArray()); + return await Task.WhenAny(allClosedTask, CancellationTokenAsTask(token)).ConfigureAwait(false) == allClosedTask; + } + + public async Task AbortAllConnectionsAsync() + { + var abortTasks = new List(); + + foreach (var kvp in _connectionReferences) + { + if (kvp.Value.TryGetConnection(out var connection)) + { + connection.TransportConnection.Abort(new ConnectionAbortedException(CoreStrings.ConnectionAbortedDuringServerShutdown)); + abortTasks.Add(connection.ExecutionTask); + } + } + + var allAbortedTask = Task.WhenAll(abortTasks.ToArray()); + return await Task.WhenAny(allAbortedTask, Task.Delay(1000)).ConfigureAwait(false) == allAbortedTask; + } + + private static Task CancellationTokenAsTask(CancellationToken token) + { + if (token.IsCancellationRequested) + { + return Task.CompletedTask; + } + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + token.Register(() => tcs.SetResult(null)); + return tcs.Task; + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/TransportManager.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/TransportManager.cs new file mode 100644 index 000000000000..9971db7c1e17 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/TransportManager.cs @@ -0,0 +1,195 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +#nullable enable + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + internal class TransportManager + { + private readonly List _transports = new List(); + + private readonly IConnectionListenerFactory? _transportFactory; + private readonly IMultiplexedConnectionListenerFactory? _multiplexedTransportFactory; + private readonly ServiceContext _serviceContext; + + public TransportManager( + IConnectionListenerFactory? transportFactory, + IMultiplexedConnectionListenerFactory? multiplexedTransportFactory, + ServiceContext serviceContext) + { + _transportFactory = transportFactory; + _multiplexedTransportFactory = multiplexedTransportFactory; + _serviceContext = serviceContext; + } + + private ConnectionManager ConnectionManager => _serviceContext.ConnectionManager; + private IKestrelTrace Trace => _serviceContext.Log; + + public async Task BindAsync(EndPoint endPoint, ConnectionDelegate connectionDelegate, EndpointConfig? endpointConfig) + { + if (_transportFactory is null) + { + throw new InvalidOperationException($"Cannot bind with {nameof(ConnectionDelegate)} no {nameof(IConnectionListenerFactory)} is registered."); + } + + var transport = await _transportFactory.BindAsync(endPoint).ConfigureAwait(false); + StartAcceptLoop(new GenericConnectionListener(transport), c => connectionDelegate(c), endpointConfig); + return transport.EndPoint; + } + + public async Task BindAsync(EndPoint endPoint, MultiplexedConnectionDelegate multiplexedConnectionDelegate, EndpointConfig? endpointConfig) + { + if (_multiplexedTransportFactory is null) + { + throw new InvalidOperationException($"Cannot bind with {nameof(MultiplexedConnectionDelegate)} no {nameof(IMultiplexedConnectionListenerFactory)} is registered."); + } + + var transport = await _multiplexedTransportFactory.BindAsync(endPoint).ConfigureAwait(false); + StartAcceptLoop(new GenericMultiplexedConnectionListener(transport), c => multiplexedConnectionDelegate(c), endpointConfig); + return transport.EndPoint; + } + + private void StartAcceptLoop(IConnectionListener connectionListener, Func connectionDelegate, EndpointConfig? endpointConfig) where T : BaseConnectionContext + { + var transportConnectionManager = new TransportConnectionManager(_serviceContext.ConnectionManager); + var connectionDispatcher = new ConnectionDispatcher(_serviceContext, connectionDelegate, transportConnectionManager); + var acceptLoopTask = connectionDispatcher.StartAcceptingConnections(connectionListener); + + _transports.Add(new ActiveTransport(connectionListener, acceptLoopTask, transportConnectionManager, endpointConfig)); + } + + public Task StopEndpointsAsync(List endpointsToStop, CancellationToken cancellationToken) + { + var transportsToStop = _transports.Where(t => t.EndpointConfig != null && endpointsToStop.Contains(t.EndpointConfig)).ToList(); + return StopTransportsAsync(transportsToStop, cancellationToken); + } + + public Task StopAsync(CancellationToken cancellationToken) + { + return StopTransportsAsync(new List(_transports), cancellationToken); + } + + private async Task StopTransportsAsync(List transportsToStop, CancellationToken cancellationToken) + { + var tasks = new Task[transportsToStop.Count]; + + for (int i = 0; i < transportsToStop.Count; i++) + { + tasks[i] = transportsToStop[i].UnbindAsync(cancellationToken); + } + + await Task.WhenAll(tasks).ConfigureAwait(false); + + async Task StopTransportConnection(ActiveTransport transport) + { + if (!await transport.TransportConnectionManager.CloseAllConnectionsAsync(cancellationToken).ConfigureAwait(false)) + { + Trace.NotAllConnectionsClosedGracefully(); + + if (!await transport.TransportConnectionManager.AbortAllConnectionsAsync().ConfigureAwait(false)) + { + Trace.NotAllConnectionsAborted(); + } + } + } + + for (int i = 0; i < transportsToStop.Count; i++) + { + tasks[i] = StopTransportConnection(transportsToStop[i]); + } + + await Task.WhenAll(tasks).ConfigureAwait(false); + + for (int i = 0; i < transportsToStop.Count; i++) + { + tasks[i] = transportsToStop[i].DisposeAsync().AsTask(); + } + + await Task.WhenAll(tasks).ConfigureAwait(false); + + foreach (var transport in transportsToStop) + { + _transports.Remove(transport); + } + } + + private class ActiveTransport : IAsyncDisposable + { + public ActiveTransport(IConnectionListenerBase transport, Task acceptLoopTask, TransportConnectionManager transportConnectionManager, EndpointConfig? endpointConfig = null) + { + ConnectionListener = transport; + AcceptLoopTask = acceptLoopTask; + TransportConnectionManager = transportConnectionManager; + EndpointConfig = endpointConfig; + } + + public IConnectionListenerBase ConnectionListener { get; } + public Task AcceptLoopTask { get; } + public TransportConnectionManager TransportConnectionManager { get; } + + public EndpointConfig? EndpointConfig { get; } + + public async Task UnbindAsync(CancellationToken cancellationToken) + { + await ConnectionListener.UnbindAsync(cancellationToken).ConfigureAwait(false); + await AcceptLoopTask.ConfigureAwait(false); + } + + public ValueTask DisposeAsync() + { + return ConnectionListener.DisposeAsync(); + } + } + + private class GenericConnectionListener : IConnectionListener + { + private readonly IConnectionListener _connectionListener; + + public GenericConnectionListener(IConnectionListener connectionListener) + { + _connectionListener = connectionListener; + } + + public EndPoint EndPoint => _connectionListener.EndPoint; + + public ValueTask AcceptAsync(CancellationToken cancellationToken = default) + => _connectionListener.AcceptAsync(cancellationToken); + + public ValueTask UnbindAsync(CancellationToken cancellationToken = default) + => _connectionListener.UnbindAsync(); + + public ValueTask DisposeAsync() + => _connectionListener.DisposeAsync(); + } + + private class GenericMultiplexedConnectionListener : IConnectionListener + { + private readonly IMultiplexedConnectionListener _multiplexedConnectionListener; + + public GenericMultiplexedConnectionListener(IMultiplexedConnectionListener multiplexedConnectionListener) + { + _multiplexedConnectionListener = multiplexedConnectionListener; + } + + public EndPoint EndPoint => _multiplexedConnectionListener.EndPoint; + + public ValueTask AcceptAsync(CancellationToken cancellationToken = default) + => _multiplexedConnectionListener.AcceptAsync(features: null, cancellationToken); + + public ValueTask UnbindAsync(CancellationToken cancellationToken = default) + => _multiplexedConnectionListener.UnbindAsync(); + + public ValueTask DisposeAsync() + => _multiplexedConnectionListener.DisposeAsync(); + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/MultiplexedConnectionDispatcher.cs b/src/Servers/Kestrel/Core/src/Internal/MultiplexedConnectionDispatcher.cs deleted file mode 100644 index e0fe1edbdc72..000000000000 --- a/src/Servers/Kestrel/Core/src/Internal/MultiplexedConnectionDispatcher.cs +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Connections; -using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; -using Microsoft.Extensions.Logging; - -namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal -{ - internal class MultiplexedConnectionDispatcher - { - private static long _lastConnectionId = long.MinValue; - - private readonly ServiceContext _serviceContext; - private readonly MultiplexedConnectionDelegate _connectionDelegate; - private readonly TaskCompletionSource _acceptLoopTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - public MultiplexedConnectionDispatcher(ServiceContext serviceContext, MultiplexedConnectionDelegate connectionDelegate) - { - _serviceContext = serviceContext; - _connectionDelegate = connectionDelegate; - } - - private IKestrelTrace Log => _serviceContext.Log; - - public Task StartAcceptingConnections(IMultiplexedConnectionListener listener) - { - ThreadPool.UnsafeQueueUserWorkItem(StartAcceptingConnectionsCore, listener, preferLocal: false); - return _acceptLoopTcs.Task; - } - - private void StartAcceptingConnectionsCore(IMultiplexedConnectionListener listener) - { - // REVIEW: Multiple accept loops in parallel? - _ = AcceptConnectionsAsync(); - - async Task AcceptConnectionsAsync() - { - try - { - while (true) - { - var connection = await listener.AcceptAsync(); - - if (connection == null) - { - // We're done listening - break; - } - - // Add the connection to the connection manager before we queue it for execution - var id = Interlocked.Increment(ref _lastConnectionId); - var kestrelConnection = new KestrelConnection(id, _serviceContext, c => _connectionDelegate(c), connection, Log); - - _serviceContext.ConnectionManager.AddConnection(id, kestrelConnection); - - Log.ConnectionAccepted(connection.ConnectionId); - - ThreadPool.UnsafeQueueUserWorkItem(kestrelConnection, preferLocal: false); - } - } - catch (Exception ex) - { - // REVIEW: If the accept loop ends should this trigger a server shutdown? It will manifest as a hang - Log.LogCritical(0, ex, "The connection listener failed to accept any new connections."); - } - finally - { - _acceptLoopTcs.TrySetResult(null); - } - } - } - } -} diff --git a/src/Servers/Kestrel/Core/src/Internal/ServerAddressesCollection.cs b/src/Servers/Kestrel/Core/src/Internal/ServerAddressesCollection.cs new file mode 100644 index 000000000000..551938a920e0 --- /dev/null +++ b/src/Servers/Kestrel/Core/src/Internal/ServerAddressesCollection.cs @@ -0,0 +1,173 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +#nullable enable + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using Microsoft.AspNetCore.Hosting.Server.Features; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal +{ + internal class ServerAddressesCollection : ICollection + { + private readonly List _addresses = new List(); + private readonly PublicServerAddressesCollection _publicCollection; + + public ServerAddressesCollection() + { + _publicCollection = new PublicServerAddressesCollection(this); + } + + public ICollection PublicCollection => _publicCollection; + + public bool IsReadOnly => false; + + public int Count + { + get + { + lock (_addresses) + { + return _addresses.Count; + } + } + } + + public void PreventPublicMutation() + { + lock (_addresses) + { + _publicCollection.IsReadOnly = true; + } + } + + public void Add(string item) + { + lock (_addresses) + { + _addresses.Add(item); + } + } + + public bool Remove(string item) + { + lock (_addresses) + { + return _addresses.Remove(item); + } + } + + public void Clear() + { + lock (_addresses) + { + _addresses.Clear(); + } + } + + public bool Contains(string item) + { + lock (_addresses) + { + return _addresses.Contains(item); + } + } + + public void CopyTo(string[] array, int arrayIndex) + { + lock (_addresses) + { + _addresses.CopyTo(array, arrayIndex); + } + } + + public IEnumerator GetEnumerator() + { + lock (_addresses) + { + // Copy inside the lock. + return new List(_addresses).GetEnumerator(); + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + private class PublicServerAddressesCollection : ICollection + { + private readonly ServerAddressesCollection _addressesCollection; + private readonly object _addressesLock; + + public PublicServerAddressesCollection(ServerAddressesCollection addresses) + { + _addressesCollection = addresses; + _addressesLock = addresses._addresses; + } + + public bool IsReadOnly { get; set; } + + public int Count => _addressesCollection.Count; + + public void Add(string item) + { + lock (_addressesLock) + { + ThrowIfReadonly(); + _addressesCollection.Add(item); + } + } + + public bool Remove(string item) + { + lock (_addressesLock) + { + ThrowIfReadonly(); + return _addressesCollection.Remove(item); + } + } + + public void Clear() + { + lock (_addressesLock) + { + ThrowIfReadonly(); + _addressesCollection.Clear(); + } + } + + public bool Contains(string item) + { + return _addressesCollection.Contains(item); + } + + public void CopyTo(string[] array, int arrayIndex) + { + _addressesCollection.CopyTo(array, arrayIndex); + } + + public IEnumerator GetEnumerator() + { + return _addressesCollection.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return _addressesCollection.GetEnumerator(); + } + + [StackTraceHidden] + private void ThrowIfReadonly() + { + if (IsReadOnly) + { + throw new InvalidOperationException($"{nameof(IServerAddressesFeature)}.{nameof(IServerAddressesFeature.Addresses)} cannot be modified after the server has started."); + } + } + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/ServerAddressesFeature.cs b/src/Servers/Kestrel/Core/src/Internal/ServerAddressesFeature.cs index f8bcd13cde98..262cf985a4d7 100644 --- a/src/Servers/Kestrel/Core/src/Internal/ServerAddressesFeature.cs +++ b/src/Servers/Kestrel/Core/src/Internal/ServerAddressesFeature.cs @@ -8,7 +8,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal { internal class ServerAddressesFeature : IServerAddressesFeature { - public ICollection Addresses { get; } = new List(); + public ServerAddressesCollection InternalCollection { get; } = new ServerAddressesCollection(); + + ICollection IServerAddressesFeature.Addresses => InternalCollection.PublicCollection; public bool PreferHostingUrls { get; set; } } } diff --git a/src/Servers/Kestrel/Core/src/KestrelConfigurationLoader.cs b/src/Servers/Kestrel/Core/src/KestrelConfigurationLoader.cs index 5e202f3efa29..518f90853569 100644 --- a/src/Servers/Kestrel/Core/src/KestrelConfigurationLoader.cs +++ b/src/Servers/Kestrel/Core/src/KestrelConfigurationLoader.cs @@ -13,7 +13,6 @@ using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; using Microsoft.AspNetCore.Server.Kestrel.Https; -using Microsoft.AspNetCore.Server.Kestrel.Https.Internal; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; @@ -25,21 +24,32 @@ public class KestrelConfigurationLoader { private bool _loaded = false; - internal KestrelConfigurationLoader(KestrelServerOptions options, IConfiguration configuration) + internal KestrelConfigurationLoader(KestrelServerOptions options, IConfiguration configuration, bool reloadOnChange) { Options = options ?? throw new ArgumentNullException(nameof(options)); Configuration = configuration ?? throw new ArgumentNullException(nameof(configuration)); - ConfigurationReader = new ConfigurationReader(Configuration); + ReloadOnChange = reloadOnChange; } public KestrelServerOptions Options { get; } - public IConfiguration Configuration { get; } - internal ConfigurationReader ConfigurationReader { get; } + public IConfiguration Configuration { get; internal set; } + + /// + /// If , Kestrel will dynamically update endpoint bindings when configuration changes. + /// This will only reload endpoints defined in the "Endpoints" section of your Kestrel configuration. Endpoints defined in code will not be reloaded. + /// + internal bool ReloadOnChange { get; } + + private ConfigurationReader ConfigurationReader { get; set; } + private IDictionary> EndpointConfigurations { get; } = new Dictionary>(0, StringComparer.OrdinalIgnoreCase); + // Actions that will be delayed until Load so that they aren't applied if the configuration loader is replaced. private IList EndpointsToAdd { get; } = new List(); + private CertificateConfig DefaultCertificateConfig { get; set; } + /// /// Specifies a configuration Action to run when an endpoint with the given name is loaded from configuration. /// @@ -222,6 +232,26 @@ public void Load() } _loaded = true; + Reload(); + + foreach (var action in EndpointsToAdd) + { + action(); + } + } + + // Adds endpoints from config to KestrelServerOptions.ConfigurationBackedListenOptions and configures some other options. + // Any endpoints that were removed from the last time endpoints were loaded are returned. + internal (List, List) Reload() + { + var endpointsToStop = Options.ConfigurationBackedListenOptions.ToList(); + var endpointsToStart = new List(); + + Options.ConfigurationBackedListenOptions.Clear(); + DefaultCertificateConfig = null; + + ConfigurationReader = new ConfigurationReader(Configuration); + Options.Latin1RequestHeaders = ConfigurationReader.Latin1RequestHeaders; LoadDefaultCert(ConfigurationReader); @@ -229,12 +259,19 @@ public void Load() foreach (var endpoint in ConfigurationReader.Endpoints) { var listenOptions = AddressBinder.ParseAddress(endpoint.Url, out var https); + Options.ApplyEndpointDefaults(listenOptions); if (endpoint.Protocols.HasValue) { listenOptions.Protocols = endpoint.Protocols.Value; } + else + { + // Ensure endpoint is reloaded if it used the default protocol and the protocol changed. + // listenOptions.Protocols should already be set to this by ApplyEndpointDefaults. + endpoint.Protocols = ConfigurationReader.EndpointDefaults.Protocols; + } // Compare to UseHttps(httpsOptions => { }) var httpsOptions = new HttpsConnectionAdapterOptions(); @@ -247,8 +284,25 @@ public void Load() httpsOptions.ServerCertificate = LoadCertificate(endpoint.Certificate, endpoint.Name) ?? httpsOptions.ServerCertificate; - // Fallback - Options.ApplyDefaultCert(httpsOptions); + if (httpsOptions.ServerCertificate == null && httpsOptions.ServerCertificateSelector == null) + { + // Fallback + Options.ApplyDefaultCert(httpsOptions); + + // Ensure endpoint is reloaded if it used the default certificate and the certificate changed. + endpoint.Certificate = DefaultCertificateConfig; + } + } + + // Now that defaults have been loaded, we can compare to the currently bound endpoints to see if the config changed. + // There's no reason to rerun an EndpointConfigurations callback if nothing changed. + var matchingBoundEndpoints = endpointsToStop.Where(o => o.EndpointConfig == endpoint).ToList(); + + if (matchingBoundEndpoints.Count > 0) + { + endpointsToStop.RemoveAll(o => o.EndpointConfig == endpoint); + Options.ConfigurationBackedListenOptions.AddRange(matchingBoundEndpoints); + continue; } if (EndpointConfigurations.TryGetValue(endpoint.Name, out var configureEndpoint)) @@ -268,13 +322,13 @@ public void Load() listenOptions.UseHttps(httpsOptions); } - Options.ListenOptions.Add(listenOptions); - } + listenOptions.EndpointConfig = endpoint; - foreach (var action in EndpointsToAdd) - { - action(); + endpointsToStart.Add(listenOptions); + Options.ConfigurationBackedListenOptions.Add(listenOptions); } + + return (endpointsToStop, endpointsToStart); } private void LoadDefaultCert(ConfigurationReader configReader) @@ -284,22 +338,24 @@ private void LoadDefaultCert(ConfigurationReader configReader) var defaultCert = LoadCertificate(defaultCertConfig, "Default"); if (defaultCert != null) { + DefaultCertificateConfig = defaultCertConfig; Options.DefaultCertificate = defaultCert; } } else { var logger = Options.ApplicationServices.GetRequiredService>(); - var certificate = FindDeveloperCertificateFile(configReader, logger); + var (certificate, certificateConfig) = FindDeveloperCertificateFile(configReader, logger); if (certificate != null) { logger.LocatedDevelopmentCertificate(certificate); + DefaultCertificateConfig = certificateConfig; Options.DefaultCertificate = certificate; } } } - private X509Certificate2 FindDeveloperCertificateFile(ConfigurationReader configReader, ILogger logger) + private (X509Certificate2, CertificateConfig) FindDeveloperCertificateFile(ConfigurationReader configReader, ILogger logger) { string certificatePath = null; try @@ -311,9 +367,13 @@ private X509Certificate2 FindDeveloperCertificateFile(ConfigurationReader config File.Exists(certificatePath)) { var certificate = new X509Certificate2(certificatePath, certificateConfig.Password); - return IsDevelopmentCertificate(certificate) ? certificate : null; + + if (IsDevelopmentCertificate(certificate)) + { + return (certificate, certificateConfig); + } } - else if (!File.Exists(certificatePath)) + else if (!string.IsNullOrEmpty(certificatePath)) { logger.FailedToLocateDevelopmentCertificateFile(certificatePath); } @@ -323,10 +383,10 @@ private X509Certificate2 FindDeveloperCertificateFile(ConfigurationReader config logger.FailedToLoadDevelopmentCertificate(certificatePath); } - return null; + return (null, null); } - private bool IsDevelopmentCertificate(X509Certificate2 certificate) + private static bool IsDevelopmentCertificate(X509Certificate2 certificate) { if (!string.Equals(certificate.Subject, "CN=localhost", StringComparison.Ordinal)) { diff --git a/src/Servers/Kestrel/Core/src/KestrelServer.cs b/src/Servers/Kestrel/Core/src/KestrelServer.cs index fdd2b47319cf..8359b9b4baea 100644 --- a/src/Servers/Kestrel/Core/src/KestrelServer.cs +++ b/src/Servers/Kestrel/Core/src/KestrelServer.cs @@ -5,7 +5,6 @@ using System.Collections.Generic; using System.IO.Pipelines; using System.Linq; -using System.Net; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; @@ -17,25 +16,30 @@ using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.Server.Kestrel.Core { public class KestrelServer : IServer { - private readonly List<(IConnectionListener, Task)> _transports = new List<(IConnectionListener, Task)>(); - private readonly List<(IMultiplexedConnectionListener, Task)> _multiplexedTransports = new List<(IMultiplexedConnectionListener, Task)>(); - private readonly IServerAddressesFeature _serverAddresses; - private readonly List _transportFactories; - private readonly List _multiplexedTransportFactories; + private readonly ServerAddressesFeature _serverAddresses; + private readonly TransportManager _transportManager; + private readonly IConnectionListenerFactory _transportFactory; + private readonly IMultiplexedConnectionListenerFactory _multiplexedTransportFactory; + private readonly SemaphoreSlim _bindSemaphore = new SemaphoreSlim(initialCount: 1); private bool _hasStarted; private int _stopping; + private readonly CancellationTokenSource _stopCts = new CancellationTokenSource(); private readonly TaskCompletionSource _stoppedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + private IDisposable _configChangedRegistration; + public KestrelServer(IOptions options, IEnumerable transportFactories, ILoggerFactory loggerFactory) : this(transportFactories, null, CreateServiceContext(options, loggerFactory)) { } + public KestrelServer(IOptions options, IEnumerable transportFactories, IEnumerable multiplexedFactories, ILoggerFactory loggerFactory) : this(transportFactories, multiplexedFactories, CreateServiceContext(options, loggerFactory)) { @@ -55,10 +59,10 @@ internal KestrelServer(IEnumerable transportFactorie throw new ArgumentNullException(nameof(transportFactories)); } - _transportFactories = transportFactories.ToList(); - _multiplexedTransportFactories = multiplexedFactories?.ToList(); + _transportFactory = transportFactories?.LastOrDefault(); + _multiplexedTransportFactory = multiplexedFactories?.LastOrDefault(); - if (_transportFactories.Count == 0 && (_multiplexedTransportFactories == null || _multiplexedTransportFactories.Count == 0)) + if (_transportFactory == null && _multiplexedTransportFactory == null) { throw new InvalidOperationException(CoreStrings.TransportNotFound); } @@ -67,7 +71,9 @@ internal KestrelServer(IEnumerable transportFactorie Features = new FeatureCollection(); _serverAddresses = new ServerAddressesFeature(); - Features.Set(_serverAddresses); + Features.Set(_serverAddresses); + + _transportManager = new TransportManager(_transportFactory, _multiplexedTransportFactory, ServiceContext); HttpCharacters.Initialize(); } @@ -120,7 +126,7 @@ private static ServiceContext CreateServiceContext(IOptions ServiceContext.Log; - private ConnectionManager ConnectionManager => ServiceContext.ConnectionManager; + private AddressBindContext AddressBindContext { get; set; } public async Task StartAsync(IHttpApplication application, CancellationToken cancellationToken) { @@ -148,22 +154,18 @@ async Task OnBind(ListenOptions options) // sockets for it to successfully listen. It also seems racy. if ((options.Protocols & HttpProtocols.Http3) == HttpProtocols.Http3) { - if (_multiplexedTransportFactories == null || _multiplexedTransportFactories.Count == 0) + if (_multiplexedTransportFactory is null) { - throw new InvalidOperationException("Cannot start HTTP/3 server if no MultiplexedTransportFactories are registered."); + throw new InvalidOperationException($"Cannot start HTTP/3 server if no {nameof(IMultiplexedConnectionListenerFactory)} is registered."); } options.UseHttp3Server(ServiceContext, application, options.Protocols); - var multiplxedConnectionDelegate = ((IMultiplexedConnectionBuilder)options).Build(); + var multiplexedConnectionDelegate = ((IMultiplexedConnectionBuilder)options).Build(); - var multiplexedConnectionDispatcher = new MultiplexedConnectionDispatcher(ServiceContext, multiplxedConnectionDelegate); - var multiplexedFactory = _multiplexedTransportFactories.Last(); - var multiplexedTransport = await multiplexedFactory.BindAsync(options.EndPoint).ConfigureAwait(false); - - var acceptLoopTask = multiplexedConnectionDispatcher.StartAcceptingConnections(multiplexedTransport); - _multiplexedTransports.Add((multiplexedTransport, acceptLoopTask)); + // Add the connection limit middleware + multiplexedConnectionDelegate = EnforceConnectionLimit(multiplexedConnectionDelegate, Options.Limits.MaxConcurrentConnections, Trace); - options.EndPoint = multiplexedTransport.EndPoint; + options.EndPoint = await _transportManager.BindAsync(options.EndPoint, multiplexedConnectionDelegate, options.EndpointConfig).ConfigureAwait(false); } // Add the HTTP middleware as the terminal connection middleware @@ -172,27 +174,30 @@ async Task OnBind(ListenOptions options) || options.Protocols == HttpProtocols.None) // TODO a test fails because it doesn't throw an exception in the right place // when there is no HttpProtocols in KestrelServer, can we remove/change the test? { - options.UseHttpServer(ServiceContext, application, options.Protocols); - var connectionDelegate = options.Build(); - - // Add the connection limit middleware - if (Options.Limits.MaxConcurrentConnections.HasValue) + if (_transportFactory is null) { - connectionDelegate = new ConnectionLimitMiddleware(connectionDelegate, Options.Limits.MaxConcurrentConnections.Value, Trace).OnConnectionAsync; + throw new InvalidOperationException($"Cannot start HTTP/1.x or HTTP/2 server if no {nameof(IConnectionListenerFactory)} is registered."); } - var connectionDispatcher = new ConnectionDispatcher(ServiceContext, connectionDelegate); - var factory = _transportFactories.Last(); - var transport = await factory.BindAsync(options.EndPoint).ConfigureAwait(false); + options.UseHttpServer(ServiceContext, application, options.Protocols); + var connectionDelegate = options.Build(); - var acceptLoopTask = connectionDispatcher.StartAcceptingConnections(transport); + // Add the connection limit middleware + connectionDelegate = EnforceConnectionLimit(connectionDelegate, Options.Limits.MaxConcurrentConnections, Trace); - _transports.Add((transport, acceptLoopTask)); - options.EndPoint = transport.EndPoint; + options.EndPoint = await _transportManager.BindAsync(options.EndPoint, connectionDelegate, options.EndpointConfig).ConfigureAwait(false); } } - await AddressBinder.BindAsync(_serverAddresses, Options, Trace, OnBind).ConfigureAwait(false); + AddressBindContext = new AddressBindContext + { + ServerAddressesFeature = _serverAddresses, + ServerOptions = Options, + Logger = Trace, + CreateBinding = OnBind, + }; + + await BindAsync(cancellationToken).ConfigureAwait(false); } catch (Exception ex) { @@ -211,73 +216,139 @@ public async Task StopAsync(CancellationToken cancellationToken) return; } + _stopCts.Cancel(); + + // Don't use cancellationToken when acquiring the semaphore. Dispose calls this with a pre-canceled token. + await _bindSemaphore.WaitAsync().ConfigureAwait(false); + try { - var connectionTransportCount = _transports.Count; - var totalTransportCount = _transports.Count + _multiplexedTransports.Count; - var tasks = new Task[totalTransportCount]; + await _transportManager.StopAsync(cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + _stoppedTcs.TrySetException(ex); + throw; + } + finally + { + ServiceContext.Heartbeat?.Dispose(); + _configChangedRegistration?.Dispose(); + _stopCts.Dispose(); + _bindSemaphore.Release(); + } + + _stoppedTcs.TrySetResult(null); + } + + // Ungraceful shutdown + public void Dispose() + { + StopAsync(new CancellationToken(canceled: true)).GetAwaiter().GetResult(); + } + + private async Task BindAsync(CancellationToken cancellationToken) + { + await _bindSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); - for (int i = 0; i < connectionTransportCount; i++) + try + { + if (_stopping == 1) { - (IConnectionListener listener, Task acceptLoop) = _transports[i]; - tasks[i] = Task.WhenAll(listener.UnbindAsync(cancellationToken).AsTask(), acceptLoop); + throw new InvalidOperationException("Kestrel has already been stopped."); } - for (int i = connectionTransportCount; i < totalTransportCount; i++) + IChangeToken reloadToken = null; + + _serverAddresses.InternalCollection.PreventPublicMutation(); + + if (Options.ConfigurationLoader?.ReloadOnChange == true && (!_serverAddresses.PreferHostingUrls || _serverAddresses.InternalCollection.Count == 0)) { - (IMultiplexedConnectionListener listener, Task acceptLoop) = _multiplexedTransports[i - connectionTransportCount]; - tasks[i] = Task.WhenAll(listener.UnbindAsync(cancellationToken).AsTask(), acceptLoop); + reloadToken = Options.ConfigurationLoader.Configuration.GetReloadToken(); } - await Task.WhenAll(tasks).ConfigureAwait(false); + Options.ConfigurationLoader?.Load(); - if (!await ConnectionManager.CloseAllConnectionsAsync(cancellationToken).ConfigureAwait(false)) - { - Trace.NotAllConnectionsClosedGracefully(); + await AddressBinder.BindAsync(Options.ListenOptions, AddressBindContext).ConfigureAwait(false); + _configChangedRegistration = reloadToken?.RegisterChangeCallback(async state => await ((KestrelServer)state).RebindAsync(), this); + } + finally + { + _bindSemaphore.Release(); + } + } - if (!await ConnectionManager.AbortAllConnectionsAsync().ConfigureAwait(false)) - { - Trace.NotAllConnectionsAborted(); - } - } + private async Task RebindAsync() + { + await _bindSemaphore.WaitAsync(); + + IChangeToken reloadToken = null; - for (int i = 0; i < connectionTransportCount; i++) + try + { + if (_stopping == 1) { - (IConnectionListener listener, Task acceptLoop) = _transports[i]; - tasks[i] = listener.DisposeAsync().AsTask(); + return; } - for (int i = connectionTransportCount; i < totalTransportCount; i++) + reloadToken = Options.ConfigurationLoader.Configuration.GetReloadToken(); + var (endpointsToStop, endpointsToStart) = Options.ConfigurationLoader.Reload(); + + Trace.LogDebug("Config reload token fired. Checking for changes..."); + + if (endpointsToStop.Count > 0) { - (IMultiplexedConnectionListener listener, Task acceptLoop) = _multiplexedTransports[i - connectionTransportCount]; - tasks[i] = listener.DisposeAsync().AsTask(); + var urlsToStop = endpointsToStop.Select(lo => lo.EndpointConfig.Url ?? ""); + Trace.LogInformation("Config changed. Stopping the following endpoints: '{endpoints}'", string.Join("', '", urlsToStop)); + + // 5 is the default value for WebHost's "shutdownTimeoutSeconds", so use that. + using var timeoutCts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + using var combinedCts = CancellationTokenSource.CreateLinkedTokenSource(_stopCts.Token, timeoutCts.Token); + + // TODO: It would be nice to start binding to new endpoints immediately and reconfigured endpoints as soon + // as the unbinding finished for the given endpoint rather than wait for all transports to unbind first. + var configsToStop = endpointsToStop.Select(lo => lo.EndpointConfig).ToList(); + await _transportManager.StopEndpointsAsync(configsToStop, combinedCts.Token).ConfigureAwait(false); + + foreach (var listenOption in endpointsToStop) + { + Options.OptionsInUse.Remove(listenOption); + _serverAddresses.InternalCollection.Remove(listenOption.GetDisplayName()); + } } - await Task.WhenAll(tasks).ConfigureAwait(false); + if (endpointsToStart.Count > 0) + { + var urlsToStart = endpointsToStart.Select(lo => lo.EndpointConfig.Url ?? ""); + Trace.LogInformation("Config changed. Starting the following endpoints: '{endpoints}'", string.Join("', '", urlsToStart)); - ServiceContext.Heartbeat?.Dispose(); + foreach (var listenOption in endpointsToStart) + { + try + { + // TODO: This should probably be canceled by the _stopCts too, but we don't currently support bind cancellation even in StartAsync(). + await listenOption.BindAsync(AddressBindContext).ConfigureAwait(false); + } + catch (Exception ex) + { + Trace.LogCritical(0, ex, "Unable to bind to '{url}' on config reload.", listenOption.EndpointConfig.Url ?? ""); + } + } + } } catch (Exception ex) { - _stoppedTcs.TrySetException(ex); - throw; + Trace.LogCritical(0, ex, "Unable to reload configuration."); + } + finally + { + _configChangedRegistration = reloadToken?.RegisterChangeCallback(async state => await ((KestrelServer)state).RebindAsync(), this); + _bindSemaphore.Release(); } - - _stoppedTcs.TrySetResult(null); - } - - // Ungraceful shutdown - public void Dispose() - { - var cancelledTokenSource = new CancellationTokenSource(); - cancelledTokenSource.Cancel(); - StopAsync(cancelledTokenSource.Token).GetAwaiter().GetResult(); } private void ValidateOptions() { - Options.ConfigurationLoader?.Load(); - if (Options.Limits.MaxRequestBufferSize.HasValue && Options.Limits.MaxRequestBufferSize < Options.Limits.MaxRequestLineSize) { @@ -292,5 +363,25 @@ private void ValidateOptions() CoreStrings.FormatMaxRequestBufferSmallerThanRequestHeaderBuffer(Options.Limits.MaxRequestBufferSize.Value, Options.Limits.MaxRequestHeadersTotalSize)); } } + + private static ConnectionDelegate EnforceConnectionLimit(ConnectionDelegate innerDelegate, long? connectionLimit, IKestrelTrace trace) + { + if (!connectionLimit.HasValue) + { + return innerDelegate; + } + + return new ConnectionLimitMiddleware(c => innerDelegate(c), connectionLimit.Value, trace).OnConnectionAsync; + } + + private static MultiplexedConnectionDelegate EnforceConnectionLimit(MultiplexedConnectionDelegate innerDelegate, long? connectionLimit, IKestrelTrace trace) + { + if (!connectionLimit.HasValue) + { + return innerDelegate; + } + + return new ConnectionLimitMiddleware(c => innerDelegate(c), connectionLimit.Value, trace).OnConnectionAsync; + } } } diff --git a/src/Servers/Kestrel/Core/src/KestrelServerOptions.cs b/src/Servers/Kestrel/Core/src/KestrelServerOptions.cs index ce6cc1451b96..e643478c11f8 100644 --- a/src/Servers/Kestrel/Core/src/KestrelServerOptions.cs +++ b/src/Servers/Kestrel/Core/src/KestrelServerOptions.cs @@ -22,13 +22,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core /// public class KestrelServerOptions { - /// - /// Configures the endpoints that Kestrel should listen to. - /// - /// - /// If this list is empty, the server.urls setting (e.g. UseUrls) is used. - /// - internal List ListenOptions { get; } = new List(); + // The following two lists configure the endpoints that Kestrel should listen to. If both lists are empty, the "urls" config setting (e.g. UseUrls) is used. + internal List CodeBackedListenOptions { get; } = new List(); + internal List ConfigurationBackedListenOptions { get; } = new List(); + internal IEnumerable ListenOptions => CodeBackedListenOptions.Concat(ConfigurationBackedListenOptions); + + // For testing and debugging. + internal List OptionsInUse { get; } = new List(); /// /// Gets or sets whether the Server header should be included in each response. @@ -202,20 +202,31 @@ private void EnsureDefaultCert() /// /// Creates a configuration loader for setting up Kestrel. /// - public KestrelConfigurationLoader Configure() - { - var loader = new KestrelConfigurationLoader(this, new ConfigurationBuilder().Build()); - ConfigurationLoader = loader; - return loader; - } + /// A for configuring endpoints. + public KestrelConfigurationLoader Configure() => Configure(new ConfigurationBuilder().Build()); + + /// + /// Creates a configuration loader for setting up Kestrel that takes an as input. + /// This configuration must be scoped to the configuration section for Kestrel. + /// Call to enable dynamic endpoint binding updates. + /// + /// The configuration section for Kestrel. + /// A for further endpoint configuration. + public KestrelConfigurationLoader Configure(IConfiguration config) => Configure(config, reloadOnChange: false); /// - /// Creates a configuration loader for setting up Kestrel that takes an IConfiguration as input. + /// Creates a configuration loader for setting up Kestrel that takes an as input. /// This configuration must be scoped to the configuration section for Kestrel. /// - public KestrelConfigurationLoader Configure(IConfiguration config) + /// The configuration section for Kestrel. + /// + /// If , Kestrel will dynamically update endpoint bindings when configuration changes. + /// This will only reload endpoints defined in the "Endpoints" section of your . Endpoints defined in code will not be reloaded. + /// + /// A for further endpoint configuration. + public KestrelConfigurationLoader Configure(IConfiguration config, bool reloadOnChange) { - var loader = new KestrelConfigurationLoader(this, config); + var loader = new KestrelConfigurationLoader(this, config, reloadOnChange); ConfigurationLoader = loader; return loader; } @@ -286,7 +297,7 @@ public void Listen(EndPoint endPoint, Action configure) var listenOptions = new ListenOptions(endPoint); ApplyEndpointDefaults(listenOptions); configure(listenOptions); - ListenOptions.Add(listenOptions); + CodeBackedListenOptions.Add(listenOptions); } /// @@ -309,7 +320,7 @@ public void ListenLocalhost(int port, Action configure) var listenOptions = new LocalhostListenOptions(port); ApplyEndpointDefaults(listenOptions); configure(listenOptions); - ListenOptions.Add(listenOptions); + CodeBackedListenOptions.Add(listenOptions); } /// @@ -330,7 +341,7 @@ public void ListenAnyIP(int port, Action configure) var listenOptions = new AnyIPListenOptions(port); ApplyEndpointDefaults(listenOptions); configure(listenOptions); - ListenOptions.Add(listenOptions); + CodeBackedListenOptions.Add(listenOptions); } /// @@ -364,7 +375,7 @@ public void ListenUnixSocket(string socketPath, Action configure) var listenOptions = new ListenOptions(socketPath); ApplyEndpointDefaults(listenOptions); configure(listenOptions); - ListenOptions.Add(listenOptions); + CodeBackedListenOptions.Add(listenOptions); } /// @@ -389,7 +400,7 @@ public void ListenHandle(ulong handle, Action configure) var listenOptions = new ListenOptions(handle); ApplyEndpointDefaults(listenOptions); configure(listenOptions); - ListenOptions.Add(listenOptions); + CodeBackedListenOptions.Add(listenOptions); } } } diff --git a/src/Servers/Kestrel/Core/src/ListenOptions.cs b/src/Servers/Kestrel/Core/src/ListenOptions.cs index 21fb635d5939..bc90bb2fbddd 100644 --- a/src/Servers/Kestrel/Core/src/ListenOptions.cs +++ b/src/Servers/Kestrel/Core/src/ListenOptions.cs @@ -17,6 +17,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core /// public class ListenOptions : IConnectionBuilder, IMultiplexedConnectionBuilder { + internal static readonly HttpProtocols DefaultHttpProtocols = HttpProtocols.Http1AndHttp2; + internal readonly List> _middleware = new List>(); internal readonly List> _multiplexedMiddleware = new List>(); @@ -42,6 +44,9 @@ internal ListenOptions(ulong fileHandle, FileHandleType handleType) public EndPoint EndPoint { get; internal set; } + // For comparing bound endpoints to changed config during endpoint config reload. + internal EndpointConfig EndpointConfig { get; set; } + // IPEndPoint is mutable so port 0 can be updated to the bound port. /// /// The to bind to. @@ -71,7 +76,7 @@ internal ListenOptions(ulong fileHandle, FileHandleType handleType) /// The protocols enabled on this endpoint. /// /// Defaults to HTTP/1.x and HTTP/2. - public HttpProtocols Protocols { get; set; } = HttpProtocols.Http1AndHttp2; + public HttpProtocols Protocols { get; set; } = DefaultHttpProtocols; public IServiceProvider ApplicationServices => KestrelServerOptions?.ApplicationServices; @@ -79,16 +84,10 @@ internal string Scheme { get { - if (IsHttp) - { - return IsTls ? "https" : "http"; - } - return "tcp"; + return IsTls ? "https" : "http"; } } - internal bool IsHttp { get; set; } = true; - internal bool IsTls { get; set; } /// diff --git a/src/Servers/Kestrel/Core/src/LocalhostListenOptions.cs b/src/Servers/Kestrel/Core/src/LocalhostListenOptions.cs index 1c465cb14720..6a0ef4c52d9f 100644 --- a/src/Servers/Kestrel/Core/src/LocalhostListenOptions.cs +++ b/src/Servers/Kestrel/Core/src/LocalhostListenOptions.cs @@ -73,7 +73,8 @@ internal ListenOptions Clone(IPAddress address) { KestrelServerOptions = KestrelServerOptions, Protocols = Protocols, - IsTls = IsTls + IsTls = IsTls, + EndpointConfig = EndpointConfig }; options._middleware.AddRange(_middleware); diff --git a/src/Servers/Kestrel/Core/src/Middleware/ConnectionLimitMiddleware.cs b/src/Servers/Kestrel/Core/src/Middleware/ConnectionLimitMiddleware.cs index c999063f1258..650c1d54cb94 100644 --- a/src/Servers/Kestrel/Core/src/Middleware/ConnectionLimitMiddleware.cs +++ b/src/Servers/Kestrel/Core/src/Middleware/ConnectionLimitMiddleware.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Server.Kestrel.Core.Features; @@ -8,33 +9,32 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal { - internal class ConnectionLimitMiddleware + internal class ConnectionLimitMiddleware where T : BaseConnectionContext { - private readonly ConnectionDelegate _next; + private readonly Func _next; private readonly ResourceCounter _concurrentConnectionCounter; private readonly IKestrelTrace _trace; - public ConnectionLimitMiddleware(ConnectionDelegate next, long connectionLimit, IKestrelTrace trace) + public ConnectionLimitMiddleware(Func next, long connectionLimit, IKestrelTrace trace) : this(next, ResourceCounter.Quota(connectionLimit), trace) { } // For Testing - internal ConnectionLimitMiddleware(ConnectionDelegate next, ResourceCounter concurrentConnectionCounter, IKestrelTrace trace) + internal ConnectionLimitMiddleware(Func next, ResourceCounter concurrentConnectionCounter, IKestrelTrace trace) { _next = next; _concurrentConnectionCounter = concurrentConnectionCounter; _trace = trace; } - public async Task OnConnectionAsync(ConnectionContext connection) + public async Task OnConnectionAsync(T connection) { if (!_concurrentConnectionCounter.TryLockOne()) { KestrelEventSource.Log.ConnectionRejected(connection.ConnectionId); _trace.ConnectionRejected(connection.ConnectionId); - connection.Transport.Input.Complete(); - connection.Transport.Output.Complete(); + await connection.DisposeAsync(); return; } diff --git a/src/Servers/Kestrel/Core/test/AddressBinderTests.cs b/src/Servers/Kestrel/Core/test/AddressBinderTests.cs index d4aaee32b78f..798409fb2d7e 100644 --- a/src/Servers/Kestrel/Core/test/AddressBinderTests.cs +++ b/src/Servers/Kestrel/Core/test/AddressBinderTests.cs @@ -116,14 +116,19 @@ public void ParseAddressIP(string address, string ip, int port, bool isHttps) public async Task WrapsAddressInUseExceptionAsIOException() { var addresses = new ServerAddressesFeature(); - addresses.Addresses.Add("http://localhost:5000"); + addresses.InternalCollection.Add("http://localhost:5000"); var options = new KestrelServerOptions(); + var addressBindContext = new AddressBindContext + { + ServerAddressesFeature = addresses, + ServerOptions = options, + Logger = NullLogger.Instance, + CreateBinding = endpoint => throw new AddressInUseException("already in use"), + }; + await Assert.ThrowsAsync(() => - AddressBinder.BindAsync(addresses, - options, - NullLogger.Instance, - endpoint => throw new AddressInUseException("already in use"))); + AddressBinder.BindAsync(options.ListenOptions, addressBindContext)); } [Theory] @@ -134,16 +139,18 @@ public async Task FallbackToIPv4WhenIPv6AnyBindFails(string address) { var logger = new MockLogger(); var addresses = new ServerAddressesFeature(); - addresses.Addresses.Add(address); + addresses.InternalCollection.Add(address); var options = new KestrelServerOptions(); var ipV6Attempt = false; var ipV4Attempt = false; - await AddressBinder.BindAsync(addresses, - options, - logger, - endpoint => + var addressBindContext = new AddressBindContext + { + ServerAddressesFeature = addresses, + ServerOptions = options, + Logger = logger, + CreateBinding = endpoint => { if (endpoint.IPEndPoint.Address == IPAddress.IPv6Any) { @@ -157,7 +164,10 @@ await AddressBinder.BindAsync(addresses, } return Task.CompletedTask; - }); + }, + }; + + await AddressBinder.BindAsync(options.ListenOptions, addressBindContext); Assert.True(ipV4Attempt, "Should have attempted to bind to IPAddress.Any"); Assert.True(ipV6Attempt, "Should have attempted to bind to IPAddress.IPv6Any"); @@ -188,11 +198,20 @@ public async Task DefaultAddressBinderWithoutDevCertButHttpsConfiguredBindsToHtt }); var endpoints = new List(); - await AddressBinder.BindAsync(addresses, options, logger, listenOptions => + + var addressBindContext = new AddressBindContext { - endpoints.Add(listenOptions); - return Task.CompletedTask; - }); + ServerAddressesFeature = addresses, + ServerOptions = options, + Logger = logger, + CreateBinding = listenOptions => + { + endpoints.Add(listenOptions); + return Task.CompletedTask; + }, + }; + + await AddressBinder.BindAsync(options.ListenOptions, addressBindContext); Assert.Contains(endpoints, e => e.IPEndPoint.Port == 5000 && !e.IsTls); Assert.Contains(endpoints, e => e.IPEndPoint.Port == 5001 && e.IsTls); diff --git a/src/Servers/Kestrel/Core/test/ConnectionDispatcherTests.cs b/src/Servers/Kestrel/Core/test/ConnectionDispatcherTests.cs index 33bba9c48ea1..fc426a210f51 100644 --- a/src/Servers/Kestrel/Core/test/ConnectionDispatcherTests.cs +++ b/src/Servers/Kestrel/Core/test/ConnectionDispatcherTests.cs @@ -29,8 +29,9 @@ public async Task OnConnectionCreatesLogScopeWithConnectionId() var connection = new Mock { CallBase = true }.Object; connection.ConnectionClosed = new CancellationToken(canceled: true); - var kestrelConnection = new KestrelConnection(0, serviceContext, _ => tcs.Task, connection, serviceContext.Log); - serviceContext.ConnectionManager.AddConnection(0, kestrelConnection); + var transportConnectionManager = new TransportConnectionManager(serviceContext.ConnectionManager); + var kestrelConnection = new KestrelConnection(0, serviceContext, transportConnectionManager, _ => tcs.Task, connection, serviceContext.Log); + transportConnectionManager.AddConnection(0, kestrelConnection); var task = kestrelConnection.ExecuteAsync(); @@ -61,7 +62,7 @@ public async Task StartAcceptingConnectionsAsyncLogsIfAcceptAsyncThrows() var logger = ((TestKestrelTrace)serviceContext.Log).Logger; logger.ThrowOnCriticalErrors = false; - var dispatcher = new ConnectionDispatcher(serviceContext, _ => Task.CompletedTask); + var dispatcher = new ConnectionDispatcher(serviceContext, _ => Task.CompletedTask, new TransportConnectionManager(serviceContext.ConnectionManager)); await dispatcher.StartAcceptingConnections(new ThrowingListener()); @@ -79,8 +80,9 @@ public async Task OnConnectionFiresOnCompleted() var connection = new Mock { CallBase = true }.Object; connection.ConnectionClosed = new CancellationToken(canceled: true); - var kestrelConnection = new KestrelConnection(0, serviceContext, _ => Task.CompletedTask, connection, serviceContext.Log); - serviceContext.ConnectionManager.AddConnection(0, kestrelConnection); + var transportConnectionManager = new TransportConnectionManager(serviceContext.ConnectionManager); + var kestrelConnection = new KestrelConnection(0, serviceContext, transportConnectionManager, _ => Task.CompletedTask, connection, serviceContext.Log); + transportConnectionManager.AddConnection(0, kestrelConnection); var completeFeature = kestrelConnection.TransportConnection.Features.Get(); Assert.NotNull(completeFeature); @@ -100,8 +102,9 @@ public async Task OnConnectionOnCompletedExceptionCaught() var logger = ((TestKestrelTrace)serviceContext.Log).Logger; var connection = new Mock { CallBase = true }.Object; connection.ConnectionClosed = new CancellationToken(canceled: true); - var kestrelConnection = new KestrelConnection(0, serviceContext, _ => Task.CompletedTask, connection, serviceContext.Log); - serviceContext.ConnectionManager.AddConnection(0, kestrelConnection); + var transportConnectionManager = new TransportConnectionManager(serviceContext.ConnectionManager); + var kestrelConnection = new KestrelConnection(0, serviceContext, transportConnectionManager, _ => Task.CompletedTask, connection, serviceContext.Log); + transportConnectionManager.AddConnection(0, kestrelConnection); var completeFeature = kestrelConnection.TransportConnection.Features.Get(); Assert.NotNull(completeFeature); @@ -117,7 +120,7 @@ public async Task OnConnectionOnCompletedExceptionCaught() Assert.Equal("An error occurred running an IConnectionCompleteFeature.OnCompleted callback.", errors[0].Message); } - private class ThrowingListener : IConnectionListener + private class ThrowingListener : IConnectionListener { public EndPoint EndPoint { get; set; } diff --git a/src/Servers/Kestrel/Core/test/HttpConnectionManagerTests.cs b/src/Servers/Kestrel/Core/test/HttpConnectionManagerTests.cs index 42867902ff0a..2d3a7a4cbccf 100644 --- a/src/Servers/Kestrel/Core/test/HttpConnectionManagerTests.cs +++ b/src/Servers/Kestrel/Core/test/HttpConnectionManagerTests.cs @@ -44,9 +44,9 @@ private void UnrootedConnectionsGetRemovedFromHeartbeatInnerScope( var serviceContext = new TestServiceContext(); var mock = new Mock() { CallBase = true }; mock.Setup(m => m.ConnectionId).Returns(connectionId); - var httpConnection = new KestrelConnection(0, serviceContext, _ => Task.CompletedTask, mock.Object, Mock.Of()); - - httpConnectionManager.AddConnection(0, httpConnection); + var transportConnectionManager = new TransportConnectionManager(httpConnectionManager); + var httpConnection = new KestrelConnection(0, serviceContext, transportConnectionManager, _ => Task.CompletedTask, mock.Object, Mock.Of()); + transportConnectionManager.AddConnection(0, httpConnection); var connectionCount = 0; httpConnectionManager.Walk(_ => connectionCount++); diff --git a/src/Servers/Kestrel/Core/test/KestrelServerOptionsTests.cs b/src/Servers/Kestrel/Core/test/KestrelServerOptionsTests.cs index c8243b6025ad..e9d99c764a28 100644 --- a/src/Servers/Kestrel/Core/test/KestrelServerOptionsTests.cs +++ b/src/Servers/Kestrel/Core/test/KestrelServerOptionsTests.cs @@ -22,7 +22,7 @@ public void ConfigureEndpointDefaultsAppliesToNewEndpoints() var options = new KestrelServerOptions(); options.ListenLocalhost(5000); - Assert.Equal(HttpProtocols.Http1AndHttp2, options.ListenOptions[0].Protocols); + Assert.Equal(HttpProtocols.Http1AndHttp2, options.CodeBackedListenOptions[0].Protocols); options.ConfigureEndpointDefaults(opt => { @@ -34,20 +34,20 @@ public void ConfigureEndpointDefaultsAppliesToNewEndpoints() // ConfigureEndpointDefaults runs before this callback Assert.Equal(HttpProtocols.Http1, opt.Protocols); }); - Assert.Equal(HttpProtocols.Http1, options.ListenOptions[1].Protocols); + Assert.Equal(HttpProtocols.Http1, options.CodeBackedListenOptions[1].Protocols); options.ListenLocalhost(5000, opt => { Assert.Equal(HttpProtocols.Http1, opt.Protocols); opt.Protocols = HttpProtocols.Http2; // Can be overriden }); - Assert.Equal(HttpProtocols.Http2, options.ListenOptions[2].Protocols); + Assert.Equal(HttpProtocols.Http2, options.CodeBackedListenOptions[2].Protocols); options.ListenAnyIP(5000, opt => { opt.Protocols = HttpProtocols.Http2; }); - Assert.Equal(HttpProtocols.Http2, options.ListenOptions[3].Protocols); + Assert.Equal(HttpProtocols.Http2, options.CodeBackedListenOptions[3].Protocols); } } } diff --git a/src/Servers/Kestrel/Core/test/KestrelServerTests.cs b/src/Servers/Kestrel/Core/test/KestrelServerTests.cs index 56454bb7a41a..f7dcef5b3866 100644 --- a/src/Servers/Kestrel/Core/test/KestrelServerTests.cs +++ b/src/Servers/Kestrel/Core/test/KestrelServerTests.cs @@ -14,9 +14,11 @@ using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using Microsoft.Extensions.Primitives; using Microsoft.Net.Http.Headers; using Moq; using Xunit; @@ -61,8 +63,8 @@ public void StartWithHttpsAddressConfiguresHttpsEndpoints() StartDummyApplication(server); - Assert.True(server.Options.ListenOptions.Any()); - Assert.True(server.Options.ListenOptions[0].IsTls); + Assert.True(server.Options.OptionsInUse.Any()); + Assert.True(server.Options.OptionsInUse[0].IsTls); } } @@ -248,7 +250,7 @@ public async Task StopAsyncCallsCompleteWhenFirstCallCompletes() { var options = new KestrelServerOptions { - ListenOptions = + CodeBackedListenOptions = { new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) } @@ -305,7 +307,7 @@ public async Task StopAsyncCallsCompleteWithThrownException() { var options = new KestrelServerOptions { - ListenOptions = + CodeBackedListenOptions = { new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) } @@ -365,7 +367,7 @@ public async Task StopAsyncDispatchesSubsequentStopAsyncContinuations() { var options = new KestrelServerOptions { - ListenOptions = + CodeBackedListenOptions = { new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) } @@ -426,7 +428,7 @@ public void StartingServerInitializesHeartbeat() { ServerOptions = { - ListenOptions = + CodeBackedListenOptions = { new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) } @@ -455,6 +457,220 @@ public void StartingServerInitializesHeartbeat() } } + [Fact] + public async Task ReloadsOnConfigurationChangeWhenOptedIn() + { + var currentConfig = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Endpoints:A:Url", "http://*:5000"), + new KeyValuePair("Endpoints:B:Url", "http://*:5001"), + }).Build(); + + Func changeCallback = null; + TaskCompletionSource changeCallbackRegisteredTcs = null; + + var mockChangeToken = new Mock(); + mockChangeToken.Setup(t => t.RegisterChangeCallback(It.IsAny>(), It.IsAny())).Returns, object>((callback, state) => + { + changeCallbackRegisteredTcs?.SetResult(null); + + changeCallback = () => + { + changeCallbackRegisteredTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + callback(state); + return changeCallbackRegisteredTcs.Task; + }; + + return Mock.Of(); + }); + + var mockConfig = new Mock(); + mockConfig.Setup(c => c.GetSection(It.IsAny())).Returns(name => currentConfig.GetSection(name)); + mockConfig.Setup(c => c.GetChildren()).Returns(() => currentConfig.GetChildren()); + mockConfig.Setup(c => c.GetReloadToken()).Returns(() => mockChangeToken.Object); + + var mockLoggerFactory = new Mock(); + mockLoggerFactory.Setup(m => m.CreateLogger(It.IsAny())).Returns(Mock.Of()); + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(mockLoggerFactory.Object); + serviceCollection.AddSingleton(Mock.Of>()); + + var options = new KestrelServerOptions + { + ApplicationServices = serviceCollection.BuildServiceProvider(), + }; + + options.Configure(mockConfig.Object, reloadOnChange: true); + + var mockTransports = new List>(); + var mockTransportFactory = new Mock(); + mockTransportFactory + .Setup(transportFactory => transportFactory.BindAsync(It.IsAny(), It.IsAny())) + .Returns((e, token) => + { + var mockTransport = new Mock(); + mockTransport + .Setup(transport => transport.AcceptAsync(It.IsAny())) + .Returns(new ValueTask(result: null)); + mockTransport + .Setup(transport => transport.EndPoint) + .Returns(e); + + mockTransports.Add(mockTransport); + + return new ValueTask(mockTransport.Object); + }); + + // Don't use "using". Dispose() could hang if test fails. + var server = new KestrelServer(Options.Create(options), new List() { mockTransportFactory.Object }, mockLoggerFactory.Object); + + await server.StartAsync(new DummyApplication(), CancellationToken.None).DefaultTimeout(); + + mockTransportFactory.Verify(f => f.BindAsync(new IPEndPoint(IPAddress.IPv6Any, 5000), It.IsAny()), Times.Once); + mockTransportFactory.Verify(f => f.BindAsync(new IPEndPoint(IPAddress.IPv6Any, 5001), It.IsAny()), Times.Once); + mockTransportFactory.Verify(f => f.BindAsync(new IPEndPoint(IPAddress.IPv6Any, 5002), It.IsAny()), Times.Never); + mockTransportFactory.Verify(f => f.BindAsync(new IPEndPoint(IPAddress.IPv6Any, 5003), It.IsAny()), Times.Never); + + Assert.Equal(2, mockTransports.Count); + + foreach (var mockTransport in mockTransports) + { + mockTransport.Verify(t => t.UnbindAsync(It.IsAny()), Times.Never); + } + + currentConfig = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Endpoints:A:Url", "http://*:5000"), + new KeyValuePair("Endpoints:B:Url", "http://*:5002"), + new KeyValuePair("Endpoints:C:Url", "http://*:5003"), + }).Build(); + + await changeCallback().DefaultTimeout(); + + mockTransportFactory.Verify(f => f.BindAsync(new IPEndPoint(IPAddress.IPv6Any, 5000), It.IsAny()), Times.Once); + mockTransportFactory.Verify(f => f.BindAsync(new IPEndPoint(IPAddress.IPv6Any, 5001), It.IsAny()), Times.Once); + mockTransportFactory.Verify(f => f.BindAsync(new IPEndPoint(IPAddress.IPv6Any, 5002), It.IsAny()), Times.Once); + mockTransportFactory.Verify(f => f.BindAsync(new IPEndPoint(IPAddress.IPv6Any, 5003), It.IsAny()), Times.Once); + + Assert.Equal(4, mockTransports.Count); + + foreach (var mockTransport in mockTransports) + { + if (((IPEndPoint)mockTransport.Object.EndPoint).Port == 5001) + { + mockTransport.Verify(t => t.UnbindAsync(It.IsAny()), Times.Once); + } + else + { + mockTransport.Verify(t => t.UnbindAsync(It.IsAny()), Times.Never); + } + } + + currentConfig = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Endpoints:A:Url", "http://*:5000"), + new KeyValuePair("Endpoints:B:Url", "http://*:5002"), + new KeyValuePair("Endpoints:C:Url", "http://*:5003"), + new KeyValuePair("Endpoints:C:Protocols", "Http1"), + }).Build(); + + await changeCallback().DefaultTimeout(); + + mockTransportFactory.Verify(f => f.BindAsync(new IPEndPoint(IPAddress.IPv6Any, 5000), It.IsAny()), Times.Once); + mockTransportFactory.Verify(f => f.BindAsync(new IPEndPoint(IPAddress.IPv6Any, 5001), It.IsAny()), Times.Once); + mockTransportFactory.Verify(f => f.BindAsync(new IPEndPoint(IPAddress.IPv6Any, 5002), It.IsAny()), Times.Once); + mockTransportFactory.Verify(f => f.BindAsync(new IPEndPoint(IPAddress.IPv6Any, 5003), It.IsAny()), Times.Exactly(2)); + + Assert.Equal(5, mockTransports.Count); + + var firstPort5003TransportChecked = false; + + foreach (var mockTransport in mockTransports) + { + var port = ((IPEndPoint)mockTransport.Object.EndPoint).Port; + if (port == 5001) + { + mockTransport.Verify(t => t.UnbindAsync(It.IsAny()), Times.Once); + } + else if (port == 5003 && !firstPort5003TransportChecked) + { + mockTransport.Verify(t => t.UnbindAsync(It.IsAny()), Times.Once); + firstPort5003TransportChecked = true; + } + else + { + mockTransport.Verify(t => t.UnbindAsync(It.IsAny()), Times.Never); + } + } + + await server.StopAsync(CancellationToken.None).DefaultTimeout(); + + foreach (var mockTransport in mockTransports) + { + mockTransport.Verify(t => t.UnbindAsync(It.IsAny()), Times.Once); + } + } + + [Fact] + public async Task DoesNotReloadOnConfigurationChangeByDefault() + { + var currentConfig = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Endpoints:A:Url", "http://*:5000"), + new KeyValuePair("Endpoints:B:Url", "http://*:5001"), + }).Build(); + + var mockConfig = new Mock(); + mockConfig.Setup(c => c.GetSection(It.IsAny())).Returns(name => currentConfig.GetSection(name)); + mockConfig.Setup(c => c.GetChildren()).Returns(() => currentConfig.GetChildren()); + + var mockLoggerFactory = new Mock(); + mockLoggerFactory.Setup(m => m.CreateLogger(It.IsAny())).Returns(Mock.Of()); + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(mockLoggerFactory.Object); + serviceCollection.AddSingleton(Mock.Of>()); + + var options = new KestrelServerOptions + { + ApplicationServices = serviceCollection.BuildServiceProvider(), + }; + + options.Configure(mockConfig.Object); + + var mockTransports = new List>(); + var mockTransportFactory = new Mock(); + mockTransportFactory + .Setup(transportFactory => transportFactory.BindAsync(It.IsAny(), It.IsAny())) + .Returns((e, token) => + { + var mockTransport = new Mock(); + mockTransport + .Setup(transport => transport.AcceptAsync(It.IsAny())) + .Returns(new ValueTask(result: null)); + mockTransport + .Setup(transport => transport.EndPoint) + .Returns(e); + + mockTransports.Add(mockTransport); + + return new ValueTask(mockTransport.Object); + }); + + // Don't use "using". Dispose() could hang if test fails. + var server = new KestrelServer(Options.Create(options), new List() { mockTransportFactory.Object }, mockLoggerFactory.Object); + + await server.StartAsync(new DummyApplication(), CancellationToken.None).DefaultTimeout(); + + mockTransportFactory.Verify(f => f.BindAsync(new IPEndPoint(IPAddress.IPv6Any, 5000), It.IsAny()), Times.Once); + mockTransportFactory.Verify(f => f.BindAsync(new IPEndPoint(IPAddress.IPv6Any, 5001), It.IsAny()), Times.Once); + + mockConfig.Verify(c => c.GetReloadToken(), Times.Never); + + await server.StopAsync(CancellationToken.None).DefaultTimeout(); + } + private static KestrelServer CreateServer(KestrelServerOptions options, ILogger testLogger) { return new KestrelServer(Options.Create(options), new List() { new MockTransportFactory() }, new LoggerFactory(new[] { new KestrelTestLoggerProvider(testLogger) })); diff --git a/src/Servers/Kestrel/Kestrel/test/ConfigurationReaderTests.cs b/src/Servers/Kestrel/Kestrel/test/ConfigurationReaderTests.cs index a8b36b29f3bd..22a24014cef5 100644 --- a/src/Servers/Kestrel/Kestrel/test/ConfigurationReaderTests.cs +++ b/src/Servers/Kestrel/Kestrel/test/ConfigurationReaderTests.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -97,8 +97,7 @@ public void ReadEndpointWithMissingUrl_Throws() { new KeyValuePair("Endpoints:End1", ""), }).Build(); - var reader = new ConfigurationReader(config); - Assert.Throws(() => reader.Endpoints); + Assert.Throws(() => new ConfigurationReader(config)); } [Fact] @@ -108,8 +107,7 @@ public void ReadEndpointWithEmptyUrl_Throws() { new KeyValuePair("Endpoints:End1:Url", ""), }).Build(); - var reader = new ConfigurationReader(config); - Assert.Throws(() => reader.Endpoints); + Assert.Throws(() => new ConfigurationReader(config)); } [Fact] diff --git a/src/Servers/Kestrel/Kestrel/test/KestrelConfigurationBuilderTests.cs b/src/Servers/Kestrel/Kestrel/test/KestrelConfigurationBuilderTests.cs index a8035946ba3d..2b38083cfd1b 100644 --- a/src/Servers/Kestrel/Kestrel/test/KestrelConfigurationBuilderTests.cs +++ b/src/Servers/Kestrel/Kestrel/test/KestrelConfigurationBuilderTests.cs @@ -46,7 +46,7 @@ public void ConfigureNamedEndpoint_OnlyRunForMatchingConfig() .Load(); Assert.Single(serverOptions.ListenOptions); - Assert.Equal(5001, serverOptions.ListenOptions[0].IPEndPoint.Port); + Assert.Equal(5001, serverOptions.ConfigurationBackedListenOptions[0].IPEndPoint.Port); Assert.True(found); } @@ -64,7 +64,7 @@ public void ConfigureEndpoint_OnlyRunWhenBuildIsCalled() serverOptions.ConfigurationLoader.Load(); Assert.Single(serverOptions.ListenOptions); - Assert.Equal(5001, serverOptions.ListenOptions[0].IPEndPoint.Port); + Assert.Equal(5001, serverOptions.CodeBackedListenOptions[0].IPEndPoint.Port); Assert.True(run); } @@ -82,13 +82,13 @@ public void CallBuildTwice_OnlyRunsOnce() builder.Load(); Assert.Single(serverOptions.ListenOptions); - Assert.Equal(5001, serverOptions.ListenOptions[0].IPEndPoint.Port); + Assert.Equal(5001, serverOptions.CodeBackedListenOptions[0].IPEndPoint.Port); Assert.NotNull(serverOptions.ConfigurationLoader); builder.Load(); Assert.Single(serverOptions.ListenOptions); - Assert.Equal(5001, serverOptions.ListenOptions[0].IPEndPoint.Port); + Assert.Equal(5001, serverOptions.CodeBackedListenOptions[0].IPEndPoint.Port); Assert.NotNull(serverOptions.ConfigurationLoader); } @@ -117,9 +117,9 @@ public void Configure_IsReplaceable() serverOptions.ConfigurationLoader.Load(); - Assert.Equal(2, serverOptions.ListenOptions.Count); - Assert.Equal(5002, serverOptions.ListenOptions[0].IPEndPoint.Port); - Assert.Equal(5003, serverOptions.ListenOptions[1].IPEndPoint.Port); + Assert.Equal(2, serverOptions.ListenOptions.Count()); + Assert.Equal(5002, serverOptions.ConfigurationBackedListenOptions[0].IPEndPoint.Port); + Assert.Equal(5003, serverOptions.CodeBackedListenOptions[0].IPEndPoint.Port); Assert.False(run1); Assert.True(run2); @@ -166,8 +166,8 @@ public void ConfigureDefaultsAppliesToNewConfigureEndpoints() Assert.True(ran1); Assert.True(ran2); - Assert.True(serverOptions.ListenOptions[0].IsTls); - Assert.False(serverOptions.ListenOptions[1].IsTls); + Assert.True(serverOptions.ConfigurationBackedListenOptions[0].IsTls); + Assert.False(serverOptions.CodeBackedListenOptions[0].IsTls); } [Fact] @@ -208,8 +208,8 @@ public void ConfigureEndpointDefaultCanEnableHttps() Assert.True(ran2); // You only get Https once per endpoint. - Assert.True(serverOptions.ListenOptions[0].IsTls); - Assert.True(serverOptions.ListenOptions[1].IsTls); + Assert.True(serverOptions.ConfigurationBackedListenOptions[0].IsTls); + Assert.True(serverOptions.CodeBackedListenOptions[0].IsTls); } [Fact] @@ -477,6 +477,112 @@ public void Latin1RequestHeadersReadFromConfig() Assert.True(options.Latin1RequestHeaders); } + [Fact] + public void Reload_IdentifiesEndpointsToStartAndStop() + { + var serverOptions = CreateServerOptions(); + + var config = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Endpoints:A:Url", "http://*:5000"), + new KeyValuePair("Endpoints:B:Url", "http://*:5001"), + }).Build(); + + serverOptions.Configure(config).Load(); + + Assert.Equal(2, serverOptions.ConfigurationBackedListenOptions.Count); + Assert.Equal(5000, serverOptions.ConfigurationBackedListenOptions[0].IPEndPoint.Port); + Assert.Equal(5001, serverOptions.ConfigurationBackedListenOptions[1].IPEndPoint.Port); + + serverOptions.ConfigurationLoader.Configuration = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Endpoints:A:Url", "http://*:5000"), + new KeyValuePair("Endpoints:B:Url", "http://*:5002"), + new KeyValuePair("Endpoints:C:Url", "http://*:5003"), + }).Build(); + + var (endpointsToStop, endpointsToStart) = serverOptions.ConfigurationLoader.Reload(); + + Assert.Single(endpointsToStop); + Assert.Equal(5001, endpointsToStop[0].IPEndPoint.Port); + + Assert.Equal(2, endpointsToStart.Count); + Assert.Equal(5002, endpointsToStart[0].IPEndPoint.Port); + Assert.Equal(5003, endpointsToStart[1].IPEndPoint.Port); + + Assert.Equal(3, serverOptions.ConfigurationBackedListenOptions.Count); + Assert.Equal(5000, serverOptions.ConfigurationBackedListenOptions[0].IPEndPoint.Port); + Assert.Same(endpointsToStart[0], serverOptions.ConfigurationBackedListenOptions[1]); + Assert.Same(endpointsToStart[1], serverOptions.ConfigurationBackedListenOptions[2]); + } + + [Fact] + public void Reload_IdentifiesEndpointsWithChangedDefaults() + { + var serverOptions = CreateServerOptions(); + + var config = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Endpoints:DefaultProtocol:Url", "http://*:5000"), + new KeyValuePair("Endpoints:NonDefaultProtocol:Url", "http://*:5001"), + new KeyValuePair("Endpoints:NonDefaultProtocol:Protocols", "Http1AndHttp2"), + }).Build(); + + serverOptions.Configure(config).Load(); + + serverOptions.ConfigurationLoader.Configuration = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Endpoints:DefaultProtocol:Url", "http://*:5000"), + new KeyValuePair("Endpoints:NonDefaultProtocol:Url", "http://*:5001"), + new KeyValuePair("Endpoints:NonDefaultProtocol:Protocols", "Http1AndHttp2"), + new KeyValuePair("EndpointDefaults:Protocols", "Http1"), + }).Build(); + + var (endpointsToStop, endpointsToStart) = serverOptions.ConfigurationLoader.Reload(); + + Assert.Single(endpointsToStop); + Assert.Single(endpointsToStart); + + Assert.Equal(5000, endpointsToStop[0].IPEndPoint.Port); + Assert.Equal(HttpProtocols.Http1AndHttp2, endpointsToStop[0].Protocols); + Assert.Equal(5000, endpointsToStart[0].IPEndPoint.Port); + Assert.Equal(HttpProtocols.Http1, endpointsToStart[0].Protocols); + } + + [Fact] + public void Reload_RerunsNamedEndpointConfigurationOnChange() + { + var foundChangedCount = 0; + var foundUnchangedCount = 0; + var serverOptions = CreateServerOptions(); + + var config = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Endpoints:Changed:Url", "http://*:5001"), + new KeyValuePair("Endpoints:Unchanged:Url", "http://*:5000"), + }).Build(); + + serverOptions.Configure(config) + .Endpoint("Changed", endpointOptions => foundChangedCount++) + .Endpoint("Unchanged", endpointOptions => foundUnchangedCount++) + .Endpoint("NotFound", endpointOptions => throw new NotImplementedException()) + .Load(); + + Assert.Equal(1, foundChangedCount); + Assert.Equal(1, foundUnchangedCount); + + serverOptions.ConfigurationLoader.Configuration = new ConfigurationBuilder().AddInMemoryCollection(new[] + { + new KeyValuePair("Endpoints:Changed:Url", "http://*:5002"), + new KeyValuePair("Endpoints:Unchanged:Url", "http://*:5000"), + }).Build(); + + serverOptions.ConfigurationLoader.Reload(); + + Assert.Equal(2, foundChangedCount); + Assert.Equal(1, foundUnchangedCount); + } + private static string GetCertificatePath() { var appData = Environment.GetEnvironmentVariable("APPDATA"); diff --git a/src/Servers/Kestrel/Transport.Libuv/test/LibuvTransportTests.cs b/src/Servers/Kestrel/Transport.Libuv/test/LibuvTransportTests.cs index db7ad6033aa8..12f802864d37 100644 --- a/src/Servers/Kestrel/Transport.Libuv/test/LibuvTransportTests.cs +++ b/src/Servers/Kestrel/Transport.Libuv/test/LibuvTransportTests.cs @@ -9,10 +9,13 @@ using System.Net.Http; using System.Net.Sockets; using System.Text; +using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal; using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests.TestHelpers; using Microsoft.AspNetCore.Testing; @@ -195,8 +198,9 @@ public async Task OneToTenThreads(int threadCount) await transport.BindAsync(); listenOptions.EndPoint = transport.EndPoint; - var dispatcher = new ConnectionDispatcher(serviceContext, listenOptions.Build()); - var acceptTask = dispatcher.StartAcceptingConnections(transport); + var transportConnectionManager = new TransportConnectionManager(serviceContext.ConnectionManager); + var dispatcher = new ConnectionDispatcher(serviceContext, c => listenOptions.Build()(c), transportConnectionManager); + var acceptTask = dispatcher.StartAcceptingConnections(new GenericConnectionListener(transport)); using (var client = new HttpClient()) { @@ -218,10 +222,31 @@ public async Task OneToTenThreads(int threadCount) await acceptTask; - if (!await serviceContext.ConnectionManager.CloseAllConnectionsAsync(default)) + if (!await transportConnectionManager.CloseAllConnectionsAsync(default)) { - await serviceContext.ConnectionManager.AbortAllConnectionsAsync(); + await transportConnectionManager.AbortAllConnectionsAsync(); } } + + private class GenericConnectionListener : IConnectionListener + { + private readonly IConnectionListener _connectionListener; + + public GenericConnectionListener(IConnectionListener connectionListener) + { + _connectionListener = connectionListener; + } + + public EndPoint EndPoint => _connectionListener.EndPoint; + + public ValueTask AcceptAsync(CancellationToken cancellationToken = default) + => _connectionListener.AcceptAsync(cancellationToken); + + public ValueTask UnbindAsync(CancellationToken cancellationToken = default) + => _connectionListener.UnbindAsync(); + + public ValueTask DisposeAsync() + => _connectionListener.DisposeAsync(); + } } } diff --git a/src/Servers/Kestrel/samples/SampleApp/Startup.cs b/src/Servers/Kestrel/samples/SampleApp/Startup.cs index 1e90bc1a52ad..3b61d0fe91dc 100644 --- a/src/Servers/Kestrel/samples/SampleApp/Startup.cs +++ b/src/Servers/Kestrel/samples/SampleApp/Startup.cs @@ -12,7 +12,6 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; -using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Https; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Hosting; @@ -72,8 +71,8 @@ public static Task Main(string[] args) .ConfigureAppConfiguration((hostingContext, config) => { var env = hostingContext.HostingEnvironment; - config.AddJsonFile("appsettings.json", optional: true) - .AddJsonFile($"appsettings.{env.EnvironmentName}.json", optional: true); + config.AddJsonFile("appsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile($"appsettings.{env.EnvironmentName}.json", optional: true, reloadOnChange: true); }) .UseKestrel((context, options) => { @@ -135,8 +134,9 @@ public static Task Main(string[] args) .LocalhostEndpoint(basePort + 7) .Load(); + // reloadOnChange: true is the default options - .Configure(context.Configuration.GetSection("Kestrel")) + .Configure(context.Configuration.GetSection("Kestrel"), reloadOnChange: true) .Endpoint("NamedEndpoint", opt => { diff --git a/src/Servers/Kestrel/samples/SampleApp/appsettings.Development.json b/src/Servers/Kestrel/samples/SampleApp/appsettings.Development.json index 741bd03aeefb..754ce4b71410 100644 --- a/src/Servers/Kestrel/samples/SampleApp/appsettings.Development.json +++ b/src/Servers/Kestrel/samples/SampleApp/appsettings.Development.json @@ -1,4 +1,4 @@ -{ +{ "Kestrel": { "Endpoints": { "NamedEndpoint": { "Url": "http://localhost:6000" }, diff --git a/src/Servers/Kestrel/shared/test/TransportTestHelpers/TestServer.cs b/src/Servers/Kestrel/shared/test/TransportTestHelpers/TestServer.cs index 7ff959bb1f4a..acfb02cd6dd7 100644 --- a/src/Servers/Kestrel/shared/test/TransportTestHelpers/TestServer.cs +++ b/src/Servers/Kestrel/shared/test/TransportTestHelpers/TestServer.cs @@ -43,7 +43,7 @@ public TestServer(RequestDelegate app, TestServiceContext context) } public TestServer(RequestDelegate app, TestServiceContext context, ListenOptions listenOptions) - : this(app, context, options => options.ListenOptions.Add(listenOptions), _ => { }) + : this(app, context, options => options.CodeBackedListenOptions.Add(listenOptions), _ => { }) { } @@ -55,7 +55,7 @@ public TestServer(RequestDelegate app, TestServiceContext context, Action { }) { } diff --git a/src/Servers/Kestrel/test/BindTests/AddressRegistrationTests.cs b/src/Servers/Kestrel/test/BindTests/AddressRegistrationTests.cs index 787d3fe7c8f0..9d6c376e3247 100644 --- a/src/Servers/Kestrel/test/BindTests/AddressRegistrationTests.cs +++ b/src/Servers/Kestrel/test/BindTests/AddressRegistrationTests.cs @@ -828,8 +828,8 @@ public async Task EndpointDefaultsConfig_CanSetProtocolForUrlsConfig(string inpu using (var host = hostBuilder.Build()) { await host.StartAsync(); - Assert.Single(capturedOptions.ListenOptions); - Assert.Equal(expected, capturedOptions.ListenOptions[0].Protocols); + Assert.Single(capturedOptions.OptionsInUse); + Assert.Equal(expected, capturedOptions.OptionsInUse[0].Protocols); await host.StopAsync(); } } diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/ConnectionLimitTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/ConnectionLimitTests.cs index 532629f0ddb9..e13b59f92123 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/ConnectionLimitTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/ConnectionLimitTests.cs @@ -5,6 +5,7 @@ using System.Net; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core; @@ -13,7 +14,6 @@ using Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests.TestTransport; using Microsoft.AspNetCore.Server.Kestrel.Tests; using Microsoft.AspNetCore.Testing; -using Microsoft.Extensions.Logging.Testing; using Xunit; namespace Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests @@ -211,7 +211,7 @@ private TestServer CreateServerWithMaxConnections(RequestDelegate app, ResourceC var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)); listenOptions.Use(next => { - var middleware = new ConnectionLimitMiddleware(next, concurrentConnectionCounter, serviceContext.Log); + var middleware = new ConnectionLimitMiddleware(c => next(c), concurrentConnectionCounter, serviceContext.Log); return middleware.OnConnectionAsync; }); diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/ResponseTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/ResponseTests.cs index 4dd609f85bc2..c95839996a11 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/ResponseTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/ResponseTests.cs @@ -4063,7 +4063,7 @@ private static async Task ResponseStatusCodeSetBeforeHttpContextDispose( }); await using (var server = new TestServer(handler, new TestServiceContext(loggerFactory), - options => options.ListenOptions.Add(new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0))), + options => options.CodeBackedListenOptions.Add(new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0))), services => services.AddSingleton(mockHttpContextFactory.Object))) { using (var connection = server.CreateConnection()) diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TestTransport/TestServer.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TestTransport/TestServer.cs index 0c8e40921da6..27c0395ed512 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TestTransport/TestServer.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TestTransport/TestServer.cs @@ -42,7 +42,7 @@ public TestServer(RequestDelegate app, TestServiceContext context) } public TestServer(RequestDelegate app, TestServiceContext context, ListenOptions listenOptions) - : this(app, context, options => options.ListenOptions.Add(listenOptions), _ => { }) + : this(app, context, options => options.CodeBackedListenOptions.Add(listenOptions), _ => { }) { } @@ -55,7 +55,7 @@ public TestServer(RequestDelegate app, TestServiceContext context, Action { }) {