diff --git a/src/Http/Routing/src/Builder/RoutingEndpointConventionBuilderExtensions.cs b/src/Http/Routing/src/Builder/RoutingEndpointConventionBuilderExtensions.cs new file mode 100644 index 000000000000..5aaf7498e718 --- /dev/null +++ b/src/Http/Routing/src/Builder/RoutingEndpointConventionBuilderExtensions.cs @@ -0,0 +1,43 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Routing; + +namespace Microsoft.AspNetCore.Builder +{ + /// + /// Extension methods for adding routing metadata to endpoint instances using . + /// + public static class RoutingEndpointConventionBuilderExtensions + { + /// + /// Requires that endpoints match one of the specified hosts during routing. + /// + /// The to add the metadata to. + /// + /// The hosts used during routing. + /// Hosts should be Unicode rather than punycode, and may have a port. + /// An empty collection means any host will be accepted. + /// + /// A reference to this instance after the operation has completed. + public static IEndpointConventionBuilder RequireHost(this IEndpointConventionBuilder builder, params string[] hosts) + { + if (builder == null) + { + throw new ArgumentNullException(nameof(builder)); + } + + if (hosts == null) + { + throw new ArgumentNullException(nameof(hosts)); + } + + builder.Add(endpointBuilder => + { + endpointBuilder.Metadata.Add(new HostAttribute(hosts)); + }); + return builder; + } + } +} diff --git a/src/Http/Routing/src/DependencyInjection/RoutingServiceCollectionExtensions.cs b/src/Http/Routing/src/DependencyInjection/RoutingServiceCollectionExtensions.cs index c78c64a5f55f..e18850d0dfc1 100644 --- a/src/Http/Routing/src/DependencyInjection/RoutingServiceCollectionExtensions.cs +++ b/src/Http/Routing/src/DependencyInjection/RoutingServiceCollectionExtensions.cs @@ -87,6 +87,7 @@ public static IServiceCollection AddRouting(this IServiceCollection services) // services.TryAddSingleton(); services.TryAddEnumerable(ServiceDescriptor.Singleton()); + services.TryAddEnumerable(ServiceDescriptor.Singleton()); // // Misc infrastructure diff --git a/src/Http/Routing/src/HostAttribute.cs b/src/Http/Routing/src/HostAttribute.cs new file mode 100644 index 000000000000..a26163210ac3 --- /dev/null +++ b/src/Http/Routing/src/HostAttribute.cs @@ -0,0 +1,67 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; + +namespace Microsoft.AspNetCore.Routing +{ + /// + /// Attribute for providing host metdata that is used during routing. + /// + [DebuggerDisplay("{DebuggerToString(),nq}")] + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = false, Inherited = false)] + public sealed class HostAttribute : Attribute, IHostMetadata + { + /// + /// Initializes a new instance of the class. + /// + /// + /// The host used during routing. + /// Host should be Unicode rather than punycode, and may have a port. + /// + public HostAttribute(string host) : this(new[] { host }) + { + if (host == null) + { + throw new ArgumentNullException(nameof(host)); + } + } + + /// + /// Initializes a new instance of the class. + /// + /// + /// The hosts used during routing. + /// Hosts should be Unicode rather than punycode, and may have a port. + /// An empty collection means any host will be accepted. + /// + public HostAttribute(params string[] hosts) + { + if (hosts == null) + { + throw new ArgumentNullException(nameof(hosts)); + } + + Hosts = hosts.ToArray(); + } + + /// + /// Returns a read-only collection of hosts used during routing. + /// Hosts will be Unicode rather than punycode, and may have a port. + /// An empty collection means any host will be accepted. + /// + public IReadOnlyList Hosts { get; } + + private string DebuggerToString() + { + var hostsDisplay = (Hosts.Count == 0) + ? "*:*" + : string.Join(",", Hosts.Select(h => h.Contains(':') ? h : h + ":*")); + + return $"Hosts: {hostsDisplay}"; + } + } +} diff --git a/src/Http/Routing/src/HttpMethodMetadata.cs b/src/Http/Routing/src/HttpMethodMetadata.cs index f01bf2a2a5b5..bf2d0eb950ad 100644 --- a/src/Http/Routing/src/HttpMethodMetadata.cs +++ b/src/Http/Routing/src/HttpMethodMetadata.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; diff --git a/src/Http/Routing/src/IHostMetadata.cs b/src/Http/Routing/src/IHostMetadata.cs new file mode 100644 index 000000000000..a3e52aa96c12 --- /dev/null +++ b/src/Http/Routing/src/IHostMetadata.cs @@ -0,0 +1,20 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information + +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Routing +{ + /// + /// Represents host metadata used during routing. + /// + public interface IHostMetadata + { + /// + /// Returns a read-only collection of hosts used during routing. + /// Hosts will be Unicode rather than punycode, and may have a port. + /// An empty collection means any host will be accepted. + /// + IReadOnlyList Hosts { get; } + } +} diff --git a/src/Http/Routing/src/Matching/HostMatcherPolicy.cs b/src/Http/Routing/src/Matching/HostMatcherPolicy.cs new file mode 100644 index 000000000000..7f00a8ead5fb --- /dev/null +++ b/src/Http/Routing/src/Matching/HostMatcherPolicy.cs @@ -0,0 +1,366 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.Routing.Matching +{ + /// + /// An that implements filtering and selection by + /// the host header of a request. + /// + public sealed class HostMatcherPolicy : MatcherPolicy, IEndpointComparerPolicy, INodeBuilderPolicy + { + // Run after HTTP methods, but before 'default'. + public override int Order { get; } = -100; + + public IComparer Comparer { get; } = new HostMetadataEndpointComparer(); + + public bool AppliesToEndpoints(IReadOnlyList endpoints) + { + if (endpoints == null) + { + throw new ArgumentNullException(nameof(endpoints)); + } + + return endpoints.Any(e => + { + var hosts = e.Metadata.GetMetadata()?.Hosts; + if (hosts == null || hosts.Count == 0) + { + return false; + } + + foreach (var host in hosts) + { + // Don't run policy on endpoints that match everything + var key = CreateEdgeKey(host); + if (!key.MatchesAll) + { + return true; + } + } + + return false; + }); + } + + private static EdgeKey CreateEdgeKey(string host) + { + if (host == null) + { + return EdgeKey.WildcardEdgeKey; + } + + var hostParts = host.Split(':'); + if (hostParts.Length == 1) + { + if (!string.IsNullOrEmpty(hostParts[0])) + { + return new EdgeKey(hostParts[0], null); + } + } + if (hostParts.Length == 2) + { + if (!string.IsNullOrEmpty(hostParts[0])) + { + if (int.TryParse(hostParts[1], out var port)) + { + return new EdgeKey(hostParts[0], port); + } + else if (string.Equals(hostParts[1], "*", StringComparison.Ordinal)) + { + return new EdgeKey(hostParts[0], null); + } + } + } + + throw new InvalidOperationException($"Could not parse host: {host}"); + } + + public IReadOnlyList GetEdges(IReadOnlyList endpoints) + { + if (endpoints == null) + { + throw new ArgumentNullException(nameof(endpoints)); + } + + // The algorithm here is designed to be preserve the order of the endpoints + // while also being relatively simple. Preserving order is important. + + // First, build a dictionary of all of the hosts that are included + // at this node. + // + // For now we're just building up the set of keys. We don't add any endpoints + // to lists now because we don't want ordering problems. + var edges = new Dictionary>(); + for (var i = 0; i < endpoints.Count; i++) + { + var endpoint = endpoints[i]; + var hosts = endpoint.Metadata.GetMetadata()?.Hosts.Select(h => CreateEdgeKey(h)).ToArray(); + if (hosts == null || hosts.Length == 0) + { + hosts = new[] { EdgeKey.WildcardEdgeKey }; + } + + for (var j = 0; j < hosts.Length; j++) + { + var host = hosts[j]; + if (!edges.ContainsKey(host)) + { + edges.Add(host, new List()); + } + } + } + + // Now in a second loop, add endpoints to these lists. We've enumerated all of + // the states, so we want to see which states this endpoint matches. + for (var i = 0; i < endpoints.Count; i++) + { + var endpoint = endpoints[i]; + + var endpointKeys = endpoint.Metadata.GetMetadata()?.Hosts.Select(h => CreateEdgeKey(h)).ToArray() ?? Array.Empty(); + if (endpointKeys.Length == 0) + { + // OK this means that this endpoint matches *all* hosts. + // So, loop and add it to all states. + foreach (var kvp in edges) + { + kvp.Value.Add(endpoint); + } + } + else + { + // OK this endpoint matches specific hosts + foreach (var kvp in edges) + { + // The edgeKey maps to a possible request header value + var edgeKey = kvp.Key; + + for (var j = 0; j < endpointKeys.Length; j++) + { + var endpointKey = endpointKeys[j]; + + if (edgeKey.Equals(endpointKey)) + { + kvp.Value.Add(endpoint); + break; + } + else if (edgeKey.HasHostWildcard && endpointKey.HasHostWildcard && + edgeKey.Port == endpointKey.Port && edgeKey.MatchHost(endpointKey.Host)) + { + kvp.Value.Add(endpoint); + break; + } + } + } + } + } + + return edges + .Select(kvp => new PolicyNodeEdge(kvp.Key, kvp.Value)) + .ToArray(); + } + + public PolicyJumpTable BuildJumpTable(int exitDestination, IReadOnlyList edges) + { + if (edges == null) + { + throw new ArgumentNullException(nameof(edges)); + } + + // Since our 'edges' can have wildcards, we do a sort based on how wildcard-ey they + // are then then execute them in linear order. + var ordered = edges + .Select(e => (host: (EdgeKey)e.State, destination: e.Destination)) + .OrderBy(e => GetScore(e.host)) + .ToArray(); + + return new HostPolicyJumpTable(exitDestination, ordered); + } + + private int GetScore(in EdgeKey key) + { + // Higher score == lower priority. + if (key.MatchesHost && !key.HasHostWildcard && key.MatchesPort) + { + return 1; // Has host AND port, e.g. www.consoto.com:8080 + } + else if (key.MatchesHost && !key.HasHostWildcard) + { + return 2; // Has host, e.g. www.consoto.com + } + else if (key.MatchesHost && key.MatchesPort) + { + return 3; // Has wildcard host AND port, e.g. *.consoto.com:8080 + } + else if (key.MatchesHost) + { + return 4; // Has wildcard host, e.g. *.consoto.com + } + else if (key.MatchesPort) + { + return 5; // Has port, e.g. *:8080 + } + else + { + return 6; // Has neither, e.g. *:* (or no metadata) + } + } + + private class HostMetadataEndpointComparer : EndpointMetadataComparer + { + protected override int CompareMetadata(IHostMetadata x, IHostMetadata y) + { + // Ignore the metadata if it has an empty list of hosts. + return base.CompareMetadata( + x?.Hosts.Count > 0 ? x : null, + y?.Hosts.Count > 0 ? y : null); + } + } + + private class HostPolicyJumpTable : PolicyJumpTable + { + private (EdgeKey host, int destination)[] _destinations; + private int _exitDestination; + + public HostPolicyJumpTable(int exitDestination, (EdgeKey host, int destination)[] destinations) + { + _exitDestination = exitDestination; + _destinations = destinations; + } + + public override int GetDestination(HttpContext httpContext) + { + // HostString can allocate when accessing the host or port + // Store host and port locally and reuse + var requestHost = httpContext.Request.Host; + var host = requestHost.Host; + var port = ResolvePort(httpContext, requestHost); + + var destinations = _destinations; + for (var i = 0; i < destinations.Length; i++) + { + var destination = destinations[i]; + + if ((!destination.host.MatchesPort || destination.host.Port == port) && + destination.host.MatchHost(host)) + { + return destination.destination; + } + } + + return _exitDestination; + } + + private static int? ResolvePort(HttpContext httpContext, HostString requestHost) + { + if (requestHost.Port != null) + { + return requestHost.Port; + } + else if (string.Equals("https", httpContext.Request.Scheme, StringComparison.OrdinalIgnoreCase)) + { + return 443; + } + else if (string.Equals("http", httpContext.Request.Scheme, StringComparison.OrdinalIgnoreCase)) + { + return 80; + } + else + { + return null; + } + } + } + + private readonly struct EdgeKey : IEquatable, IComparable, IComparable + { + private const string WildcardHost = "*"; + internal static readonly EdgeKey WildcardEdgeKey = new EdgeKey(null, null); + + public readonly int? Port; + public readonly string Host; + + private readonly string _wildcardEndsWith; + + public EdgeKey(string host, int? port) + { + Host = host ?? WildcardHost; + Port = port; + + HasHostWildcard = Host.StartsWith("*.", StringComparison.Ordinal); + _wildcardEndsWith = HasHostWildcard ? Host.Substring(1) : null; + } + + public bool HasHostWildcard { get; } + + public bool MatchesHost => !string.Equals(Host, WildcardHost, StringComparison.Ordinal); + + public bool MatchesPort => Port != null; + + public bool MatchesAll => !MatchesHost && !MatchesPort; + + public int CompareTo(EdgeKey other) + { + var result = Comparer.Default.Compare(Host, other.Host); + if (result != 0) + { + return result; + } + + return Comparer.Default.Compare(Port, other.Port); + } + + public int CompareTo(object obj) + { + return CompareTo((EdgeKey)obj); + } + + public bool Equals(EdgeKey other) + { + return string.Equals(Host, other.Host, StringComparison.Ordinal) && Port == other.Port; + } + + public bool MatchHost(string host) + { + if (MatchesHost) + { + if (HasHostWildcard) + { + return host.EndsWith(_wildcardEndsWith, StringComparison.OrdinalIgnoreCase); + } + else + { + return string.Equals(host, Host, StringComparison.OrdinalIgnoreCase); + } + } + + return true; + } + + public override int GetHashCode() + { + return (Host?.GetHashCode() ?? 0) ^ (Port?.GetHashCode() ?? 0); + } + + public override bool Equals(object obj) + { + if (obj is EdgeKey key) + { + return Equals(key); + } + + return false; + } + + public override string ToString() + { + return $"{Host}:{Port?.ToString() ?? "*"}"; + } + } + } +} diff --git a/src/Http/Routing/test/FunctionalTests/HostMatchingTests.cs b/src/Http/Routing/test/FunctionalTests/HostMatchingTests.cs new file mode 100644 index 000000000000..841c5b74c1e1 --- /dev/null +++ b/src/Http/Routing/test/FunctionalTests/HostMatchingTests.cs @@ -0,0 +1,119 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Net; +using System.Net.Http; +using System.Threading.Tasks; +using RoutingWebSite; +using Xunit; + +namespace Microsoft.AspNetCore.Routing.FunctionalTests +{ + public class HostMatchingTests : IClassFixture> + { + private readonly RoutingTestFixture _fixture; + + public HostMatchingTests(RoutingTestFixture fixture) + { + _fixture = fixture; + } + + private HttpClient CreateClient(string baseAddress) + { + var client = _fixture.CreateClient(baseAddress); + + return client; + } + + [Theory] + [InlineData("http://localhost")] + [InlineData("http://localhost:5001")] + public async Task Get_CatchAll(string baseAddress) + { + // Arrange + var request = new HttpRequestMessage(HttpMethod.Get, "api/DomainWildcard"); + + // Act + var client = CreateClient(baseAddress); + var response = await client.SendAsync(request); + var responseContent = await response.Content.ReadAsStringAsync(); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.Equal("*:*", responseContent); + } + + [Theory] + [InlineData("http://9000.0.0.1")] + [InlineData("http://9000.0.0.1:8888")] + public async Task Get_MatchWildcardDomain(string baseAddress) + { + // Arrange + var request = new HttpRequestMessage(HttpMethod.Get, "api/DomainWildcard"); + + // Act + var client = CreateClient(baseAddress); + var response = await client.SendAsync(request); + var responseContent = await response.Content.ReadAsStringAsync(); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.Equal("*.0.0.1:*", responseContent); + } + + [Theory] + [InlineData("http://127.0.0.1")] + [InlineData("http://127.0.0.1:8888")] + public async Task Get_MatchDomain(string baseAddress) + { + // Arrange + var request = new HttpRequestMessage(HttpMethod.Get, "api/DomainWildcard"); + + // Act + var client = CreateClient(baseAddress); + var response = await client.SendAsync(request); + var responseContent = await response.Content.ReadAsStringAsync(); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.Equal("127.0.0.1:*", responseContent); + } + + [Theory] + [InlineData("http://9000.0.0.1:5000")] + [InlineData("http://9000.0.0.1:5001")] + public async Task Get_MatchWildcardDomainAndPort(string baseAddress) + { + // Arrange + var request = new HttpRequestMessage(HttpMethod.Get, "api/DomainWildcard"); + + // Act + var client = CreateClient(baseAddress); + var response = await client.SendAsync(request); + var responseContent = await response.Content.ReadAsStringAsync(); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.Equal("*.0.0.1:5000,*.0.0.1:5001", responseContent); + } + + [Theory] + [InlineData("http://www.contoso.com")] + [InlineData("http://contoso.com")] + public async Task Get_MatchWildcardDomainAndSubdomain(string baseAddress) + { + // Arrange + var request = new HttpRequestMessage(HttpMethod.Get, "api/DomainWildcard"); + + // Act + var client = CreateClient(baseAddress); + var response = await client.SendAsync(request); + var responseContent = await response.Content.ReadAsStringAsync(); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.Equal("contoso.com:*,*.contoso.com:*", responseContent); + } + } +} diff --git a/src/Http/Routing/test/FunctionalTests/Microsoft.AspNetCore.Routing.FunctionalTests.csproj b/src/Http/Routing/test/FunctionalTests/Microsoft.AspNetCore.Routing.FunctionalTests.csproj index badadf1fe694..c281194c0cf5 100644 --- a/src/Http/Routing/test/FunctionalTests/Microsoft.AspNetCore.Routing.FunctionalTests.csproj +++ b/src/Http/Routing/test/FunctionalTests/Microsoft.AspNetCore.Routing.FunctionalTests.csproj @@ -1,4 +1,4 @@ - + netcoreapp3.0 diff --git a/src/Http/Routing/test/FunctionalTests/RoutingTestFixture.cs b/src/Http/Routing/test/FunctionalTests/RoutingTestFixture.cs index 1ced141956a2..51c9bc8c5ac7 100644 --- a/src/Http/Routing/test/FunctionalTests/RoutingTestFixture.cs +++ b/src/Http/Routing/test/FunctionalTests/RoutingTestFixture.cs @@ -25,6 +25,14 @@ public RoutingTestFixture() public HttpClient Client { get; } + public HttpClient CreateClient(string baseAddress) + { + var client = _server.CreateClient(); + client.BaseAddress = new Uri(baseAddress); + + return client; + } + public void Dispose() { Client.Dispose(); diff --git a/src/Http/Routing/test/UnitTests/Matching/HostMatcherPolicyIntegrationTest.cs b/src/Http/Routing/test/UnitTests/Matching/HostMatcherPolicyIntegrationTest.cs new file mode 100644 index 000000000000..b369d2a0e5e8 --- /dev/null +++ b/src/Http/Routing/test/UnitTests/Matching/HostMatcherPolicyIntegrationTest.cs @@ -0,0 +1,339 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Routing.Patterns; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.AspNetCore.Routing.Matching +{ + // End-to-end tests for the host matching functionality + public class HostMatcherPolicyIntegrationTest + { + [Fact] + public async Task Match_Host() + { + // Arrange + var endpoint = CreateEndpoint("/hello", hosts: new string[] { "contoso.com", }); + + var matcher = CreateMatcher(endpoint); + var (httpContext, context) = CreateContext("/hello", "contoso.com"); + + // Act + await matcher.MatchAsync(httpContext, context); + + // Assert + MatcherAssert.AssertMatch(context, httpContext, endpoint); + } + + [Fact] + public async Task Match_HostWithPort() + { + // Arrange + var endpoint = CreateEndpoint("/hello", hosts: new string[] { "contoso.com:8080", }); + + var matcher = CreateMatcher(endpoint); + var (httpContext, context) = CreateContext("/hello", "contoso.com:8080"); + + // Act + await matcher.MatchAsync(httpContext, context); + + // Assert + MatcherAssert.AssertMatch(context, httpContext, endpoint); + } + + [Fact] + public async Task Match_Host_Unicode() + { + // Arrange + var endpoint = CreateEndpoint("/hello", hosts: new string[] { "æon.contoso.com", }); + + var matcher = CreateMatcher(endpoint); + var (httpContext, context) = CreateContext("/hello", "æon.contoso.com"); + + // Act + await matcher.MatchAsync(httpContext, context); + + // Assert + MatcherAssert.AssertMatch(context, httpContext, endpoint); + } + + [Fact] + public async Task Match_HostWithPort_IncorrectPort() + { + // Arrange + var endpoint = CreateEndpoint("/hello", hosts: new string[] { "contoso.com:8080", }); + + var matcher = CreateMatcher(endpoint); + var (httpContext, context) = CreateContext("/hello", "contoso.com:1111"); + + // Act + await matcher.MatchAsync(httpContext, context); + + // Assert + MatcherAssert.AssertNotMatch(context, httpContext); + } + + [Fact] + public async Task Match_HostWithPort_IncorrectHost() + { + // Arrange + var endpoint = CreateEndpoint("/hello", hosts: new string[] { "contoso.com:8080", }); + + var matcher = CreateMatcher(endpoint); + var (httpContext, context) = CreateContext("/hello", "www.contoso.com:8080"); + + // Act + await matcher.MatchAsync(httpContext, context); + + // Assert + MatcherAssert.AssertNotMatch(context, httpContext); + } + + [Fact] + public async Task Match_HostWithWildcard() + { + // Arrange + var endpoint = CreateEndpoint("/hello", hosts: new string[] { "*.contoso.com:8080", }); + + var matcher = CreateMatcher(endpoint); + var (httpContext, context) = CreateContext("/hello", "æon.contoso.com:8080"); + + // Act + await matcher.MatchAsync(httpContext, context); + + // Assert + MatcherAssert.AssertMatch(context, httpContext, endpoint); + } + + [Fact] + public async Task Match_HostWithWildcard_Unicode() + { + // Arrange + var endpoint = CreateEndpoint("/hello", hosts: new string[] { "*.contoso.com:8080", }); + + var matcher = CreateMatcher(endpoint); + var (httpContext, context) = CreateContext("/hello", "www.contoso.com:8080"); + + // Act + await matcher.MatchAsync(httpContext, context); + + // Assert + MatcherAssert.AssertMatch(context, httpContext, endpoint); + } + + [Fact] + public async Task Match_Host_CaseInsensitive() + { + // Arrange + var endpoint = CreateEndpoint("/hello", hosts: new string[] { "Contoso.COM", }); + + var matcher = CreateMatcher(endpoint); + var (httpContext, context) = CreateContext("/hello", "contoso.com"); + + // Act + await matcher.MatchAsync(httpContext, context); + + // Assert + MatcherAssert.AssertMatch(context, httpContext, endpoint); + } + + [Fact] + public async Task Match_HostWithPort_InferHttpPort() + { + // Arrange + var endpoint = CreateEndpoint("/hello", hosts: new string[] { "contoso.com:80", }); + + var matcher = CreateMatcher(endpoint); + var (httpContext, context) = CreateContext("/hello", "contoso.com", "http"); + + // Act + await matcher.MatchAsync(httpContext, context); + + // Assert + MatcherAssert.AssertMatch(context, httpContext, endpoint); + } + + [Fact] + public async Task Match_HostWithPort_InferHttpsPort() + { + // Arrange + var endpoint = CreateEndpoint("/hello", hosts: new string[] { "contoso.com:443", }); + + var matcher = CreateMatcher(endpoint); + var (httpContext, context) = CreateContext("/hello", "contoso.com", "https"); + + // Act + await matcher.MatchAsync(httpContext, context); + + // Assert + MatcherAssert.AssertMatch(context, httpContext, endpoint); + } + + [Fact] + public async Task Match_HostWithPort_NoHostHeader() + { + // Arrange + var endpoint = CreateEndpoint("/hello", hosts: new string[] { "contoso.com:443", }); + + var matcher = CreateMatcher(endpoint); + var (httpContext, context) = CreateContext("/hello", null, "https"); + + // Act + await matcher.MatchAsync(httpContext, context); + + // Assert + MatcherAssert.AssertNotMatch(context, httpContext); + } + + [Fact] + public async Task Match_Port_NoHostHeader_InferHttpsPort() + { + // Arrange + var endpoint = CreateEndpoint("/hello", hosts: new string[] { "*:443", }); + + var matcher = CreateMatcher(endpoint); + var (httpContext, context) = CreateContext("/hello", null, "https"); + + // Act + await matcher.MatchAsync(httpContext, context); + + // Assert + MatcherAssert.AssertMatch(context, httpContext, endpoint); + } + + [Fact] + public async Task Match_NoMetadata_MatchesAnyHost() + { + // Arrange + var endpoint = CreateEndpoint("/hello"); + + var matcher = CreateMatcher(endpoint); + var (httpContext, context) = CreateContext("/hello", "contoso.com"); + + // Act + await matcher.MatchAsync(httpContext, context); + + // Assert + MatcherAssert.AssertMatch(context, httpContext, endpoint); + } + + [Fact] + public async Task Match_EmptyHostList_MatchesAnyHost() + { + // Arrange + var endpoint = CreateEndpoint("/hello", hosts: new string[] { }); + + var matcher = CreateMatcher(endpoint); + var (httpContext, context) = CreateContext("/hello", "contoso.com"); + + // Act + await matcher.MatchAsync(httpContext, context); + + // Assert + MatcherAssert.AssertMatch(context, httpContext, endpoint); + } + + [Fact] + public async Task Match_WildcardHost_MatchesAnyHost() + { + // Arrange + var endpoint = CreateEndpoint("/hello", hosts: new string[] { "*", }); + + var matcher = CreateMatcher(endpoint); + var (httpContext, context) = CreateContext("/hello", "contoso.com"); + + // Act + await matcher.MatchAsync(httpContext, context); + + // Assert + MatcherAssert.AssertMatch(context, httpContext, endpoint); + } + + [Fact] + public async Task Match_WildcardHostAndWildcardPort_MatchesAnyHost() + { + // Arrange + var endpoint = CreateEndpoint("/hello", hosts: new string[] { "*:*", }); + + var matcher = CreateMatcher(endpoint); + var (httpContext, context) = CreateContext("/hello", "contoso.com"); + + // Act + await matcher.MatchAsync(httpContext, context); + + // Assert + MatcherAssert.AssertMatch(context, httpContext, endpoint); + } + + private static Matcher CreateMatcher(params RouteEndpoint[] endpoints) + { + var services = new ServiceCollection() + .AddOptions() + .AddLogging() + .AddRouting() + .BuildServiceProvider(); + + var builder = services.GetRequiredService(); + for (var i = 0; i < endpoints.Length; i++) + { + builder.AddEndpoint(endpoints[i]); + } + + return builder.Build(); + } + + internal static (HttpContext httpContext, EndpointSelectorContext context) CreateContext( + string path, + string host, + string scheme = null) + { + var httpContext = new DefaultHttpContext(); + if (host != null) + { + httpContext.Request.Host = new HostString(host); + } + httpContext.Request.Path = path; + httpContext.Request.Scheme = scheme; + + var context = new EndpointSelectorContext(); + httpContext.Features.Set(context); + httpContext.Features.Set(context); + + return (httpContext, context); + } + + internal static RouteEndpoint CreateEndpoint( + string template, + object defaults = null, + object constraints = null, + int order = 0, + string[] hosts = null) + { + var metadata = new List(); + if (hosts != null) + { + metadata.Add(new HostAttribute(hosts ?? Array.Empty())); + } + + var displayName = "endpoint: " + template + " " + string.Join(", ", hosts ?? new[] { "*:*" }); + return new RouteEndpoint( + TestConstants.EmptyRequestDelegate, + RoutePatternFactory.Parse(template, defaults, constraints), + order, + new EndpointMetadataCollection(metadata), + displayName); + } + + internal (Matcher matcher, RouteEndpoint endpoint) CreateMatcher(string template) + { + var endpoint = CreateEndpoint(template); + return (CreateMatcher(endpoint), endpoint); + } + } +} diff --git a/src/Http/Routing/test/UnitTests/Matching/HostMatcherPolicyTest.cs b/src/Http/Routing/test/UnitTests/Matching/HostMatcherPolicyTest.cs new file mode 100644 index 000000000000..9b3421e90834 --- /dev/null +++ b/src/Http/Routing/test/UnitTests/Matching/HostMatcherPolicyTest.cs @@ -0,0 +1,176 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Routing; +using Microsoft.AspNetCore.Routing.Matching; +using Microsoft.AspNetCore.Routing.Patterns; +using Xunit; + +namespace Microsoft.AspNetCore.Routing.Matching +{ + public class HostMatcherPolicyTest + { + [Fact] + public void AppliesToEndpoints_EndpointWithoutMetadata_ReturnsFalse() + { + // Arrange + var endpoints = new[] { CreateEndpoint("/", null), }; + + var policy = CreatePolicy(); + + // Act + var result = policy.AppliesToEndpoints(endpoints); + + // Assert + Assert.False(result); + } + + [Fact] + public void AppliesToEndpoints_EndpointWithoutHosts_ReturnsFalse() + { + // Arrange + var endpoints = new[] + { + CreateEndpoint("/", new HostAttribute(Array.Empty())), + }; + + var policy = CreatePolicy(); + + // Act + var result = policy.AppliesToEndpoints(endpoints); + + // Assert + Assert.False(result); + } + + [Fact] + public void AppliesToEndpoints_EndpointHasHosts_ReturnsTrue() + { + // Arrange + var endpoints = new[] + { + CreateEndpoint("/", new HostAttribute(Array.Empty())), + CreateEndpoint("/", new HostAttribute(new[] { "localhost", })), + }; + + var policy = CreatePolicy(); + + // Act + var result = policy.AppliesToEndpoints(endpoints); + + // Assert + Assert.True(result); + } + + [Theory] + [InlineData(":")] + [InlineData(":80")] + [InlineData("80:")] + [InlineData("")] + [InlineData("::")] + [InlineData("*:test")] + public void AppliesToEndpoints_InvalidHosts(string host) + { + // Arrange + var endpoints = new[] { CreateEndpoint("/", new HostAttribute(new[] { host })), }; + + var policy = CreatePolicy(); + + // Act & Assert + Assert.Throws(() => + { + policy.AppliesToEndpoints(endpoints); + }); + } + + [Fact] + public void GetEdges_GroupsByHost() + { + // Arrange + var endpoints = new[] + { + CreateEndpoint("/", new HostAttribute(new[] { "*:5000", "*:5001", })), + CreateEndpoint("/", new HostAttribute(Array.Empty())), + CreateEndpoint("/", hostMetadata: null), + CreateEndpoint("/", new HostAttribute("*.contoso.com:*")), + CreateEndpoint("/", new HostAttribute("*.sub.contoso.com:*")), + CreateEndpoint("/", new HostAttribute("www.contoso.com:*")), + CreateEndpoint("/", new HostAttribute("www.contoso.com:5000")), + CreateEndpoint("/", new HostAttribute("*:*")), + }; + + var policy = CreatePolicy(); + + // Act + var edges = policy.GetEdges(endpoints); + + var data = edges.OrderBy(e => e.State).ToList(); + + // Assert + Assert.Collection( + data, + e => + { + Assert.Equal("*:*", e.State.ToString()); + Assert.Equal(new[] { endpoints[1], endpoints[2], endpoints[7], }, e.Endpoints.ToArray()); + }, + e => + { + Assert.Equal("*:5000", e.State.ToString()); + Assert.Equal(new[] { endpoints[0], endpoints[1], endpoints[2], }, e.Endpoints.ToArray()); + }, + e => + { + Assert.Equal("*:5001", e.State.ToString()); + Assert.Equal(new[] { endpoints[0], endpoints[1], endpoints[2], }, e.Endpoints.ToArray()); + }, + e => + { + Assert.Equal("*.contoso.com:*", e.State.ToString()); + Assert.Equal(new[] { endpoints[1], endpoints[2], endpoints[3], endpoints[4], }, e.Endpoints.ToArray()); + }, + e => + { + Assert.Equal("*.sub.contoso.com:*", e.State.ToString()); + Assert.Equal(new[] { endpoints[1], endpoints[2], endpoints[4], }, e.Endpoints.ToArray()); + }, + e => + { + Assert.Equal("www.contoso.com:*", e.State.ToString()); + Assert.Equal(new[] { endpoints[1], endpoints[2], endpoints[5], }, e.Endpoints.ToArray()); + }, + e => + { + Assert.Equal("www.contoso.com:5000", e.State.ToString()); + Assert.Equal(new[] { endpoints[1], endpoints[2], endpoints[6], }, e.Endpoints.ToArray()); + }); + } + + private static RouteEndpoint CreateEndpoint(string template, IHostMetadata hostMetadata) + { + var metadata = new List(); + if (hostMetadata != null) + { + metadata.Add(hostMetadata); + } + + return new RouteEndpoint( + (context) => Task.CompletedTask, + RoutePatternFactory.Parse(template), + 0, + new EndpointMetadataCollection(metadata), + $"test: {template} - {string.Join(", ", hostMetadata?.Hosts ?? Array.Empty())}"); + } + + private static HostMatcherPolicy CreatePolicy() + { + return new HostMatcherPolicy(); + } + } +} diff --git a/src/Http/Routing/test/UnitTests/Matching/HttpMethodMatcherPolicyIntegrationTest.cs b/src/Http/Routing/test/UnitTests/Matching/HttpMethodMatcherPolicyIntegrationTest.cs index fd223d9b7e49..1bdbcd0e2f74 100644 --- a/src/Http/Routing/test/UnitTests/Matching/HttpMethodMatcherPolicyIntegrationTest.cs +++ b/src/Http/Routing/test/UnitTests/Matching/HttpMethodMatcherPolicyIntegrationTest.cs @@ -356,6 +356,7 @@ internal static (HttpContext httpContext, EndpointSelectorContext context) Creat return (httpContext, context); } + internal static RouteEndpoint CreateEndpoint( string template, object defaults = null, diff --git a/src/Http/Routing/test/UnitTests/RoutingEndpointConventionBuilderExtensionsTests.cs b/src/Http/Routing/test/UnitTests/RoutingEndpointConventionBuilderExtensionsTests.cs new file mode 100644 index 000000000000..de02bb47de08 --- /dev/null +++ b/src/Http/Routing/test/UnitTests/RoutingEndpointConventionBuilderExtensionsTests.cs @@ -0,0 +1,47 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Routing.Patterns; +using Xunit; + +namespace Microsoft.AspNetCore.Routing +{ + public class RoutingEndpointConventionBuilderExtensionsTests + { + [Fact] + public void RequireHost_HostNames() + { + // Arrange + var builder = new TestEndpointConventionBuilder(); + + // Act + builder.RequireHost("contoso.com:8080"); + + // Assert + var convention = Assert.Single(builder.Conventions); + + var endpointModel = new RouteEndpointBuilder((context) => Task.CompletedTask, RoutePatternFactory.Parse("/"), 0); + convention(endpointModel); + + var hostMetadata = Assert.IsType(Assert.Single(endpointModel.Metadata)); + + Assert.Equal("contoso.com:8080", hostMetadata.Hosts.Single()); + } + + private class TestEndpointConventionBuilder : IEndpointConventionBuilder + { + public IList> Conventions { get; } = new List>(); + + public void Add(Action convention) + { + Conventions.Add(convention); + } + } + } +} diff --git a/src/Http/Routing/test/testassets/RoutingWebSite/UseEndpointRoutingStartup.cs b/src/Http/Routing/test/testassets/RoutingWebSite/UseEndpointRoutingStartup.cs index c2f05782674c..765876a2a3be 100644 --- a/src/Http/Routing/test/testassets/RoutingWebSite/UseEndpointRoutingStartup.cs +++ b/src/Http/Routing/test/testassets/RoutingWebSite/UseEndpointRoutingStartup.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Endpoints; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.Routing; @@ -112,6 +113,12 @@ public void Configure(IApplicationBuilder app) "Link: " + linkGenerator.GetPathByRouteValues(httpContext, "WithDoubleAsteriskCatchAll", new { })); }, new RouteNameMetadata(routeName: "WithDoubleAsteriskCatchAll")); + + MapHostEndpoint(routes); + MapHostEndpoint(routes, "*.0.0.1"); + MapHostEndpoint(routes, "127.0.0.1"); + MapHostEndpoint(routes, "*.0.0.1:5000", "*.0.0.1:5001"); + MapHostEndpoint(routes, "contoso.com:*", "*.contoso.com:*"); }); app.Map("/Branch1", branch => SetupBranch(branch, "Branch1")); @@ -124,6 +131,31 @@ public void Configure(IApplicationBuilder app) app.UseEndpoint(); } + private IEndpointConventionBuilder MapHostEndpoint(IEndpointRouteBuilder routes, params string[] hosts) + { + var hostsDisplay = (hosts == null || hosts.Length == 0) + ? "*:*" + : string.Join(",", hosts.Select(h => h.Contains(':') ? h : h + ":*")); + + var conventionBuilder = routes.MapGet( + "api/DomainWildcard", + httpContext => + { + var response = httpContext.Response; + response.StatusCode = 200; + response.ContentType = "text/plain"; + return response.WriteAsync(hostsDisplay); + }); + + conventionBuilder.Add(endpointBuilder => + { + endpointBuilder.Metadata.Add(new HostAttribute(hosts)); + endpointBuilder.DisplayName += " HOST: " + hostsDisplay; + }); + + return conventionBuilder; + } + private void SetupBranch(IApplicationBuilder app, string name) { app.UseRouting(routes =>