diff --git a/src/Servers/Kestrel/Transport.Sockets/src/PublicAPI.Unshipped.txt b/src/Servers/Kestrel/Transport.Sockets/src/PublicAPI.Unshipped.txt index 5708e0985dfd..6e20ff457d44 100644 --- a/src/Servers/Kestrel/Transport.Sockets/src/PublicAPI.Unshipped.txt +++ b/src/Servers/Kestrel/Transport.Sockets/src/PublicAPI.Unshipped.txt @@ -9,4 +9,7 @@ static Microsoft.AspNetCore.Hosting.WebHostBuilderSocketExtensions.UseSockets(th static Microsoft.AspNetCore.Hosting.WebHostBuilderSocketExtensions.UseSockets(this Microsoft.AspNetCore.Hosting.IWebHostBuilder! hostBuilder, System.Action! configureOptions) -> Microsoft.AspNetCore.Hosting.IWebHostBuilder! static Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.CreateDefaultBoundListenSocket(System.Net.EndPoint! endpoint) -> System.Net.Sockets.Socket! Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.CreateBoundListenSocket.get -> System.Func! -Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.CreateBoundListenSocket.set -> void \ No newline at end of file +Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.CreateBoundListenSocket.set -> void +static Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.DefaultAcceptSocketAsync(System.Net.Sockets.Socket! listenSocket, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.ValueTask +Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.AcceptSocketAsync.get -> System.Func>! +Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.AcceptSocketAsync.set -> void diff --git a/src/Servers/Kestrel/Transport.Sockets/src/SocketConnectionListener.cs b/src/Servers/Kestrel/Transport.Sockets/src/SocketConnectionListener.cs index 485dc9c99746..d875bcc6961b 100644 --- a/src/Servers/Kestrel/Transport.Sockets/src/SocketConnectionListener.cs +++ b/src/Servers/Kestrel/Transport.Sockets/src/SocketConnectionListener.cs @@ -117,7 +117,7 @@ internal void Bind() { Debug.Assert(_listenSocket != null, "Bind must be called first."); - var acceptSocket = await _listenSocket.AcceptAsync(cancellationToken); + var acceptSocket = await _options.AcceptSocketAsync(_listenSocket, cancellationToken); // Only apply no delay to Tcp based endpoints if (acceptSocket.LocalEndPoint is IPEndPoint) diff --git a/src/Servers/Kestrel/Transport.Sockets/src/SocketTransportOptions.cs b/src/Servers/Kestrel/Transport.Sockets/src/SocketTransportOptions.cs index 6e2cb7ca4735..ada150e80d5b 100644 --- a/src/Servers/Kestrel/Transport.Sockets/src/SocketTransportOptions.cs +++ b/src/Servers/Kestrel/Transport.Sockets/src/SocketTransportOptions.cs @@ -80,6 +80,25 @@ public class SocketTransportOptions /// public Func CreateBoundListenSocket { get; set; } = CreateDefaultBoundListenSocket; + /// + /// A function used to accept a new given a listening . + /// + /// + /// The listening passed is the one created by a previous call to . + /// + /// This property defaults to . + /// + public Func> AcceptSocketAsync { get; set; } = DefaultAcceptSocketAsync; + + /// + /// Accepts a new from a listen previously obtained from . + /// + /// A listening . + /// Indicates if the accept operation should be aborted. + /// A newly accepted . + public static ValueTask DefaultAcceptSocketAsync(Socket listenSocket, CancellationToken cancellationToken) + => listenSocket.AcceptAsync(cancellationToken); + /// /// Creates a default instance of for the given /// that can be used by a connection listener to listen for inbound requests. diff --git a/src/Servers/Kestrel/test/Sockets.BindTests/SocketTransportOptionsTests.cs b/src/Servers/Kestrel/test/Sockets.BindTests/SocketTransportOptionsTests.cs index c8ccbf00bb13..35f9a4e18bc4 100644 --- a/src/Servers/Kestrel/test/Sockets.BindTests/SocketTransportOptionsTests.cs +++ b/src/Servers/Kestrel/test/Sockets.BindTests/SocketTransportOptionsTests.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Net; +using System.Net.Http; using System.Net.Sockets; using System.Runtime.InteropServices; using System.Threading.Tasks; @@ -80,6 +81,37 @@ public void CreateDefaultBoundListenSocket_PreservesLocalEndpointFromFileHandleE Assert.Equal(fileHandleSocket.LocalEndPoint, listenSocket.LocalEndPoint); } + [Fact] + public async Task VerifySocketTransportCallsAcceptSocketAsync() + { + var wasCalled = false; + + ValueTask AcceptSocketAsync(Socket socket, CancellationToken cancellationToken) + { + wasCalled = true; + return socket.AcceptAsync(cancellationToken); + } + + using var host = CreateWebHost( + new IPEndPoint(IPAddress.Loopback, 0), + options => + { + options.AcceptSocketAsync = AcceptSocketAsync; + } + ); + + await host.StartAsync(); + using var client = new HttpClient(); + + var response = await client.GetAsync($"http://127.0.0.1:{host.GetPort()}/"); + response.EnsureSuccessStatusCode(); + + await host.StopAsync(); + + Assert.True(wasCalled, $"Expected {nameof(SocketTransportOptions.AcceptSocketAsync)} to be called."); + await host.StopAsync(); + } + public static IEnumerable GetEndpoints() { // IPv4