From 523e5f398ab0a255e5e8d8bb086a61b3e1ccb739 Mon Sep 17 00:00:00 2001 From: Rafiki Assumani <87031580+rafikiassumaniMSFT@users.noreply.github.com> Date: Fri, 20 Aug 2021 19:25:11 -0500 Subject: [PATCH] Improve Minimal APIs support for request media types #35082 (#35230) * add support for request media types * change namespace for acceptsmatcher policy * additional changes * enable 415 when unsupported content type is provide * add accepts extension method on minimalActions endpoint * add IAcceptsMetadata to API description * add empty content type test * feat: add types for iacceptmetadata * change requestdelegate factory to return metatdata * clean RequestDelegateFactoryOptions.cs * change request delegate to return requestdelegateresult type * make apis property init only * adding constructor to requestdelegatefactoryResult * Fixups * fix merge errors * address pr comment * fix test error * remove options from params * implements iacceptsMetadata * fix test failures * fix test failures * move iacceptmetadata to shared source * add acceptsmetadata shared code to mvc * fix tests * address pr comments * address another comment * nit * fix duplicate media types * fix test failures Co-authored-by: Pranav K --- .../src/Metadata/IAcceptsMetadata.cs | 25 + .../src/PublicAPI.Unshipped.txt | 7 + .../src/RequestDelegateResult.cs | 33 + ...icrosoft.AspNetCore.Http.Extensions.csproj | 3 +- .../src/PublicAPI.Unshipped.txt | 4 +- .../src/RequestDelegateFactory.cs | 49 +- .../test/RequestDelegateFactoryTests.cs | 190 ++++-- ...malActionEndpointRouteBuilderExtensions.cs | 11 +- .../RoutingServiceCollectionExtensions.cs | 1 + .../src/Matching/AcceptsMatcherPolicy.cs | 392 +++++++++++ .../src/Microsoft.AspNetCore.Routing.csproj | 14 +- .../Matching/AcceptsMatcherPolicyTest.cs} | 72 +- .../Microsoft.AspNetCore.Routing.Tests.csproj | 6 +- .../src/DefaultApiDescriptionProvider.cs | 1 + ...pointMetadataApiDescriptionProviderTest.cs | 33 + .../IApiRequestMetadataProvider.cs | 3 +- ...nApiEndpointConventionBuilderExtensions.cs | 38 +- src/Mvc/Mvc.Core/src/ConsumesAttribute.cs | 52 +- .../MvcCoreServiceCollectionExtensions.cs | 2 +- .../src/Formatters/AcceptHeaderParser.cs | 1 + .../src/Formatters/HttpParseResult.cs | 12 - .../src/Formatters/HttpTokenParsingRules.cs | 270 -------- src/Mvc/Mvc.Core/src/Formatters/MediaType.cs | 184 +----- .../src/Microsoft.AspNetCore.Mvc.Core.csproj | 3 + src/Mvc/Mvc.Core/src/PublicAPI.Unshipped.txt | 3 + .../src/Routing/ActionEndpointFactory.cs | 5 +- .../src/Routing/ConsumesMatcherPolicy.cs | 394 ----------- .../Mvc.Core/src/Routing/ConsumesMetadata.cs | 24 - .../Mvc.Core/src/Routing/IConsumesMetadata.cs | 13 - .../MvcCoreServiceCollectionExtensionsTest.cs | 2 +- .../SimpleWithWebApplicationBuilderTests.cs | 43 ++ .../Program.cs | 5 +- src/Shared/MediaType/HttpTokenParsingRule.cs | 277 ++++++++ .../MediaType/ReadOnlyMediaTypeHeaderValue.cs | 625 ++++++++++++++++++ src/Shared/RoutingMetadata/AcceptsMetadata.cs | 54 ++ 35 files changed, 1839 insertions(+), 1012 deletions(-) create mode 100644 src/Http/Http.Abstractions/src/Metadata/IAcceptsMetadata.cs create mode 100644 src/Http/Http.Abstractions/src/RequestDelegateResult.cs create mode 100644 src/Http/Routing/src/Matching/AcceptsMatcherPolicy.cs rename src/{Mvc/Mvc.Core/test/Routing/ConsumesMatcherPolicyTest.cs => Http/Routing/test/UnitTests/Matching/AcceptsMatcherPolicyTest.cs} (85%) delete mode 100644 src/Mvc/Mvc.Core/src/Formatters/HttpParseResult.cs delete mode 100644 src/Mvc/Mvc.Core/src/Formatters/HttpTokenParsingRules.cs delete mode 100644 src/Mvc/Mvc.Core/src/Routing/ConsumesMatcherPolicy.cs delete mode 100644 src/Mvc/Mvc.Core/src/Routing/ConsumesMetadata.cs delete mode 100644 src/Mvc/Mvc.Core/src/Routing/IConsumesMetadata.cs create mode 100644 src/Shared/MediaType/HttpTokenParsingRule.cs create mode 100644 src/Shared/MediaType/ReadOnlyMediaTypeHeaderValue.cs create mode 100644 src/Shared/RoutingMetadata/AcceptsMetadata.cs diff --git a/src/Http/Http.Abstractions/src/Metadata/IAcceptsMetadata.cs b/src/Http/Http.Abstractions/src/Metadata/IAcceptsMetadata.cs new file mode 100644 index 000000000000..a3325158a7ac --- /dev/null +++ b/src/Http/Http.Abstractions/src/Metadata/IAcceptsMetadata.cs @@ -0,0 +1,25 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + + +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Http.Metadata +{ + /// + /// Interface for accepting request media types. + /// + public interface IAcceptsMetadata + { + /// + /// Gets a list of the allowed request content types. + /// If the incoming request does not have a Content-Type with one of these values, the request will be rejected with a 415 response. + /// + IReadOnlyList ContentTypes { get; } + + /// + /// Gets the type being read from the request. + /// + Type? RequestType { get; } + } +} diff --git a/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt b/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt index 9338ae940214..d689bce14e4b 100644 --- a/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt +++ b/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt @@ -7,6 +7,9 @@ *REMOVED*abstract Microsoft.AspNetCore.Http.HttpRequest.ContentType.get -> string! Microsoft.AspNetCore.Http.IResult Microsoft.AspNetCore.Http.IResult.ExecuteAsync(Microsoft.AspNetCore.Http.HttpContext! httpContext) -> System.Threading.Tasks.Task! +Microsoft.AspNetCore.Http.Metadata.IAcceptsMetadata +Microsoft.AspNetCore.Http.Metadata.IAcceptsMetadata.ContentTypes.get -> System.Collections.Generic.IReadOnlyList! +Microsoft.AspNetCore.Http.Metadata.IAcceptsMetadata.RequestType.get -> System.Type? Microsoft.AspNetCore.Http.Metadata.IFromBodyMetadata Microsoft.AspNetCore.Http.Metadata.IFromBodyMetadata.AllowEmpty.get -> bool Microsoft.AspNetCore.Http.Metadata.IFromHeaderMetadata @@ -18,6 +21,10 @@ Microsoft.AspNetCore.Http.Metadata.IFromRouteMetadata.Name.get -> string? Microsoft.AspNetCore.Http.Metadata.IFromServiceMetadata Microsoft.AspNetCore.Http.Endpoint.Endpoint(Microsoft.AspNetCore.Http.RequestDelegate? requestDelegate, Microsoft.AspNetCore.Http.EndpointMetadataCollection? metadata, string? displayName) -> void Microsoft.AspNetCore.Http.Endpoint.RequestDelegate.get -> Microsoft.AspNetCore.Http.RequestDelegate? +Microsoft.AspNetCore.Http.RequestDelegateResult +Microsoft.AspNetCore.Http.RequestDelegateResult.EndpointMetadata.get -> System.Collections.Generic.IReadOnlyList! +Microsoft.AspNetCore.Http.RequestDelegateResult.RequestDelegate.get -> Microsoft.AspNetCore.Http.RequestDelegate! +Microsoft.AspNetCore.Http.RequestDelegateResult.RequestDelegateResult(Microsoft.AspNetCore.Http.RequestDelegate! requestDelegate, System.Collections.Generic.IReadOnlyList! metadata) -> void Microsoft.AspNetCore.Routing.RouteValueDictionary.TryAdd(string! key, object? value) -> bool static readonly Microsoft.AspNetCore.Http.HttpProtocol.Http09 -> string! static Microsoft.AspNetCore.Http.HttpProtocol.IsHttp09(string! protocol) -> bool diff --git a/src/Http/Http.Abstractions/src/RequestDelegateResult.cs b/src/Http/Http.Abstractions/src/RequestDelegateResult.cs new file mode 100644 index 000000000000..88ddd28a4173 --- /dev/null +++ b/src/Http/Http.Abstractions/src/RequestDelegateResult.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Http +{ + /// + /// The result of creating a from a + /// + public sealed class RequestDelegateResult + { + /// + /// Creates a new instance of . + /// + public RequestDelegateResult(RequestDelegate requestDelegate, IReadOnlyList metadata) + { + RequestDelegate = requestDelegate; + EndpointMetadata = metadata; + } + + /// + /// Gets the + /// + public RequestDelegate RequestDelegate { get;} + + /// + /// Gets endpoint metadata inferred from creating the + /// + public IReadOnlyList EndpointMetadata { get;} + } + +} diff --git a/src/Http/Http.Extensions/src/Microsoft.AspNetCore.Http.Extensions.csproj b/src/Http/Http.Extensions/src/Microsoft.AspNetCore.Http.Extensions.csproj index 7e4e66fe03aa..93a83901bac1 100644 --- a/src/Http/Http.Extensions/src/Microsoft.AspNetCore.Http.Extensions.csproj +++ b/src/Http/Http.Extensions/src/Microsoft.AspNetCore.Http.Extensions.csproj @@ -1,4 +1,4 @@ - + ASP.NET Core common extension methods for HTTP abstractions, HTTP headers, HTTP request/response, and session state. @@ -16,6 +16,7 @@ + diff --git a/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt b/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt index aac4480a487c..65dc3997cc20 100644 --- a/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt +++ b/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt @@ -192,8 +192,8 @@ static Microsoft.AspNetCore.Http.HeaderDictionaryTypeExtensions.AppendList(th static Microsoft.AspNetCore.Http.HeaderDictionaryTypeExtensions.GetTypedHeaders(this Microsoft.AspNetCore.Http.HttpRequest! request) -> Microsoft.AspNetCore.Http.Headers.RequestHeaders! static Microsoft.AspNetCore.Http.HeaderDictionaryTypeExtensions.GetTypedHeaders(this Microsoft.AspNetCore.Http.HttpResponse! response) -> Microsoft.AspNetCore.Http.Headers.ResponseHeaders! static Microsoft.AspNetCore.Http.HttpContextServerVariableExtensions.GetServerVariable(this Microsoft.AspNetCore.Http.HttpContext! context, string! variableName) -> string? -static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Delegate! action, Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions? options = null) -> Microsoft.AspNetCore.Http.RequestDelegate! -static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Reflection.MethodInfo! methodInfo, System.Func? targetFactory = null, Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions? options = null) -> Microsoft.AspNetCore.Http.RequestDelegate! +static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Delegate! action, Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions? options = null) -> Microsoft.AspNetCore.Http.RequestDelegateResult! +static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Reflection.MethodInfo! methodInfo, System.Func? targetFactory = null, Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions? options = null) -> Microsoft.AspNetCore.Http.RequestDelegateResult! static Microsoft.AspNetCore.Http.ResponseExtensions.Clear(this Microsoft.AspNetCore.Http.HttpResponse! response) -> void static Microsoft.AspNetCore.Http.ResponseExtensions.Redirect(this Microsoft.AspNetCore.Http.HttpResponse! response, string! location, bool permanent, bool preserveMethod) -> void static Microsoft.AspNetCore.Http.SendFileResponseExtensions.SendFileAsync(this Microsoft.AspNetCore.Http.HttpResponse! response, Microsoft.Extensions.FileProviders.IFileInfo! file, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index a735b9abee20..eb6a6eebe06f 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.Linq; using System.Linq.Expressions; +using System.Net.Http; using System.Reflection; using System.Security.Claims; using System.Text; @@ -63,14 +64,16 @@ public static partial class RequestDelegateFactory private static readonly BinaryExpression TempSourceStringNotNullExpr = Expression.NotEqual(TempSourceStringExpr, Expression.Constant(null)); private static readonly BinaryExpression TempSourceStringNullExpr = Expression.Equal(TempSourceStringExpr, Expression.Constant(null)); + private static readonly AcceptsMetadata DefaultAcceptsMetadata = new(new[] { "application/json" }); + /// /// Creates a implementation for . /// /// A request handler with any number of custom parameters that often produces a response with its return value. /// The used to configure the behavior of the handler. - /// The . + /// The . #pragma warning disable RS0026 // Do not add multiple public overloads with optional parameters - public static RequestDelegate Create(Delegate action, RequestDelegateFactoryOptions? options = null) + public static RequestDelegateResult Create(Delegate action, RequestDelegateFactoryOptions? options = null) #pragma warning restore RS0026 // Do not add multiple public overloads with optional parameters { if (action is null) @@ -84,12 +87,15 @@ public static RequestDelegate Create(Delegate action, RequestDelegateFactoryOpti null => null, }; - var targetableRequestDelegate = CreateTargetableRequestDelegate(action.Method, options, targetExpression); - - return httpContext => + var factoryContext = new FactoryContext { - return targetableRequestDelegate(action.Target, httpContext); + ServiceProviderIsService = options?.ServiceProvider?.GetService() }; + + var targetableRequestDelegate = CreateTargetableRequestDelegate(action.Method, options, factoryContext, targetExpression); + + return new RequestDelegateResult(httpContext => targetableRequestDelegate(action.Target, httpContext), factoryContext.Metadata); + } /// @@ -100,7 +106,7 @@ public static RequestDelegate Create(Delegate action, RequestDelegateFactoryOpti /// The used to configure the behavior of the handler. /// The . #pragma warning disable RS0026 // Do not add multiple public overloads with optional parameters - public static RequestDelegate Create(MethodInfo methodInfo, Func? targetFactory = null, RequestDelegateFactoryOptions? options = null) + public static RequestDelegateResult Create(MethodInfo methodInfo, Func? targetFactory = null, RequestDelegateFactoryOptions? options = null) #pragma warning restore RS0026 // Do not add multiple public overloads with optional parameters { if (methodInfo is null) @@ -113,31 +119,30 @@ public static RequestDelegate Create(MethodInfo methodInfo, Func() + }; + if (targetFactory is null) { if (methodInfo.IsStatic) { - var untargetableRequestDelegate = CreateTargetableRequestDelegate(methodInfo, options, targetExpression: null); + var untargetableRequestDelegate = CreateTargetableRequestDelegate(methodInfo, options, factoryContext, targetExpression: null); - return httpContext => - { - return untargetableRequestDelegate(null, httpContext); - }; + return new RequestDelegateResult(httpContext => untargetableRequestDelegate(null, httpContext), factoryContext.Metadata); } targetFactory = context => Activator.CreateInstance(methodInfo.DeclaringType)!; } var targetExpression = Expression.Convert(TargetExpr, methodInfo.DeclaringType); - var targetableRequestDelegate = CreateTargetableRequestDelegate(methodInfo, options, targetExpression); + var targetableRequestDelegate = CreateTargetableRequestDelegate(methodInfo, options, factoryContext, targetExpression); - return httpContext => - { - return targetableRequestDelegate(targetFactory(httpContext), httpContext); - }; + return new RequestDelegateResult(httpContext => targetableRequestDelegate(targetFactory(httpContext), httpContext), factoryContext.Metadata); } - private static Func CreateTargetableRequestDelegate(MethodInfo methodInfo, RequestDelegateFactoryOptions? options, Expression? targetExpression) + private static Func CreateTargetableRequestDelegate(MethodInfo methodInfo, RequestDelegateFactoryOptions? options, FactoryContext factoryContext, Expression? targetExpression) { // Non void return type @@ -155,11 +160,6 @@ public static RequestDelegate Create(MethodInfo methodInfo, Func() - }; - if (options?.RouteParameterNames is { } routeParameterNames) { factoryContext.RouteParameters = new(routeParameterNames); @@ -861,6 +861,7 @@ private static Expression BindParameterFromBody(ParameterInfo parameter, bool al } } + factoryContext.Metadata.Add(DefaultAcceptsMetadata); var isOptional = IsOptionalParameter(parameter); factoryContext.JsonRequestBodyType = parameter.ParameterType; @@ -1111,6 +1112,8 @@ private class FactoryContext public Dictionary TrackedParameters { get; } = new(); public bool HasMultipleBodyParameters { get; set; } + + public List Metadata { get; } = new(); } private static class RequestDelegateFactoryConstants diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index a6236e37c802..2326e6907b09 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -3,6 +3,7 @@ #nullable enable +using System; using System.Globalization; using System.Linq.Expressions; using System.Net; @@ -91,7 +92,8 @@ public async Task RequestDelegateInvokesAction(Delegate @delegate) { var httpContext = new DefaultHttpContext(); - var requestDelegate = RequestDelegateFactory.Create(@delegate); + var factoryResult = RequestDelegateFactory.Create(@delegate); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -111,7 +113,8 @@ public async Task StaticMethodInfoOverloadWorksWithBasicReflection() BindingFlags.NonPublic | BindingFlags.Static, new[] { typeof(HttpContext) }); - var requestDelegate = RequestDelegateFactory.Create(methodInfo!); + var factoryResult = RequestDelegateFactory.Create(methodInfo!); + var requestDelegate = factoryResult.RequestDelegate; var httpContext = new DefaultHttpContext(); @@ -156,7 +159,8 @@ object GetTarget() return new TestNonStaticActionClass(2); } - var requestDelegate = RequestDelegateFactory.Create(methodInfo!, _ => GetTarget()); + var factoryResult = RequestDelegateFactory.Create(methodInfo!, _ => GetTarget()); + var requestDelegate = factoryResult.RequestDelegate; var httpContext = new DefaultHttpContext(); @@ -202,7 +206,8 @@ static void TestAction(HttpContext httpContext, [FromRoute] int value) var httpContext = new DefaultHttpContext(); httpContext.Request.RouteValues[paramName] = originalRouteParam.ToString(NumberFormatInfo.InvariantInfo); - var requestDelegate = RequestDelegateFactory.Create(TestAction); + var factoryResult = RequestDelegateFactory.Create(TestAction); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -229,7 +234,7 @@ public async Task SpecifiedRouteParametersDoNotFallbackToQueryString() { var httpContext = new DefaultHttpContext(); - var requestDelegate = RequestDelegateFactory.Create((int? id, HttpContext httpContext) => + var factoryResult = RequestDelegateFactory.Create((int? id, HttpContext httpContext) => { if (id is not null) { @@ -238,6 +243,8 @@ public async Task SpecifiedRouteParametersDoNotFallbackToQueryString() }, new() { RouteParameterNames = new string[] { "id" } }); + var requestDelegate = factoryResult.RequestDelegate; + httpContext.Request.Query = new QueryCollection(new Dictionary { ["id"] = "42" @@ -253,7 +260,7 @@ public async Task SpecifiedQueryParametersDoNotFallbackToRouteValues() { var httpContext = new DefaultHttpContext(); - var requestDelegate = RequestDelegateFactory.Create((int? id, HttpContext httpContext) => + var factoryResult = RequestDelegateFactory.Create((int? id, HttpContext httpContext) => { if (id is not null) { @@ -271,6 +278,8 @@ public async Task SpecifiedQueryParametersDoNotFallbackToRouteValues() ["id"] = "42" }; + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); Assert.Equal(41, httpContext.Items["input"]); @@ -281,7 +290,7 @@ public async Task NullRouteParametersPrefersRouteOverQueryString() { var httpContext = new DefaultHttpContext(); - var requestDelegate = RequestDelegateFactory.Create((int? id, HttpContext httpContext) => + var factoryResult = RequestDelegateFactory.Create((int? id, HttpContext httpContext) => { if (id is not null) { @@ -299,6 +308,7 @@ public async Task NullRouteParametersPrefersRouteOverQueryString() ["id"] = "42" }; + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); Assert.Equal(42, httpContext.Items["input"]); @@ -311,7 +321,9 @@ public async Task CreatingDelegateWithInstanceMethodInfoCreatesInstancePerCall() Assert.NotNull(methodInfo); - var requestDelegate = RequestDelegateFactory.Create(methodInfo!); + var factoryResult = RequestDelegateFactory.Create(methodInfo!); + var requestDelegate = factoryResult.RequestDelegate; + var context = new DefaultHttpContext(); await requestDelegate(context); @@ -337,7 +349,8 @@ public async Task RequestDelegatePopulatesFromRouteOptionalParameter() { var httpContext = new DefaultHttpContext(); - var requestDelegate = RequestDelegateFactory.Create(TestOptional); + var factoryResult = RequestDelegateFactory.Create(TestOptional); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -349,7 +362,8 @@ public async Task RequestDelegatePopulatesFromNullableOptionalParameter() { var httpContext = new DefaultHttpContext(); - var requestDelegate = RequestDelegateFactory.Create(TestOptional); + var factoryResult = RequestDelegateFactory.Create(TestOptional); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -361,7 +375,8 @@ public async Task RequestDelegatePopulatesFromOptionalStringParameter() { var httpContext = new DefaultHttpContext(); - var requestDelegate = RequestDelegateFactory.Create(TestOptionalString); + var factoryResult = RequestDelegateFactory.Create(TestOptionalString); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -378,7 +393,8 @@ public async Task RequestDelegatePopulatesFromRouteOptionalParameterBasedOnParam httpContext.Request.RouteValues[paramName] = originalRouteParam.ToString(NumberFormatInfo.InvariantInfo); - var requestDelegate = RequestDelegateFactory.Create(TestOptional); + var factoryResult = RequestDelegateFactory.Create(TestOptional); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -401,7 +417,8 @@ void TestAction([FromRoute(Name = specifiedName)] int foo) var httpContext = new DefaultHttpContext(); httpContext.Request.RouteValues[specifiedName] = originalRouteParam.ToString(NumberFormatInfo.InvariantInfo); - var requestDelegate = RequestDelegateFactory.Create(TestAction); + var factoryResult = RequestDelegateFactory.Create(TestAction); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -428,7 +445,8 @@ void TestAction([FromRoute] int foo) serviceCollection.AddSingleton(LoggerFactory); httpContext.RequestServices = serviceCollection.BuildServiceProvider(); - var requestDelegate = RequestDelegateFactory.Create(TestAction); + var factoryResult = RequestDelegateFactory.Create(TestAction); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -557,7 +575,8 @@ public async Task RequestDelegatePopulatesUnattributedTryParsableParametersFromR serviceCollection.AddSingleton(LoggerFactory); httpContext.RequestServices = serviceCollection.BuildServiceProvider(); - var requestDelegate = RequestDelegateFactory.Create(action); + var factoryResult = RequestDelegateFactory.Create(action); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -578,7 +597,8 @@ public async Task RequestDelegatePopulatesUnattributedTryParsableParametersFromQ serviceCollection.AddSingleton(LoggerFactory); httpContext.RequestServices = serviceCollection.BuildServiceProvider(); - var requestDelegate = RequestDelegateFactory.Create(action); + var factoryResult = RequestDelegateFactory.Create(action); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -597,11 +617,13 @@ public async Task RequestDelegatePopulatesUnattributedTryParsableParametersFromR ["tryParsable"] = "invalid!" }); - var requestDelegate = RequestDelegateFactory.Create((HttpContext httpContext, int tryParsable) => + var factoryResult = RequestDelegateFactory.Create((HttpContext httpContext, int tryParsable) => { httpContext.Items["tryParsable"] = tryParsable; }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); Assert.Equal(42, httpContext.Items["tryParsable"]); @@ -614,11 +636,13 @@ public async Task RequestDelegatePrefersBindAsyncOverTryParseString() httpContext.Request.Headers.Referer = "https://example.org"; - var requestDelegate = RequestDelegateFactory.Create((HttpContext httpContext, MyBindAsyncRecord tryParsable) => + var resultFactory = RequestDelegateFactory.Create((HttpContext httpContext, MyBindAsyncRecord tryParsable) => { httpContext.Items["tryParsable"] = tryParsable; }); + var requestDelegate = resultFactory.RequestDelegate; + await requestDelegate(httpContext); Assert.Equal(new MyBindAsyncRecord(new Uri("https://example.org")), httpContext.Items["tryParsable"]); @@ -631,11 +655,12 @@ public async Task RequestDelegatePrefersBindAsyncOverTryParseStringForNonNullabl httpContext.Request.Headers.Referer = "https://example.org"; - var requestDelegate = RequestDelegateFactory.Create((HttpContext httpContext, MyBindAsyncStruct tryParsable) => + var resultFactory = RequestDelegateFactory.Create((HttpContext httpContext, MyBindAsyncStruct tryParsable) => { httpContext.Items["tryParsable"] = tryParsable; }); + var requestDelegate = resultFactory.RequestDelegate; await requestDelegate(httpContext); Assert.Equal(new MyBindAsyncStruct(new Uri("https://example.org")), httpContext.Items["tryParsable"]); @@ -644,8 +669,9 @@ public async Task RequestDelegatePrefersBindAsyncOverTryParseStringForNonNullabl [Fact] public async Task RequestDelegateUsesTryParseStringoOverBindAsyncGivenExplicitAttribute() { - var fromRouteRequestDelegate = RequestDelegateFactory.Create((HttpContext httpContext, [FromRoute] MyBindAsyncRecord tryParsable) => { }); - var fromQueryRequestDelegate = RequestDelegateFactory.Create((HttpContext httpContext, [FromQuery] MyBindAsyncRecord tryParsable) => { }); + var fromRouteFactoryResult = RequestDelegateFactory.Create((HttpContext httpContext, [FromRoute] MyBindAsyncRecord tryParsable) => { }); + var fromQueryFactoryResult = RequestDelegateFactory.Create((HttpContext httpContext, [FromQuery] MyBindAsyncRecord tryParsable) => { }); + var httpContext = new DefaultHttpContext { @@ -662,6 +688,9 @@ public async Task RequestDelegateUsesTryParseStringoOverBindAsyncGivenExplicitAt }, }; + var fromRouteRequestDelegate = fromRouteFactoryResult.RequestDelegate; + var fromQueryRequestDelegate = fromQueryFactoryResult.RequestDelegate; + await Assert.ThrowsAsync(() => fromRouteRequestDelegate(httpContext)); await Assert.ThrowsAsync(() => fromQueryRequestDelegate(httpContext)); } @@ -669,7 +698,7 @@ public async Task RequestDelegateUsesTryParseStringoOverBindAsyncGivenExplicitAt [Fact] public async Task RequestDelegateUsesTryParseStringOverBindAsyncGivenNullableStruct() { - var fromRouteRequestDelegate = RequestDelegateFactory.Create((HttpContext httpContext, MyBindAsyncStruct? tryParsable) => { }); + var fromRouteFactoryResult = RequestDelegateFactory.Create((HttpContext httpContext, MyBindAsyncStruct? tryParsable) => { }); var httpContext = new DefaultHttpContext { @@ -682,6 +711,7 @@ public async Task RequestDelegateUsesTryParseStringOverBindAsyncGivenNullableStr }, }; + var fromRouteRequestDelegate = fromRouteFactoryResult.RequestDelegate; await Assert.ThrowsAsync(() => fromRouteRequestDelegate(httpContext)); } @@ -738,7 +768,8 @@ void TestAction([FromRoute] int tryParsable, [FromRoute] int tryParsable2) httpContext.Features.Set(new TestHttpRequestLifetimeFeature()); httpContext.RequestServices = serviceCollection.BuildServiceProvider(); - var requestDelegate = RequestDelegateFactory.Create(TestAction); + var factoryResult = RequestDelegateFactory.Create(TestAction); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -770,11 +801,12 @@ public async Task RequestDelegateLogsBindAsyncFailuresAndSets400Response() var invoked = false; - var requestDelegate = RequestDelegateFactory.Create((MyBindAsyncRecord arg1, MyBindAsyncRecord arg2) => + var factoryResult = RequestDelegateFactory.Create((MyBindAsyncRecord arg1, MyBindAsyncRecord arg2) => { invoked = true; }); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); Assert.False(invoked); @@ -803,7 +835,9 @@ public async Task BindAsyncExceptionsThrowException() RequestServices = new ServiceCollection().AddSingleton(LoggerFactory).BuildServiceProvider(), }; - var requestDelegate = RequestDelegateFactory.Create((MyBindAsyncTypeThatThrows arg1) => { }); + var factoryResult = RequestDelegateFactory.Create((MyBindAsyncTypeThatThrows arg1) => { }); + + var requestDelegate = factoryResult.RequestDelegate; var ex = await Assert.ThrowsAsync(() => requestDelegate(httpContext)); Assert.Equal("BindAsync failed", ex.Message); @@ -845,13 +879,15 @@ public async Task BindAsyncWithBodyArgument() var invoked = false; - var requestDelegate = RequestDelegateFactory.Create((HttpContext context, MyBindAsyncRecord arg1, Todo todo) => + var factoryResult = RequestDelegateFactory.Create((HttpContext context, MyBindAsyncRecord arg1, Todo todo) => { invoked = true; context.Items[nameof(arg1)] = arg1; context.Items[nameof(todo)] = todo; }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); Assert.True(invoked); @@ -899,13 +935,15 @@ public async Task BindAsyncRunsBeforeBodyBinding() var invoked = false; - var requestDelegate = RequestDelegateFactory.Create((HttpContext context, CustomTodo customTodo, Todo todo) => + var factoryResult = RequestDelegateFactory.Create((HttpContext context, CustomTodo customTodo, Todo todo) => { invoked = true; context.Items[nameof(customTodo)] = customTodo; context.Items[nameof(todo)] = todo; }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); Assert.True(invoked); @@ -938,7 +976,8 @@ void TestAction([FromQuery] int value) var httpContext = new DefaultHttpContext(); httpContext.Request.Query = query; - var requestDelegate = RequestDelegateFactory.Create(TestAction); + var factoryResult = RequestDelegateFactory.Create(TestAction); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -961,7 +1000,8 @@ void TestAction([FromHeader(Name = customHeaderName)] int value) var httpContext = new DefaultHttpContext(); httpContext.Request.Headers[customHeaderName] = originalHeaderParam.ToString(NumberFormatInfo.InvariantInfo); - var requestDelegate = RequestDelegateFactory.Create(TestAction); + var factoryResult = RequestDelegateFactory.Create(TestAction); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -1035,7 +1075,8 @@ public async Task RequestDelegatePopulatesFromBodyParameter(Delegate action) }); httpContext.RequestServices = mock.Object; - var requestDelegate = RequestDelegateFactory.Create(action); + var factoryResult = RequestDelegateFactory.Create(action); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -1057,7 +1098,8 @@ public async Task RequestDelegateRejectsEmptyBodyGivenFromBodyParameter(Delegate serviceCollection.AddSingleton(LoggerFactory); httpContext.RequestServices = serviceCollection.BuildServiceProvider(); - var requestDelegate = RequestDelegateFactory.Create(action); + var factoryResult = RequestDelegateFactory.Create(action); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -1078,7 +1120,8 @@ void TestAction([FromBody(AllowEmpty = true)] Todo todo) httpContext.Request.Headers["Content-Type"] = "application/json"; httpContext.Request.Headers["Content-Length"] = "0"; - var requestDelegate = RequestDelegateFactory.Create(TestAction); + var factoryResult = RequestDelegateFactory.Create(TestAction); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -1102,7 +1145,8 @@ void TestAction([FromBody(AllowEmpty = true)] BodyStruct bodyStruct) httpContext.Request.Headers["Content-Type"] = "application/json"; httpContext.Request.Headers["Content-Length"] = "0"; - var requestDelegate = RequestDelegateFactory.Create(TestAction); + var factoryResult = RequestDelegateFactory.Create(TestAction); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -1131,7 +1175,8 @@ void TestAction([FromBody] Todo todo) httpContext.Features.Set(new RequestBodyDetectionFeature(true)); httpContext.RequestServices = serviceCollection.BuildServiceProvider(); - var requestDelegate = RequestDelegateFactory.Create(TestAction); + var factoryResult = RequestDelegateFactory.Create(TestAction); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -1167,7 +1212,8 @@ void TestAction([FromBody] Todo todo) httpContext.RequestServices = serviceCollection.BuildServiceProvider(); - var requestDelegate = RequestDelegateFactory.Create(TestAction); + var factoryResult = RequestDelegateFactory.Create(TestAction); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -1240,7 +1286,8 @@ public async Task RequestDelegateRequiresServiceForAllFromServiceParameters(Dele var httpContext = new DefaultHttpContext(); httpContext.RequestServices = new EmptyServiceProvider(); - var requestDelegate = RequestDelegateFactory.Create(action); + var factoryResult = RequestDelegateFactory.Create(action); + var requestDelegate = factoryResult.RequestDelegate; await Assert.ThrowsAsync(() => requestDelegate(httpContext)); } @@ -1262,7 +1309,8 @@ public async Task RequestDelegatePopulatesParametersFromServiceWithAndWithoutAtt var httpContext = new DefaultHttpContext(); httpContext.RequestServices = requestScoped.ServiceProvider; - var requestDelegate = RequestDelegateFactory.Create(action, options: new() { ServiceProvider = services }); + var factoryResult = RequestDelegateFactory.Create(action, options: new() { ServiceProvider = services }); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -1281,7 +1329,8 @@ void TestAction(HttpContext httpContext) var httpContext = new DefaultHttpContext(); - var requestDelegate = RequestDelegateFactory.Create(TestAction); + var factoryResult = RequestDelegateFactory.Create(TestAction); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -1304,7 +1353,8 @@ void TestAction(CancellationToken cancellationToken) RequestAborted = cts.Token }; - var requestDelegate = RequestDelegateFactory.Create(TestAction); + var factoryResult = RequestDelegateFactory.Create(TestAction); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -1326,7 +1376,8 @@ void TestAction(ClaimsPrincipal user) User = new ClaimsPrincipal() }; - var requestDelegate = RequestDelegateFactory.Create(TestAction); + var factoryResult = RequestDelegateFactory.Create(TestAction); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -1345,7 +1396,8 @@ void TestAction(HttpRequest httpRequest) var httpContext = new DefaultHttpContext(); - var requestDelegate = RequestDelegateFactory.Create(TestAction); + var factoryResult = RequestDelegateFactory.Create(TestAction); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -1364,7 +1416,8 @@ void TestAction(HttpResponse httpResponse) var httpContext = new DefaultHttpContext(); - var requestDelegate = RequestDelegateFactory.Create(TestAction); + var factoryResult = RequestDelegateFactory.Create(TestAction); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -1408,7 +1461,8 @@ public async Task RequestDelegateWritesComplexReturnValueAsJsonResponseBody(Dele var responseBodyStream = new MemoryStream(); httpContext.Response.Body = responseBodyStream; - var requestDelegate = RequestDelegateFactory.Create(@delegate); + var factoryResult = RequestDelegateFactory.Create(@delegate); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -1482,7 +1536,8 @@ public async Task RequestDelegateUsesCustomIResult(Delegate @delegate) var responseBodyStream = new MemoryStream(); httpContext.Response.Body = responseBodyStream; - var requestDelegate = RequestDelegateFactory.Create(@delegate); + var factoryResult = RequestDelegateFactory.Create(@delegate); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -1545,7 +1600,8 @@ public async Task RequestDelegateWritesStringReturnValueAndSetContentTypeWhenNul var responseBodyStream = new MemoryStream(); httpContext.Response.Body = responseBodyStream; - var requestDelegate = RequestDelegateFactory.Create(@delegate); + var factoryResult = RequestDelegateFactory.Create(@delegate); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -1562,7 +1618,8 @@ public async Task RequestDelegateWritesStringReturnDoNotChangeContentType(Delega var httpContext = new DefaultHttpContext(); httpContext.Response.ContentType = "application/json; charset=utf-8"; - var requestDelegate = RequestDelegateFactory.Create(@delegate); + var factoryResult = RequestDelegateFactory.Create(@delegate); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -1601,7 +1658,8 @@ public async Task RequestDelegateWritesIntReturnValue(Delegate @delegate) var responseBodyStream = new MemoryStream(); httpContext.Response.Body = responseBodyStream; - var requestDelegate = RequestDelegateFactory.Create(@delegate); + var factoryResult = RequestDelegateFactory.Create(@delegate); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -1642,7 +1700,8 @@ public async Task RequestDelegateWritesBoolReturnValue(Delegate @delegate) var responseBodyStream = new MemoryStream(); httpContext.Response.Body = responseBodyStream; - var requestDelegate = RequestDelegateFactory.Create(@delegate); + var factoryResult = RequestDelegateFactory.Create(@delegate); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -1680,7 +1739,8 @@ public async Task RequestDelegateThrowsInvalidOperationExceptionOnNullDelegate(D var responseBodyStream = new MemoryStream(); httpContext.Response.Body = responseBodyStream; - var requestDelegate = RequestDelegateFactory.Create(@delegate); + var factoryResult = RequestDelegateFactory.Create(@delegate); + var requestDelegate = factoryResult.RequestDelegate; var exception = await Assert.ThrowsAnyAsync(async () => await requestDelegate(httpContext)); Assert.Contains(message, exception.Message); @@ -1725,7 +1785,8 @@ public async Task RequestDelegateWritesNullReturnNullValue(Delegate @delegate) var responseBodyStream = new MemoryStream(); httpContext.Response.Body = responseBodyStream; - var requestDelegate = RequestDelegateFactory.Create(@delegate); + var factoryResult = RequestDelegateFactory.Create(@delegate); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -1784,7 +1845,8 @@ public async Task RequestDelegateHandlesQueryParamOptionality(Delegate @delegate serviceCollection.AddSingleton(LoggerFactory); httpContext.RequestServices = serviceCollection.BuildServiceProvider(); - var requestDelegate = RequestDelegateFactory.Create(@delegate); + var factoryResult = RequestDelegateFactory.Create(@delegate); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -1855,11 +1917,13 @@ public async Task RequestDelegateHandlesRouteParamOptionality(Delegate @delegate serviceCollection.AddSingleton(LoggerFactory); httpContext.RequestServices = serviceCollection.BuildServiceProvider(); - var requestDelegate = RequestDelegateFactory.Create(@delegate, new() + var factoryResult = RequestDelegateFactory.Create(@delegate, new() { RouteParameterNames = routeParam is not null ? new[] { paramName } : Array.Empty() }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); var logs = TestSink.Writes.ToArray(); @@ -1929,7 +1993,8 @@ public async Task RequestDelegateHandlesBodyParamOptionality(Delegate @delegate, serviceCollection.AddSingleton(Options.Create(jsonOptions)); httpContext.RequestServices = serviceCollection.BuildServiceProvider(); - var requestDelegate = RequestDelegateFactory.Create(@delegate); + var factoryResult = RequestDelegateFactory.Create(@delegate); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -1962,11 +2027,12 @@ public async Task RequestDelegateDoesSupportBindAsyncOptionality() var invoked = false; - var requestDelegate = RequestDelegateFactory.Create((MyBindAsyncRecord? arg1) => + var factoryResult = RequestDelegateFactory.Create((MyBindAsyncRecord? arg1) => { invoked = true; }); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); Assert.True(invoked); @@ -2012,7 +2078,8 @@ public async Task RequestDelegateHandlesServiceParamOptionality(Delegate @delega httpContext.RequestServices = services; RequestDelegateFactoryOptions options = new() { ServiceProvider = services }; - var requestDelegate = RequestDelegateFactory.Create(@delegate, options); + var factoryResult = RequestDelegateFactory.Create(@delegate, options); + var requestDelegate = factoryResult.RequestDelegate; if (!isInvalid) { @@ -2056,7 +2123,9 @@ public async Task AllowEmptyOverridesOptionality(Delegate @delegate, bool allows serviceCollection.AddSingleton(LoggerFactory); httpContext.RequestServices = serviceCollection.BuildServiceProvider(); - var requestDelegate = RequestDelegateFactory.Create(@delegate); + var factoryResult = RequestDelegateFactory.Create(@delegate); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); @@ -2098,7 +2167,8 @@ public async Task CanSetStringParamAsOptionalWithNullabilityDisability(bool prov }); } - var requestDelegate = RequestDelegateFactory.Create(optionalQueryParam); + var factoryResult = RequestDelegateFactory.Create(optionalQueryParam); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -2127,7 +2197,8 @@ public async Task CanSetParseableStringParamAsOptionalWithNullabilityDisability( }); } - var requestDelegate = RequestDelegateFactory.Create(optionalQueryParam); + var factoryResult = RequestDelegateFactory.Create(optionalQueryParam); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -2156,8 +2227,9 @@ public async Task TreatsUnknownNullabilityAsOptionalForReferenceType(bool provid }); } - var requestDelegate = RequestDelegateFactory.Create(optionalQueryParam); + var factoryResult = RequestDelegateFactory.Create(optionalQueryParam); + var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); Assert.Equal(200, httpContext.Response.StatusCode); diff --git a/src/Http/Routing/src/Builder/MinimalActionEndpointRouteBuilderExtensions.cs b/src/Http/Routing/src/Builder/MinimalActionEndpointRouteBuilderExtensions.cs index 15281a28ccb3..d209bd774a57 100644 --- a/src/Http/Routing/src/Builder/MinimalActionEndpointRouteBuilderExtensions.cs +++ b/src/Http/Routing/src/Builder/MinimalActionEndpointRouteBuilderExtensions.cs @@ -7,6 +7,7 @@ using System.Reflection; using System.Runtime.CompilerServices; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Metadata; using Microsoft.AspNetCore.Routing; using Microsoft.AspNetCore.Routing.Patterns; using Microsoft.CodeAnalysis.CSharp.Symbols; @@ -173,8 +174,10 @@ public static MinimalActionEndpointConventionBuilder Map( RouteParameterNames = routeParams }; + var requestDelegateResult = RequestDelegateFactory.Create(action, options); + var builder = new RouteEndpointBuilder( - RequestDelegateFactory.Create(action, options), + requestDelegateResult.RequestDelegate, pattern, defaultOrder) { @@ -203,6 +206,12 @@ public static MinimalActionEndpointConventionBuilder Map( // Add delegate attributes as metadata var attributes = action.Method.GetCustomAttributes(); + // Add add request delegate metadata + foreach (var metadata in requestDelegateResult.EndpointMetadata) + { + builder.Metadata.Add(metadata); + } + // This can be null if the delegate is a dynamic method or compiled from an expression tree if (attributes is not null) { diff --git a/src/Http/Routing/src/DependencyInjection/RoutingServiceCollectionExtensions.cs b/src/Http/Routing/src/DependencyInjection/RoutingServiceCollectionExtensions.cs index 226b85cebe75..2c0944eec4e4 100644 --- a/src/Http/Routing/src/DependencyInjection/RoutingServiceCollectionExtensions.cs +++ b/src/Http/Routing/src/DependencyInjection/RoutingServiceCollectionExtensions.cs @@ -91,6 +91,7 @@ public static IServiceCollection AddRouting(this IServiceCollection services) services.TryAddSingleton(); services.TryAddEnumerable(ServiceDescriptor.Singleton()); services.TryAddEnumerable(ServiceDescriptor.Singleton()); + services.TryAddEnumerable(ServiceDescriptor.Singleton()); // // Misc infrastructure diff --git a/src/Http/Routing/src/Matching/AcceptsMatcherPolicy.cs b/src/Http/Routing/src/Matching/AcceptsMatcherPolicy.cs new file mode 100644 index 000000000000..ee686e17749b --- /dev/null +++ b/src/Http/Routing/src/Matching/AcceptsMatcherPolicy.cs @@ -0,0 +1,392 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Linq; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Headers; +using Microsoft.AspNetCore.Http.Metadata; + +namespace Microsoft.AspNetCore.Routing.Matching; + +internal sealed class AcceptsMatcherPolicy : MatcherPolicy, IEndpointComparerPolicy, INodeBuilderPolicy, IEndpointSelectorPolicy +{ + internal const string Http415EndpointDisplayName = "415 HTTP Unsupported Media Type"; + internal const string AnyContentType = "*/*"; + + // Run after HTTP methods, but before 'default'. + public override int Order { get; } = -100; + + public IComparer Comparer { get; } = new ConsumesMetadataEndpointComparer(); + + bool INodeBuilderPolicy.AppliesToEndpoints(IReadOnlyList endpoints) + { + if (endpoints == null) + { + throw new ArgumentNullException(nameof(endpoints)); + } + + if (ContainsDynamicEndpoints(endpoints)) + { + return false; + } + + return AppliesToEndpointsCore(endpoints); + } + + bool IEndpointSelectorPolicy.AppliesToEndpoints(IReadOnlyList endpoints) + { + if (endpoints == null) + { + throw new ArgumentNullException(nameof(endpoints)); + } + + // When the node contains dynamic endpoints we can't make any assumptions. + return ContainsDynamicEndpoints(endpoints); + } + + private static bool AppliesToEndpointsCore(IReadOnlyList endpoints) + { + return endpoints.Any(e => e.Metadata.GetMetadata()?.ContentTypes.Count > 0); + } + + public Task ApplyAsync(HttpContext httpContext, CandidateSet candidates) + { + if (httpContext == null) + { + throw new ArgumentNullException(nameof(httpContext)); + } + + if (candidates == null) + { + throw new ArgumentNullException(nameof(candidates)); + } + + // We want to return a 415 if we eliminated ALL of the currently valid endpoints due to content type + // mismatch. + bool? needs415Endpoint = null; + + for (var i = 0; i < candidates.Count; i++) + { + // We do this check first for consistency with how 415 is implemented for the graph version + // of this code. We still want to know if any endpoints in this set require an a ContentType + // even if those endpoints are already invalid - hence the null check. + var metadata = candidates[i].Endpoint?.Metadata.GetMetadata(); + if (metadata == null || metadata.ContentTypes?.Count == 0) + { + // Can match any content type. + needs415Endpoint = false; + continue; + } + + // Saw a valid endpoint. + needs415Endpoint = needs415Endpoint ?? true; + + if (!candidates.IsValidCandidate(i)) + { + // If the candidate is already invalid, then do a search to see if it has a wildcard content type. + // + // We don't want to return a 415 if any content type could be accepted depending on other parameters. + if (metadata != null) + { + for (var j = 0; j < metadata.ContentTypes?.Count; j++) + { + if (string.Equals("*/*", metadata.ContentTypes[j], StringComparison.Ordinal)) + { + needs415Endpoint = false; + break; + } + } + } + + continue; + } + + var contentType = httpContext.Request.ContentType; + var mediaType = string.IsNullOrEmpty(contentType) ? (ReadOnlyMediaTypeHeaderValue?)null : new(contentType); + + var matched = false; + for (var j = 0; j < metadata.ContentTypes?.Count; j++) + { + var candidateMediaType = new ReadOnlyMediaTypeHeaderValue(metadata.ContentTypes[j]); + if (candidateMediaType.MatchesAllTypes) + { + // We don't need a 415 response because there's an endpoint that would accept any type. + needs415Endpoint = false; + } + + // If there's no ContentType, then then can only matched by a wildcard `*/*`. + if (mediaType == null && !candidateMediaType.MatchesAllTypes) + { + continue; + } + + // We have a ContentType but it's not a match. + else if (mediaType != null && !mediaType.Value.IsSubsetOf(candidateMediaType)) + { + continue; + } + + // We have a ContentType and we accept any value OR we have a ContentType and it's a match. + matched = true; + needs415Endpoint = false; + break; + } + + if (!matched) + { + candidates.SetValidity(i, false); + } + } + + if (needs415Endpoint == true) + { + // We saw some endpoints coming in, and we eliminated them all. + httpContext.SetEndpoint(CreateRejectionEndpoint()); + } + + return Task.CompletedTask; + } + + public IReadOnlyList GetEdges(IReadOnlyList endpoints) + { + if (endpoints == null) + { + throw new ArgumentNullException(nameof(endpoints)); + } + + // The algorithm here is designed to be preserve the order of the endpoints + // while also being relatively simple. Preserving order is important. + + // First, build a dictionary of all of the content-type patterns that are included + // at this node. + // + // For now we're just building up the set of keys. We don't add any endpoints + // to lists now because we don't want ordering problems. + var edges = new Dictionary>(StringComparer.OrdinalIgnoreCase); + for (var i = 0; i < endpoints.Count; i++) + { + var endpoint = endpoints[i]; + var contentTypes = endpoint.Metadata.GetMetadata()?.ContentTypes; + if (contentTypes == null || contentTypes.Count == 0) + { + contentTypes = new string[] { AnyContentType, }; + } + + for (var j = 0; j < contentTypes.Count; j++) + { + var contentType = contentTypes[j]; + + if (!edges.ContainsKey(contentType)) + { + edges.Add(contentType, new List()); + } + } + } + + // Now in a second loop, add endpoints to these lists. We've enumerated all of + // the states, so we want to see which states this endpoint matches. + for (var i = 0; i < endpoints.Count; i++) + { + var endpoint = endpoints[i]; + var contentTypes = endpoint.Metadata.GetMetadata()?.ContentTypes ?? Array.Empty(); + if (contentTypes.Count == 0) + { + // OK this means that this endpoint matches *all* content methods. + // So, loop and add it to all states. + foreach (var kvp in edges) + { + kvp.Value.Add(endpoint); + } + } + else + { + // OK this endpoint matches specific content types -- we have to loop through edges here + // because content types could either be exact (like 'application/json') or they + // could have wildcards (like 'text/*'). We don't expect wildcards to be especially common + // with consumes, but we need to support it. + foreach (var kvp in edges) + { + // The edgeKey maps to a possible request header value + var edgeKey = new ReadOnlyMediaTypeHeaderValue(kvp.Key); + + for (var j = 0; j < contentTypes.Count; j++) + { + var contentType = contentTypes[j]; + + var mediaType = new ReadOnlyMediaTypeHeaderValue(contentType); + + // Example: 'application/json' is subset of 'application/*' + // + // This means that when the request has content-type 'application/json' an endpoint + // what consumes 'application/*' should match. + if (edgeKey.IsSubsetOf(mediaType)) + { + kvp.Value.Add(endpoint); + + // It's possible that a ConsumesMetadata defines overlapping wildcards. Don't add an endpoint + // to any edge twice + break; + } + } + } + } + } + + // If after we're done there isn't any endpoint that accepts */*, then we'll synthesize an + // endpoint that always returns a 415. + if (!edges.TryGetValue(AnyContentType, out var anyEndpoints)) + { + edges.Add(AnyContentType, new List() + { + CreateRejectionEndpoint(), + }); + + // Add a node to use when there is no request content type. + // When there is no content type we want the policy to no-op + edges.Add(string.Empty, endpoints.ToList()); + } + else + { + // If there is an endpoint that accepts */* then it is also used when there is no content type + edges.Add(string.Empty, anyEndpoints.ToList()); + } + + + return edges + .Select(kvp => new PolicyNodeEdge(kvp.Key, kvp.Value)) + .ToArray(); + } + + private Endpoint CreateRejectionEndpoint() + { + return new Endpoint( + (context) => + { + context.Response.StatusCode = StatusCodes.Status415UnsupportedMediaType; + return Task.CompletedTask; + }, + EndpointMetadataCollection.Empty, + Http415EndpointDisplayName); + } + + public PolicyJumpTable BuildJumpTable(int exitDestination, IReadOnlyList edges) + { + if (edges == null) + { + throw new ArgumentNullException(nameof(edges)); + } + + // Since our 'edges' can have wildcards, we do a sort based on how wildcard-ey they + // are then then execute them in linear order. + var ordered = edges + .Select(e => (mediaType: CreateEdgeMediaType(ref e), destination: e.Destination)) + .OrderBy(e => GetScore(e.mediaType)) + .ToArray(); + + // If any edge matches all content types, then treat that as the 'exit'. This will + // always happen because we insert a 415 endpoint. + for (var i = 0; i < ordered.Length; i++) + { + if (ordered[i].mediaType.MatchesAllTypes) + { + exitDestination = ordered[i].destination; + break; + } + } + + var noContentTypeDestination = GetNoContentTypeDestination(ordered); + + return new ConsumesPolicyJumpTable(exitDestination, noContentTypeDestination, ordered); + } + + private static int GetNoContentTypeDestination((ReadOnlyMediaTypeHeaderValue mediaType, int destination)[] destinations) + { + for (var i = 0; i < destinations.Length; i++) + { + var mediaType = destinations[i].mediaType; + + if (!mediaType.Type.HasValue) + { + return destinations[i].destination; + } + } + + throw new InvalidOperationException("Could not find destination for no content type."); + } + + private static ReadOnlyMediaTypeHeaderValue CreateEdgeMediaType(ref PolicyJumpTableEdge e) + { + var mediaType = (string)e.State; + return !string.IsNullOrEmpty(mediaType) ? new ReadOnlyMediaTypeHeaderValue(mediaType) : default; + } + + private static int GetScore(ReadOnlyMediaTypeHeaderValue mediaType) + { + // Higher score == lower priority - see comments on MediaType. + if (mediaType.MatchesAllTypes) + { + return 4; + } + else if (mediaType.MatchesAllSubTypes) + { + return 3; + } + else if (mediaType.MatchesAllSubTypesWithoutSuffix) + { + return 2; + } + else + { + return 1; + } + } + + private sealed class ConsumesMetadataEndpointComparer : EndpointMetadataComparer + { + protected override int CompareMetadata(IAcceptsMetadata? x, IAcceptsMetadata? y) + { + // Ignore the metadata if it has an empty list of content types. + return base.CompareMetadata( + x?.ContentTypes.Count > 0 ? x : null, + y?.ContentTypes.Count > 0 ? y : null); + } + } + + private sealed class ConsumesPolicyJumpTable : PolicyJumpTable + { + private readonly (ReadOnlyMediaTypeHeaderValue mediaType, int destination)[] _destinations; + private readonly int _exitDestination; + private readonly int _noContentTypeDestination; + + public ConsumesPolicyJumpTable(int exitDestination, int noContentTypeDestination, (ReadOnlyMediaTypeHeaderValue mediaType, int destination)[] destinations) + { + _exitDestination = exitDestination; + _noContentTypeDestination = noContentTypeDestination; + _destinations = destinations; + } + + public override int GetDestination(HttpContext httpContext) + { + var contentType = httpContext.Request.ContentType; + + if (string.IsNullOrEmpty(contentType)) + { + return _noContentTypeDestination; + } + + var requestMediaType = new ReadOnlyMediaTypeHeaderValue(contentType); + var destinations = _destinations; + for (var i = 0; i < destinations.Length; i++) + { + + var destination = destinations[i].mediaType; + if (requestMediaType.IsSubsetOf(destination)) + { + return destinations[i].destination; + } + } + + return _exitDestination; + } + } +} diff --git a/src/Http/Routing/src/Microsoft.AspNetCore.Routing.csproj b/src/Http/Routing/src/Microsoft.AspNetCore.Routing.csproj index e92de9551933..20e688ff0daf 100644 --- a/src/Http/Routing/src/Microsoft.AspNetCore.Routing.csproj +++ b/src/Http/Routing/src/Microsoft.AspNetCore.Routing.csproj @@ -1,10 +1,12 @@ - + - ASP.NET Core middleware for routing requests to application logic and for generating links. -Commonly used types: -Microsoft.AspNetCore.Routing.Route -Microsoft.AspNetCore.Routing.RouteCollection + + ASP.NET Core middleware for routing requests to application logic and for generating links. + Commonly used types: + Microsoft.AspNetCore.Routing.Route + Microsoft.AspNetCore.Routing.RouteCollection + $(DefaultNetCoreTargetFramework) true true @@ -26,6 +28,8 @@ Microsoft.AspNetCore.Routing.RouteCollection + + diff --git a/src/Mvc/Mvc.Core/test/Routing/ConsumesMatcherPolicyTest.cs b/src/Http/Routing/test/UnitTests/Matching/AcceptsMatcherPolicyTest.cs similarity index 85% rename from src/Mvc/Mvc.Core/test/Routing/ConsumesMatcherPolicyTest.cs rename to src/Http/Routing/test/UnitTests/Matching/AcceptsMatcherPolicyTest.cs index 8575e3d9ed3d..dea6bc6bd1d1 100644 --- a/src/Mvc/Mvc.Core/test/Routing/ConsumesMatcherPolicyTest.cs +++ b/src/Http/Routing/test/UnitTests/Matching/AcceptsMatcherPolicyTest.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. using System; @@ -6,16 +6,16 @@ using System.Linq; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Metadata; using Microsoft.AspNetCore.Routing; -using Microsoft.AspNetCore.Routing.Matching; using Microsoft.AspNetCore.Routing.Patterns; using Xunit; -namespace Microsoft.AspNetCore.Mvc.Routing +namespace Microsoft.AspNetCore.Routing.Matching { // There are some unit tests here for the IEndpointSelectorPolicy implementation. // The INodeBuilderPolicy implementation is well-tested by functional tests. - public class ConsumesMatcherPolicyTest + public class AcceptsMatcherPolicyTest { [Fact] public void INodeBuilderPolicy_AppliesToEndpoints_EndpointWithoutMetadata_ReturnsFalse() @@ -38,7 +38,7 @@ public void INodeBuilderPolicy_AppliesToEndpoints_EndpointWithoutContentTypes_Re // Arrange var endpoints = new[] { - CreateEndpoint("/", new ConsumesMetadata(Array.Empty())), + CreateEndpoint("/", new AcceptsMetadata(Array.Empty())), }; var policy = (INodeBuilderPolicy)CreatePolicy(); @@ -56,8 +56,8 @@ public void INodeBuilderPolicy_AppliesToEndpoints_EndpointHasContentTypes_Return // Arrange var endpoints = new[] { - CreateEndpoint("/", new ConsumesMetadata(Array.Empty())), - CreateEndpoint("/", new ConsumesMetadata(new[] { "application/json", })), + CreateEndpoint("/", new AcceptsMetadata(Array.Empty())), + CreateEndpoint("/", new AcceptsMetadata(new[] { "application/json", })), }; var policy = (INodeBuilderPolicy)CreatePolicy(); @@ -75,8 +75,8 @@ public void INodeBuilderPolicy_AppliesToEndpoints_WithDynamicMetadata_ReturnsFal // Arrange var endpoints = new[] { - CreateEndpoint("/", new ConsumesMetadata(Array.Empty()), new DynamicEndpointMetadata()), - CreateEndpoint("/", new ConsumesMetadata(new[] { "application/json", })), + CreateEndpoint("/", new AcceptsMetadata(Array.Empty()), new DynamicEndpointMetadata()), + CreateEndpoint("/", new AcceptsMetadata(new[] { "application/json", })), }; var policy = (INodeBuilderPolicy)CreatePolicy(); @@ -109,7 +109,7 @@ public void IEndpointSelectorPolicy_AppliesToEndpoints_EndpointWithoutContentTyp // Arrange var endpoints = new[] { - CreateEndpoint("/", new ConsumesMetadata(Array.Empty()), new DynamicEndpointMetadata()), + CreateEndpoint("/", new AcceptsMetadata(Array.Empty()), new DynamicEndpointMetadata()), }; var policy = (IEndpointSelectorPolicy)CreatePolicy(); @@ -127,8 +127,8 @@ public void IEndpointSelectorPolicy_AppliesToEndpoints_EndpointHasContentTypes_R // Arrange var endpoints = new[] { - CreateEndpoint("/", new ConsumesMetadata(Array.Empty()), new DynamicEndpointMetadata()), - CreateEndpoint("/", new ConsumesMetadata(new[] { "application/json", })), + CreateEndpoint("/", new AcceptsMetadata(Array.Empty()), new DynamicEndpointMetadata()), + CreateEndpoint("/", new AcceptsMetadata(new[] { "application/json", })), }; var policy = (IEndpointSelectorPolicy)CreatePolicy(); @@ -146,8 +146,8 @@ public void IEndpointSelectorPolicy_AppliesToEndpoints_WithoutDynamicMetadata_Re // Arrange var endpoints = new[] { - CreateEndpoint("/", new ConsumesMetadata(Array.Empty())), - CreateEndpoint("/", new ConsumesMetadata(new[] { "application/json", })), + CreateEndpoint("/", new AcceptsMetadata(Array.Empty())), + CreateEndpoint("/", new AcceptsMetadata(new[] { "application/json", })), }; var policy = (IEndpointSelectorPolicy)CreatePolicy(); @@ -167,11 +167,11 @@ public void GetEdges_GroupsByContentType() { // These are arrange in an order that we won't actually see in a product scenario. It's done // this way so we can verify that ordering is preserved by GetEdges. - CreateEndpoint("/", new ConsumesMetadata(new[] { "application/json", "application/*+json", })), - CreateEndpoint("/", new ConsumesMetadata(Array.Empty())), - CreateEndpoint("/", new ConsumesMetadata(new[] { "application/xml", "application/*+xml", })), - CreateEndpoint("/", new ConsumesMetadata(new[] { "application/*", })), - CreateEndpoint("/", new ConsumesMetadata(new[]{ "*/*", })), + CreateEndpoint("/", new AcceptsMetadata(new[] { "application/json", "application/*+json", })), + CreateEndpoint("/", new AcceptsMetadata(Array.Empty())), + CreateEndpoint("/", new AcceptsMetadata(new[] { "application/xml", "application/*+xml", })), + CreateEndpoint("/", new AcceptsMetadata(new[] { "application/*", })), + CreateEndpoint("/", new AcceptsMetadata(new[]{ "*/*", })), }; var policy = CreatePolicy(); @@ -227,9 +227,9 @@ public void GetEdges_GroupsByContentType_CreatesHttp415Endpoint() { // These are arrange in an order that we won't actually see in a product scenario. It's done // this way so we can verify that ordering is preserved by GetEdges. - CreateEndpoint("/", new ConsumesMetadata(new[] { "application/json", "application/*+json", })), - CreateEndpoint("/", new ConsumesMetadata(new[] { "application/xml", "application/*+xml", })), - CreateEndpoint("/", new ConsumesMetadata(new[] { "application/*", })), + CreateEndpoint("/", new AcceptsMetadata(new[] { "application/json", "application/*+json", })), + CreateEndpoint("/", new AcceptsMetadata(new[] { "application/xml", "application/*+xml", })), + CreateEndpoint("/", new AcceptsMetadata(new[] { "application/*", })), }; var policy = CreatePolicy(); @@ -248,7 +248,7 @@ public void GetEdges_GroupsByContentType_CreatesHttp415Endpoint() e => { Assert.Equal("*/*", e.State); - Assert.Equal(ConsumesMatcherPolicy.Http415EndpointDisplayName, Assert.Single(e.Endpoints).DisplayName); + Assert.Equal(AcceptsMatcherPolicy.Http415EndpointDisplayName, Assert.Single(e.Endpoints).DisplayName); }, e => { @@ -343,7 +343,7 @@ public async Task ApplyAsync_EndpointAllowsAnyContentType_MatchWithoutContentTyp // Arrange var endpoints = new[] { - CreateEndpoint("/", new ConsumesMetadata(Array.Empty())), + CreateEndpoint("/", new AcceptsMetadata(Array.Empty())), }; var candidates = CreateCandidateSet(endpoints); @@ -364,7 +364,7 @@ public async Task ApplyAsync_EndpointHasWildcardContentType_MatchWithoutContentT // Arrange var endpoints = new[] { - CreateEndpoint("/", new ConsumesMetadata(new string[] { "*/*" })), + CreateEndpoint("/", new AcceptsMetadata(new string[] { "*/*" })), }; var candidates = CreateCandidateSet(endpoints); @@ -412,7 +412,7 @@ public async Task ApplyAsync_EndpointAllowsAnyContentType_MatchWithAnyContentTyp // Arrange var endpoints = new[] { - CreateEndpoint("/", new ConsumesMetadata(Array.Empty())), + CreateEndpoint("/", new AcceptsMetadata(Array.Empty())), }; var candidates = CreateCandidateSet(endpoints); @@ -439,7 +439,7 @@ public async Task ApplyAsync_EndpointHasWildcardContentType_MatchWithAnyContentT // Arrange var endpoints = new[] { - CreateEndpoint("/", new ConsumesMetadata(new string[] { "*/*" })), + CreateEndpoint("/", new AcceptsMetadata(new string[] { "*/*" })), }; var candidates = CreateCandidateSet(endpoints); @@ -466,7 +466,7 @@ public async Task ApplyAsync_EndpointHasSubTypeWildcard_MatchWithValidContentTyp // Arrange var endpoints = new[] { - CreateEndpoint("/", new ConsumesMetadata(new string[] { "application/*+json", })), + CreateEndpoint("/", new AcceptsMetadata(new string[] { "application/*+json", })), }; var candidates = CreateCandidateSet(endpoints); @@ -493,7 +493,7 @@ public async Task ApplyAsync_EndpointHasMultipleContentType_MatchWithValidConten // Arrange var endpoints = new[] { - CreateEndpoint("/", new ConsumesMetadata(new string[] { "text/xml", "application/xml", })), + CreateEndpoint("/", new AcceptsMetadata(new string[] { "text/xml", "application/xml", })), }; var candidates = CreateCandidateSet(endpoints); @@ -520,7 +520,7 @@ public async Task ApplyAsync_EndpointDoesNotMatch_Returns415() // Arrange var endpoints = new[] { - CreateEndpoint("/", new ConsumesMetadata(new string[] { "text/xml", "application/xml", })), + CreateEndpoint("/", new AcceptsMetadata(new string[] { "text/xml", "application/xml", })), }; var candidates = CreateCandidateSet(endpoints); @@ -548,7 +548,7 @@ public async Task ApplyAsync_EndpointDoesNotMatch_DoesNotReturns415WithContentTy // Arrange var endpoints = new[] { - CreateEndpoint("/", new ConsumesMetadata(new string[] { "text/xml", "application/xml", })), + CreateEndpoint("/", new AcceptsMetadata(new string[] { "text/xml", "application/xml", })), CreateEndpoint("/", null) }; @@ -577,8 +577,8 @@ public async Task ApplyAsync_EndpointDoesNotMatch_DoesNotReturns415WithContentTy // Arrange var endpoints = new[] { - CreateEndpoint("/", new ConsumesMetadata(new string[] { "text/xml", "application/xml", })), - CreateEndpoint("/", new ConsumesMetadata(new string[] { "*/*", })) + CreateEndpoint("/", new AcceptsMetadata(new string[] { "text/xml", "application/xml", })), + CreateEndpoint("/", new AcceptsMetadata(new string[] { "*/*", })) }; var candidates = CreateCandidateSet(endpoints); @@ -601,7 +601,7 @@ public async Task ApplyAsync_EndpointDoesNotMatch_DoesNotReturns415WithContentTy Assert.Null(httpContext.GetEndpoint()); } - private static RouteEndpoint CreateEndpoint(string template, ConsumesMetadata consumesMetadata, params object[] more) + private static RouteEndpoint CreateEndpoint(string template, AcceptsMetadata consumesMetadata, params object[] more) { var metadata = new List(); if (consumesMetadata != null) @@ -627,9 +627,9 @@ private static CandidateSet CreateCandidateSet(Endpoint[] endpoints) return new CandidateSet(endpoints, new RouteValueDictionary[endpoints.Length], new int[endpoints.Length]); } - private static ConsumesMatcherPolicy CreatePolicy() + private static AcceptsMatcherPolicy CreatePolicy() { - return new ConsumesMatcherPolicy(); + return new AcceptsMatcherPolicy(); } private class DynamicEndpointMetadata : IDynamicEndpointMetadata diff --git a/src/Http/Routing/test/UnitTests/Microsoft.AspNetCore.Routing.Tests.csproj b/src/Http/Routing/test/UnitTests/Microsoft.AspNetCore.Routing.Tests.csproj index 9d9344c05084..7b6df8ebded4 100644 --- a/src/Http/Routing/test/UnitTests/Microsoft.AspNetCore.Routing.Tests.csproj +++ b/src/Http/Routing/test/UnitTests/Microsoft.AspNetCore.Routing.Tests.csproj @@ -1,4 +1,4 @@ - + $(DefaultNetCoreTargetFramework) @@ -14,4 +14,8 @@ + + + + diff --git a/src/Mvc/Mvc.ApiExplorer/src/DefaultApiDescriptionProvider.cs b/src/Mvc/Mvc.ApiExplorer/src/DefaultApiDescriptionProvider.cs index 5a87ad2089fe..5a5d82e1ff6c 100644 --- a/src/Mvc/Mvc.ApiExplorer/src/DefaultApiDescriptionProvider.cs +++ b/src/Mvc/Mvc.ApiExplorer/src/DefaultApiDescriptionProvider.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using Microsoft.AspNetCore.Http.Metadata; using Microsoft.AspNetCore.Mvc.Abstractions; using Microsoft.AspNetCore.Mvc.ActionConstraints; using Microsoft.AspNetCore.Mvc.Controllers; diff --git a/src/Mvc/Mvc.ApiExplorer/test/EndpointMetadataApiDescriptionProviderTest.cs b/src/Mvc/Mvc.ApiExplorer/test/EndpointMetadataApiDescriptionProviderTest.cs index 707a80f8c8e8..007989fcb106 100644 --- a/src/Mvc/Mvc.ApiExplorer/test/EndpointMetadataApiDescriptionProviderTest.cs +++ b/src/Mvc/Mvc.ApiExplorer/test/EndpointMetadataApiDescriptionProviderTest.cs @@ -598,6 +598,39 @@ public void HandleMultipleProduces() }); } + [Fact] + public void HandleAcceptsMetadata() + { + // Arrange + var builder = new TestEndpointRouteBuilder(new ApplicationBuilder(null)); + builder.MapPost("/api/todos", () => "") + .Accepts("application/json", "application/xml"); + var context = new ApiDescriptionProviderContext(Array.Empty()); + + var endpointDataSource = builder.DataSources.OfType().Single(); + var hostEnvironment = new HostEnvironment + { + ApplicationName = nameof(EndpointMetadataApiDescriptionProviderTest) + }; + var provider = new EndpointMetadataApiDescriptionProvider(endpointDataSource, hostEnvironment, new ServiceProviderIsService()); + + // Act + provider.OnProvidersExecuting(context); + provider.OnProvidersExecuted(context); + + // Assert + Assert.Collection( + context.Results.SelectMany(r => r.SupportedRequestFormats), + requestType => + { + Assert.Equal("application/json", requestType.MediaType); + }, + requestType => + { + Assert.Equal("application/xml", requestType.MediaType); + }); + } + private static IEnumerable GetSortedMediaTypes(ApiResponseType apiResponseType) { return apiResponseType.ApiResponseFormats diff --git a/src/Mvc/Mvc.Core/src/ApiExplorer/IApiRequestMetadataProvider.cs b/src/Mvc/Mvc.Core/src/ApiExplorer/IApiRequestMetadataProvider.cs index 1df6dc86aca2..45b7b1ffe9c7 100644 --- a/src/Mvc/Mvc.Core/src/ApiExplorer/IApiRequestMetadataProvider.cs +++ b/src/Mvc/Mvc.Core/src/ApiExplorer/IApiRequestMetadataProvider.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using Microsoft.AspNetCore.Http.Metadata; using Microsoft.AspNetCore.Mvc.Filters; using Microsoft.AspNetCore.Mvc.Formatters; @@ -17,4 +18,4 @@ public interface IApiRequestMetadataProvider : IFilterMetadata /// The void SetContentTypes(MediaTypeCollection contentTypes); } -} \ No newline at end of file +} diff --git a/src/Mvc/Mvc.Core/src/Builder/OpenApiEndpointConventionBuilderExtensions.cs b/src/Mvc/Mvc.Core/src/Builder/OpenApiEndpointConventionBuilderExtensions.cs index 78c9b6a826db..dbb4326f4465 100644 --- a/src/Mvc/Mvc.Core/src/Builder/OpenApiEndpointConventionBuilderExtensions.cs +++ b/src/Mvc/Mvc.Core/src/Builder/OpenApiEndpointConventionBuilderExtensions.cs @@ -1,8 +1,12 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Net.Http.Headers; using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http.Metadata; using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Core; +using Microsoft.AspNetCore.Mvc.Formatters; using Microsoft.AspNetCore.Routing; namespace Microsoft.AspNetCore.Http @@ -41,7 +45,7 @@ public static MinimalActionEndpointConventionBuilder ExcludeFromDescription(this public static MinimalActionEndpointConventionBuilder Produces(this MinimalActionEndpointConventionBuilder builder, #pragma warning restore RS0026 int statusCode = StatusCodes.Status200OK, - string? contentType = null, + string? contentType = null, params string[] additionalContentTypes) { return Produces(builder, statusCode, typeof(TResponse), contentType, additionalContentTypes); @@ -120,5 +124,37 @@ public static MinimalActionEndpointConventionBuilder ProducesValidationProblem(t return Produces(builder, statusCode, contentType); } + + /// + /// Adds the to for all builders + /// produced by . + /// + /// The type of the request. + /// The . + /// The request content type. Defaults to "application/json" if empty. + /// Additional response content types the endpoint produces for the supplied status code. + /// A that can be used to further customize the endpoint. + public static MinimalActionEndpointConventionBuilder Accepts(this MinimalActionEndpointConventionBuilder builder, string contentType, params string[] additionalContentTypes) + { + Accepts(builder, typeof(TRequest), contentType, additionalContentTypes); + + return builder; + } + + /// + /// Adds the to for all builders + /// produced by . + /// + /// The . + /// The type of the request. Defaults to null. + /// The response content type that the endpoint accepts. + /// Additional response content types the endpoint accepts + /// A that can be used to further customize the endpoint. + public static MinimalActionEndpointConventionBuilder Accepts(this MinimalActionEndpointConventionBuilder builder, Type requestType, + string contentType, params string[] additionalContentTypes) + { + builder.WithMetadata(new ConsumesAttribute(requestType, contentType, additionalContentTypes)); + return builder; + } } } diff --git a/src/Mvc/Mvc.Core/src/ConsumesAttribute.cs b/src/Mvc/Mvc.Core/src/ConsumesAttribute.cs index ea4e3a6ae635..817e54a724d5 100644 --- a/src/Mvc/Mvc.Core/src/ConsumesAttribute.cs +++ b/src/Mvc/Mvc.Core/src/ConsumesAttribute.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using Microsoft.AspNetCore.Http.Metadata; using Microsoft.AspNetCore.Mvc.Abstractions; using Microsoft.AspNetCore.Mvc.ActionConstraints; using Microsoft.AspNetCore.Mvc.ApiExplorer; @@ -23,7 +24,8 @@ public class ConsumesAttribute : Attribute, IResourceFilter, IConsumesActionConstraint, - IApiRequestMetadataProvider + IApiRequestMetadataProvider, + IAcceptsMetadata { /// /// The order for consumes attribute. @@ -33,6 +35,8 @@ public class ConsumesAttribute : /// /// Creates a new instance of . + /// The request content type + /// The additional list of allowed request content types /// public ConsumesAttribute(string contentType, params string[] otherContentTypes) { @@ -53,6 +57,34 @@ public ConsumesAttribute(string contentType, params string[] otherContentTypes) ContentTypes = GetContentTypes(contentType, otherContentTypes); } + /// + /// Creates a new instance of . + /// The type being read from the request + /// The request content type + /// The additional list of allowed request content types + /// + public ConsumesAttribute(Type requestType, string contentType, params string[] otherContentTypes) + { + if (contentType == null) + { + throw new ArgumentNullException(nameof(contentType)); + } + + // We want to ensure that the given provided content types are valid values, so + // we validate them using the semantics of MediaTypeHeaderValue. + MediaTypeHeaderValue.Parse(contentType); + + for (var i = 0; i < otherContentTypes.Length; i++) + { + MediaTypeHeaderValue.Parse(otherContentTypes[i]); + } + + ContentTypes = GetContentTypes(contentType, otherContentTypes); + _contentTypes = GetAllContentTypes(contentType, otherContentTypes); + _requestType = requestType; + + } + // The value used is a non default value so that it avoids getting mixed with other action constraints // with default order. /// @@ -64,6 +96,14 @@ public ConsumesAttribute(string contentType, params string[] otherContentTypes) /// public MediaTypeCollection ContentTypes { get; set; } + readonly Type? _requestType; + + readonly List _contentTypes = new(); + + Type? IAcceptsMetadata.RequestType => _requestType; + + IReadOnlyList IAcceptsMetadata.ContentTypes => _contentTypes; + /// public void OnResourceExecuting(ResourceExecutingContext context) { @@ -223,6 +263,16 @@ private MediaTypeCollection GetContentTypes(string firstArg, string[] args) return contentTypes; } + private static List GetAllContentTypes(string contentType, string[] additionalContentTypes) + { + var allContentTypes = new List() + { + contentType + }; + allContentTypes.AddRange(additionalContentTypes); + return allContentTypes; + } + /// public void SetContentTypes(MediaTypeCollection contentTypes) { diff --git a/src/Mvc/Mvc.Core/src/DependencyInjection/MvcCoreServiceCollectionExtensions.cs b/src/Mvc/Mvc.Core/src/DependencyInjection/MvcCoreServiceCollectionExtensions.cs index 27157062d7c8..0c9d81fbf9ac 100644 --- a/src/Mvc/Mvc.Core/src/DependencyInjection/MvcCoreServiceCollectionExtensions.cs +++ b/src/Mvc/Mvc.Core/src/DependencyInjection/MvcCoreServiceCollectionExtensions.cs @@ -19,6 +19,7 @@ using Microsoft.AspNetCore.Mvc.ModelBinding.Validation; using Microsoft.AspNetCore.Mvc.Routing; using Microsoft.AspNetCore.Routing; +using Microsoft.AspNetCore.Routing.Matching; using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Options; @@ -176,7 +177,6 @@ internal static void AddMvcCoreServices(IServiceCollection services) services.TryAddEnumerable(ServiceDescriptor.Transient()); // Policies for Endpoints - services.TryAddEnumerable(ServiceDescriptor.Singleton()); services.TryAddEnumerable(ServiceDescriptor.Singleton()); // diff --git a/src/Mvc/Mvc.Core/src/Formatters/AcceptHeaderParser.cs b/src/Mvc/Mvc.Core/src/Formatters/AcceptHeaderParser.cs index 2928a83f0cd2..befb3721b320 100644 --- a/src/Mvc/Mvc.Core/src/Formatters/AcceptHeaderParser.cs +++ b/src/Mvc/Mvc.Core/src/Formatters/AcceptHeaderParser.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using Microsoft.AspNetCore.Http.Headers; namespace Microsoft.AspNetCore.Mvc.Formatters { diff --git a/src/Mvc/Mvc.Core/src/Formatters/HttpParseResult.cs b/src/Mvc/Mvc.Core/src/Formatters/HttpParseResult.cs deleted file mode 100644 index 9969b949d847..000000000000 --- a/src/Mvc/Mvc.Core/src/Formatters/HttpParseResult.cs +++ /dev/null @@ -1,12 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace Microsoft.AspNetCore.Mvc.Formatters -{ - internal enum HttpParseResult - { - Parsed, - NotParsed, - InvalidFormat, - } -} diff --git a/src/Mvc/Mvc.Core/src/Formatters/HttpTokenParsingRules.cs b/src/Mvc/Mvc.Core/src/Formatters/HttpTokenParsingRules.cs deleted file mode 100644 index df6d7d8e873e..000000000000 --- a/src/Mvc/Mvc.Core/src/Formatters/HttpTokenParsingRules.cs +++ /dev/null @@ -1,270 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Diagnostics; -using System.Text; - -namespace Microsoft.AspNetCore.Mvc.Formatters -{ - internal static class HttpTokenParsingRules - { - private static readonly bool[] TokenChars = CreateTokenChars(); - private const int MaxNestedCount = 5; - - internal const char CR = '\r'; - internal const char LF = '\n'; - internal const char SP = ' '; - internal const char Tab = '\t'; - internal const int MaxInt64Digits = 19; - internal const int MaxInt32Digits = 10; - - // iso-8859-1, Western European (ISO) - internal static readonly Encoding DefaultHttpEncoding = Encoding.GetEncoding(28591); - - private static bool[] CreateTokenChars() - { - // token = 1* - // CTL = - - var tokenChars = new bool[128]; // everything is false - - for (var i = 33; i < 127; i++) // skip Space (32) & DEL (127) - { - tokenChars[i] = true; - } - - // remove separators: these are not valid token characters - tokenChars[(byte)'('] = false; - tokenChars[(byte)')'] = false; - tokenChars[(byte)'<'] = false; - tokenChars[(byte)'>'] = false; - tokenChars[(byte)'@'] = false; - tokenChars[(byte)','] = false; - tokenChars[(byte)';'] = false; - tokenChars[(byte)':'] = false; - tokenChars[(byte)'\\'] = false; - tokenChars[(byte)'"'] = false; - tokenChars[(byte)'/'] = false; - tokenChars[(byte)'['] = false; - tokenChars[(byte)']'] = false; - tokenChars[(byte)'?'] = false; - tokenChars[(byte)'='] = false; - tokenChars[(byte)'{'] = false; - tokenChars[(byte)'}'] = false; - - return tokenChars; - } - - internal static bool IsTokenChar(char character) - { - // Must be between 'space' (32) and 'DEL' (127) - if (character > 127) - { - return false; - } - - return TokenChars[character]; - } - - internal static int GetTokenLength(string input, int startIndex) - { - Debug.Assert(input != null); - - if (startIndex >= input.Length) - { - return 0; - } - - var current = startIndex; - - while (current < input.Length) - { - if (!IsTokenChar(input[current])) - { - return current - startIndex; - } - current++; - } - return input.Length - startIndex; - } - - internal static int GetWhitespaceLength(string input, int startIndex) - { - Debug.Assert(input != null); - - if (startIndex >= input.Length) - { - return 0; - } - - var current = startIndex; - - while (current < input.Length) - { - var c = input[current]; - - if ((c == SP) || (c == Tab)) - { - current++; - continue; - } - - if (c == CR) - { - // If we have a #13 char, it must be followed by #10 and then at least one SP or HT. - if ((current + 2 < input.Length) && (input[current + 1] == LF)) - { - char spaceOrTab = input[current + 2]; - if ((spaceOrTab == SP) || (spaceOrTab == Tab)) - { - current += 3; - continue; - } - } - } - - return current - startIndex; - } - - // All characters between startIndex and the end of the string are LWS characters. - return input.Length - startIndex; - } - - internal static HttpParseResult GetQuotedStringLength(string input, int startIndex, out int length) - { - var nestedCount = 0; - return GetExpressionLength(input, startIndex, '"', '"', false, ref nestedCount, out length); - } - - // quoted-pair = "\" CHAR - // CHAR = - internal static HttpParseResult GetQuotedPairLength(string input, int startIndex, out int length) - { - Debug.Assert(input != null); - Debug.Assert((startIndex >= 0) && (startIndex < input.Length)); - - length = 0; - - if (input[startIndex] != '\\') - { - return HttpParseResult.NotParsed; - } - - // Quoted-char has 2 characters. Check whether there are 2 chars left ('\' + char) - // If so, check whether the character is in the range 0-127. If not, it's an invalid value. - if ((startIndex + 2 > input.Length) || (input[startIndex + 1] > 127)) - { - return HttpParseResult.InvalidFormat; - } - - // We don't care what the char next to '\' is. - length = 2; - return HttpParseResult.Parsed; - } - - // TEXT = - // LWS = [CRLF] 1*( SP | HT ) - // CTL = - // - // Since we don't really care about the content of a quoted string or comment, we're more tolerant and - // allow these characters. We only want to find the delimiters ('"' for quoted string and '(', ')' for comment). - // - // 'nestedCount': Comments can be nested. We allow a depth of up to 5 nested comments, i.e. something like - // "(((((comment)))))". If we wouldn't define a limit an attacker could send a comment with hundreds of nested - // comments, resulting in a stack overflow exception. In addition having more than 1 nested comment (if any) - // is unusual. - private static HttpParseResult GetExpressionLength( - string input, - int startIndex, - char openChar, - char closeChar, - bool supportsNesting, - ref int nestedCount, - out int length) - { - Debug.Assert(input != null); - Debug.Assert((startIndex >= 0) && (startIndex < input.Length)); - - length = 0; - - if (input[startIndex] != openChar) - { - return HttpParseResult.NotParsed; - } - - var current = startIndex + 1; // Start parsing with the character next to the first open-char - while (current < input.Length) - { - // Only check whether we have a quoted char, if we have at least 3 characters left to read (i.e. - // quoted char + closing char). Otherwise the closing char may be considered part of the quoted char. - if ((current + 2 < input.Length) && - (GetQuotedPairLength(input, current, out var quotedPairLength) == HttpParseResult.Parsed)) - { - // We ignore invalid quoted-pairs. Invalid quoted-pairs may mean that it looked like a quoted pair, - // but we actually have a quoted-string: e.g. "\ü" ('\' followed by a char >127 - quoted-pair only - // allows ASCII chars after '\'; qdtext allows both '\' and >127 chars). - current = current + quotedPairLength; - continue; - } - - // If we support nested expressions and we find an open-char, then parse the nested expressions. - if (supportsNesting && (input[current] == openChar)) - { - nestedCount++; - try - { - // Check if we exceeded the number of nested calls. - if (nestedCount > MaxNestedCount) - { - return HttpParseResult.InvalidFormat; - } - - var nestedResult = GetExpressionLength( - input, - current, - openChar, - closeChar, - supportsNesting, - ref nestedCount, - out var nestedLength); - - switch (nestedResult) - { - case HttpParseResult.Parsed: - current += nestedLength; // add the length of the nested expression and continue. - break; - - case HttpParseResult.NotParsed: - Debug.Fail("'NotParsed' is unexpected: We started nested expression " + - "parsing, because we found the open-char. So either it's a valid nested " + - "expression or it has invalid format."); - break; - - case HttpParseResult.InvalidFormat: - // If the nested expression is invalid, we can't continue, so we fail with invalid format. - return HttpParseResult.InvalidFormat; - - default: - Debug.Fail("Unknown enum result: " + nestedResult); - break; - } - } - finally - { - nestedCount--; - } - } - - if (input[current] == closeChar) - { - length = current - startIndex + 1; - return HttpParseResult.Parsed; - } - current++; - } - - // We didn't see the final quote, therefore we have an invalid expression string. - return HttpParseResult.InvalidFormat; - } - } -} diff --git a/src/Mvc/Mvc.Core/src/Formatters/MediaType.cs b/src/Mvc/Mvc.Core/src/Formatters/MediaType.cs index b7004ca2dd59..c682d5b4f2da 100644 --- a/src/Mvc/Mvc.Core/src/Formatters/MediaType.cs +++ b/src/Mvc/Mvc.Core/src/Formatters/MediaType.cs @@ -5,6 +5,7 @@ using System.Globalization; using System.Text; using Microsoft.AspNetCore.Mvc.Core; +using Microsoft.AspNetCore.Http.Headers; using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.Mvc.Formatters @@ -16,7 +17,7 @@ public readonly struct MediaType { private static readonly StringSegment QualityParameter = new StringSegment("q"); - private readonly MediaTypeParameterParser _parameterParser; + private readonly ReadOnlyMediaTypeHeaderValue _mediaTypeHeaderValue; /// /// Initializes a instance. @@ -67,124 +68,7 @@ public MediaType(string mediaType, int offset, int? length) } } - _parameterParser = default(MediaTypeParameterParser); - - var typeLength = GetTypeLength(mediaType, offset, out var type); - if (typeLength == 0) - { - Type = new StringSegment(); - SubType = new StringSegment(); - SubTypeWithoutSuffix = new StringSegment(); - SubTypeSuffix = new StringSegment(); - return; - } - else - { - Type = type; - } - - var subTypeLength = GetSubtypeLength(mediaType, offset + typeLength, out var subType); - if (subTypeLength == 0) - { - SubType = new StringSegment(); - SubTypeWithoutSuffix = new StringSegment(); - SubTypeSuffix = new StringSegment(); - return; - } - else - { - SubType = subType; - - if (TryGetSuffixLength(subType, out var subtypeSuffixLength)) - { - SubTypeWithoutSuffix = subType.Subsegment(0, subType.Length - subtypeSuffixLength - 1); - SubTypeSuffix = subType.Subsegment(subType.Length - subtypeSuffixLength, subtypeSuffixLength); - } - else - { - SubTypeWithoutSuffix = SubType; - SubTypeSuffix = new StringSegment(); - } - } - - _parameterParser = new MediaTypeParameterParser(mediaType, offset + typeLength + subTypeLength, length); - } - - // All GetXXXLength methods work in the same way. They expect to be on the right position for - // the token they are parsing, for example, the beginning of the media type or the delimiter - // from a previous token, like '/', ';' or '='. - // Each method consumes the delimiter token if any, the leading whitespace, then the given token - // itself, and finally the trailing whitespace. - private static int GetTypeLength(string input, int offset, out StringSegment type) - { - if (offset < 0 || offset >= input.Length) - { - type = default(StringSegment); - return 0; - } - - var current = offset + HttpTokenParsingRules.GetWhitespaceLength(input, offset); - - // Parse the type, i.e. in media type string "/; param1=value1; param2=value2" - var typeLength = HttpTokenParsingRules.GetTokenLength(input, current); - if (typeLength == 0) - { - type = default(StringSegment); - return 0; - } - - type = new StringSegment(input, current, typeLength); - - current += typeLength; - current += HttpTokenParsingRules.GetWhitespaceLength(input, current); - - return current - offset; - } - - private static int GetSubtypeLength(string input, int offset, out StringSegment subType) - { - var current = offset; - - // Parse the separator between type and subtype - if (current < 0 || current >= input.Length || input[current] != '/') - { - subType = default(StringSegment); - return 0; - } - - current++; // skip delimiter. - current += HttpTokenParsingRules.GetWhitespaceLength(input, current); - - var subtypeLength = HttpTokenParsingRules.GetTokenLength(input, current); - if (subtypeLength == 0) - { - subType = default(StringSegment); - return 0; - } - - subType = new StringSegment(input, current, subtypeLength); - - current += subtypeLength; - current += HttpTokenParsingRules.GetWhitespaceLength(input, current); - - return current - offset; - } - - private static bool TryGetSuffixLength(StringSegment subType, out int suffixLength) - { - // Find the last instance of '+', if there is one - var startPos = subType.Offset + subType.Length - 1; - for (var currentPos = startPos; currentPos >= subType.Offset; currentPos--) - { - if (subType.Buffer[currentPos] == '+') - { - suffixLength = startPos - currentPos; - return true; - } - } - - suffixLength = 0; - return false; + _mediaTypeHeaderValue = new ReadOnlyMediaTypeHeaderValue(mediaType, offset, length); } /// @@ -193,12 +77,12 @@ private static bool TryGetSuffixLength(StringSegment subType, out int suffixLeng /// /// For the media type "application/json", this property gives the value "application". /// - public StringSegment Type { get; } + public StringSegment Type => _mediaTypeHeaderValue.Type; /// /// Gets whether this matches all types. /// - public bool MatchesAllTypes => Type.Equals("*", StringComparison.OrdinalIgnoreCase); + public bool MatchesAllTypes => _mediaTypeHeaderValue.MatchesAllTypes; /// /// Gets the subtype of the . @@ -207,7 +91,7 @@ private static bool TryGetSuffixLength(StringSegment subType, out int suffixLeng /// For the media type "application/vnd.example+json", this property gives the value /// "vnd.example+json". /// - public StringSegment SubType { get; } + public StringSegment SubType => _mediaTypeHeaderValue.SubType; /// /// Gets the subtype of the , excluding any structured syntax suffix. @@ -216,7 +100,7 @@ private static bool TryGetSuffixLength(StringSegment subType, out int suffixLeng /// For the media type "application/vnd.example+json", this property gives the value /// "vnd.example". /// - public StringSegment SubTypeWithoutSuffix { get; } + public StringSegment SubTypeWithoutSuffix => _mediaTypeHeaderValue.SubTypeWithoutSuffix; /// /// Gets the structured syntax suffix of the if it has one. @@ -225,7 +109,7 @@ private static bool TryGetSuffixLength(StringSegment subType, out int suffixLeng /// For the media type "application/vnd.example+json", this property gives the value /// "json". /// - public StringSegment SubTypeSuffix { get; } + public StringSegment SubTypeSuffix => _mediaTypeHeaderValue.SubTypeSuffix; /// /// Gets whether this matches all subtypes. @@ -236,7 +120,7 @@ private static bool TryGetSuffixLength(StringSegment subType, out int suffixLeng /// /// For the media type "application/json", this property is false. /// - public bool MatchesAllSubTypes => SubType.Equals("*", StringComparison.OrdinalIgnoreCase); + public bool MatchesAllSubTypes => _mediaTypeHeaderValue.MatchesAllSubTypes; /// /// Gets whether this matches all subtypes, ignoring any structured syntax suffix. @@ -247,17 +131,17 @@ private static bool TryGetSuffixLength(StringSegment subType, out int suffixLeng /// /// For the media type "application/vnd.example+json", this property is false. /// - public bool MatchesAllSubTypesWithoutSuffix => SubTypeWithoutSuffix.Equals("*", StringComparison.OrdinalIgnoreCase); + public bool MatchesAllSubTypesWithoutSuffix => _mediaTypeHeaderValue.MatchesAllSubTypesWithoutSuffix; /// /// Gets the of the if it has one. /// - public Encoding? Encoding => GetEncodingFromCharset(GetParameter("charset")); + public Encoding? Encoding => _mediaTypeHeaderValue.Encoding; /// /// Gets the charset parameter of the if it has one. /// - public StringSegment Charset => GetParameter("charset"); + public StringSegment Charset => _mediaTypeHeaderValue.Charset; /// /// Determines whether the current contains a wildcard. @@ -265,15 +149,7 @@ private static bool TryGetSuffixLength(StringSegment subType, out int suffixLeng /// /// true if this contains a wildcard; otherwise false. /// - public bool HasWildcard - { - get - { - return MatchesAllTypes || - MatchesAllSubTypesWithoutSuffix || - GetParameter("*").Equals("*", StringComparison.OrdinalIgnoreCase); - } - } + public bool HasWildcard => _mediaTypeHeaderValue.HasWildcard; /// /// Determines whether the current is a subset of the @@ -284,11 +160,7 @@ public bool HasWildcard /// true if this is a subset of ; otherwise false. /// public bool IsSubsetOf(MediaType set) - { - return MatchesType(set) && - MatchesSubtype(set) && - ContainsAllParameters(set._parameterParser); - } + => _mediaTypeHeaderValue.IsSubsetOf(set._mediaTypeHeaderValue); /// /// Gets the parameter of the media type. @@ -299,9 +171,7 @@ public bool IsSubsetOf(MediaType set) /// null. /// public StringSegment GetParameter(string parameterName) - { - return GetParameter(new StringSegment(parameterName)); - } + => _mediaTypeHeaderValue.GetParameter(parameterName); /// /// Gets the parameter of the media type. @@ -312,19 +182,7 @@ public StringSegment GetParameter(string parameterName) /// null. /// public StringSegment GetParameter(StringSegment parameterName) - { - var parametersParser = _parameterParser; - - while (parametersParser.ParseNextParameter(out var parameter)) - { - if (parameter.HasName(parameterName)) - { - return parameter.Value; - } - } - - return new StringSegment(); - } + => _mediaTypeHeaderValue.GetParameter(parameterName); /// /// Replaces the encoding of the given with the provided @@ -404,7 +262,7 @@ public static string ReplaceEncoding(StringSegment mediaType, Encoding encoding) /// The parsed media type with its associated quality. public static MediaTypeSegmentWithQuality CreateMediaTypeSegmentWithQuality(string mediaType, int start) { - var parsedMediaType = new MediaType(mediaType, start, length: null); + var parsedMediaType = new ReadOnlyMediaTypeHeaderValue(mediaType, start, length: null); // Short-circuit use of the MediaTypeParameterParser if constructor detected an invalid type or subtype. // Parser would set ParsingFailed==true in this case. But, we handle invalid parameters as a separate case. @@ -414,16 +272,16 @@ public static MediaTypeSegmentWithQuality CreateMediaTypeSegmentWithQuality(stri return default(MediaTypeSegmentWithQuality); } - var parser = parsedMediaType._parameterParser; - var quality = 1.0d; + + var parser = parsedMediaType.ParameterParser; while (parser.ParseNextParameter(out var parameter)) { if (parameter.HasName(QualityParameter)) { // If media type contains two `q` values i.e. it's invalid in an uncommon way, pick last value. quality = double.Parse( - parameter.Value.Value, NumberStyles.AllowDecimalPoint, + parameter.Value.AsSpan(), NumberStyles.AllowDecimalPoint, NumberFormatInfo.InvariantInfo); } } @@ -542,7 +400,7 @@ private bool ContainsAllParameters(MediaTypeParameterParser setParameters) // Copy the parser as we need to iterate multiple times over it. // We can do this because it's a struct - var subSetParameters = _parameterParser; + var subSetParameters = _mediaTypeHeaderValue.ParameterParser; parameterFound = false; while (subSetParameters.ParseNextParameter(out var subSetParameter) && !parameterFound) { diff --git a/src/Mvc/Mvc.Core/src/Microsoft.AspNetCore.Mvc.Core.csproj b/src/Mvc/Mvc.Core/src/Microsoft.AspNetCore.Mvc.Core.csproj index b2134e5d98d9..a0b45ab9d83a 100644 --- a/src/Mvc/Mvc.Core/src/Microsoft.AspNetCore.Mvc.Core.csproj +++ b/src/Mvc/Mvc.Core/src/Microsoft.AspNetCore.Mvc.Core.csproj @@ -31,6 +31,9 @@ Microsoft.AspNetCore.Mvc.RouteAttribute + + + diff --git a/src/Mvc/Mvc.Core/src/PublicAPI.Unshipped.txt b/src/Mvc/Mvc.Core/src/PublicAPI.Unshipped.txt index 02d3d3791fdd..0c8597b3f152 100644 --- a/src/Mvc/Mvc.Core/src/PublicAPI.Unshipped.txt +++ b/src/Mvc/Mvc.Core/src/PublicAPI.Unshipped.txt @@ -528,6 +528,7 @@ *REMOVED*~virtual Microsoft.AspNetCore.Mvc.Routing.UrlHelperBase.Content(string contentPath) -> string *REMOVED*~virtual Microsoft.AspNetCore.Mvc.Routing.UrlHelperBase.IsLocalUrl(string url) -> bool *REMOVED*~virtual Microsoft.AspNetCore.Mvc.Routing.UrlHelperBase.Link(string routeName, object values) -> string +Microsoft.AspNetCore.Mvc.ConsumesAttribute.ConsumesAttribute(System.Type! requestType, string! contentType, params string![]! otherContentTypes) -> void Microsoft.AspNetCore.Mvc.JsonOptions.AllowInputFormatterExceptionMessages.get -> bool Microsoft.AspNetCore.Mvc.JsonOptions.AllowInputFormatterExceptionMessages.set -> void Microsoft.AspNetCore.Mvc.Controllers.ControllerActivatorProvider.CreateAsyncReleaser(Microsoft.AspNetCore.Mvc.Controllers.ControllerActionDescriptor! descriptor) -> System.Func? @@ -544,6 +545,8 @@ Microsoft.AspNetCore.Mvc.Infrastructure.ActionDescriptorCollection.Items.get -> Microsoft.AspNetCore.Mvc.Infrastructure.AmbiguousActionException.AmbiguousActionException(System.Runtime.Serialization.SerializationInfo! info, System.Runtime.Serialization.StreamingContext context) -> void Microsoft.AspNetCore.Mvc.Infrastructure.AmbiguousActionException.AmbiguousActionException(string? message) -> void Microsoft.AspNetCore.Mvc.Infrastructure.ContentResultExecutor.ContentResultExecutor(Microsoft.Extensions.Logging.ILogger! logger, Microsoft.AspNetCore.Mvc.Infrastructure.IHttpResponseStreamWriterFactory! httpResponseStreamWriterFactory) -> void +static Microsoft.AspNetCore.Http.OpenApiEndpointConventionBuilderExtensions.Accepts(this Microsoft.AspNetCore.Builder.MinimalActionEndpointConventionBuilder! builder, System.Type! requestType, string! contentType, params string![]! additionalContentTypes) -> Microsoft.AspNetCore.Builder.MinimalActionEndpointConventionBuilder! +static Microsoft.AspNetCore.Http.OpenApiEndpointConventionBuilderExtensions.Accepts(this Microsoft.AspNetCore.Builder.MinimalActionEndpointConventionBuilder! builder, string! contentType, params string![]! additionalContentTypes) -> Microsoft.AspNetCore.Builder.MinimalActionEndpointConventionBuilder! ~Microsoft.AspNetCore.Mvc.Infrastructure.DefaultOutputFormatterSelector.DefaultOutputFormatterSelector(Microsoft.Extensions.Options.IOptions! options, Microsoft.Extensions.Logging.ILoggerFactory! loggerFactory) -> void Microsoft.AspNetCore.Mvc.Infrastructure.FileContentResultExecutor.FileContentResultExecutor(Microsoft.Extensions.Logging.ILoggerFactory! loggerFactory) -> void Microsoft.AspNetCore.Mvc.Infrastructure.FileResultExecutorBase.FileResultExecutorBase(Microsoft.Extensions.Logging.ILogger! logger) -> void diff --git a/src/Mvc/Mvc.Core/src/Routing/ActionEndpointFactory.cs b/src/Mvc/Mvc.Core/src/Routing/ActionEndpointFactory.cs index c0b3d89e6b67..0f825940ca24 100644 --- a/src/Mvc/Mvc.Core/src/Routing/ActionEndpointFactory.cs +++ b/src/Mvc/Mvc.Core/src/Routing/ActionEndpointFactory.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Metadata; using Microsoft.AspNetCore.Mvc.Abstractions; using Microsoft.AspNetCore.Mvc.ActionConstraints; using Microsoft.AspNetCore.Mvc.Filters; @@ -378,9 +379,9 @@ private static void AddActionDataToBuilder( builder.Metadata.Add(new HttpMethodMetadata(httpMethodActionConstraint.HttpMethods)); } else if (actionConstraint is ConsumesAttribute consumesAttribute && - !builder.Metadata.OfType().Any()) + !builder.Metadata.OfType().Any()) { - builder.Metadata.Add(new ConsumesMetadata(consumesAttribute.ContentTypes.ToArray())); + builder.Metadata.Add(new AcceptsMetadata(consumesAttribute.ContentTypes.ToArray())); } else if (!builder.Metadata.Contains(actionConstraint)) { diff --git a/src/Mvc/Mvc.Core/src/Routing/ConsumesMatcherPolicy.cs b/src/Mvc/Mvc.Core/src/Routing/ConsumesMatcherPolicy.cs deleted file mode 100644 index 1584042c593e..000000000000 --- a/src/Mvc/Mvc.Core/src/Routing/ConsumesMatcherPolicy.cs +++ /dev/null @@ -1,394 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Mvc.Formatters; -using Microsoft.AspNetCore.Routing; -using Microsoft.AspNetCore.Routing.Matching; - -namespace Microsoft.AspNetCore.Mvc.Routing -{ - internal class ConsumesMatcherPolicy : MatcherPolicy, IEndpointComparerPolicy, INodeBuilderPolicy, IEndpointSelectorPolicy - { - internal const string Http415EndpointDisplayName = "415 HTTP Unsupported Media Type"; - internal const string AnyContentType = "*/*"; - - // Run after HTTP methods, but before 'default'. - public override int Order { get; } = -100; - - public IComparer Comparer { get; } = new ConsumesMetadataEndpointComparer(); - - bool INodeBuilderPolicy.AppliesToEndpoints(IReadOnlyList endpoints) - { - if (endpoints == null) - { - throw new ArgumentNullException(nameof(endpoints)); - } - - if (ContainsDynamicEndpoints(endpoints)) - { - return false; - } - - return AppliesToEndpointsCore(endpoints); - } - - bool IEndpointSelectorPolicy.AppliesToEndpoints(IReadOnlyList endpoints) - { - if (endpoints == null) - { - throw new ArgumentNullException(nameof(endpoints)); - } - - // When the node contains dynamic endpoints we can't make any assumptions. - return ContainsDynamicEndpoints(endpoints); - } - - private bool AppliesToEndpointsCore(IReadOnlyList endpoints) - { - return endpoints.Any(e => e.Metadata.GetMetadata()?.ContentTypes.Count > 0); - } - - public Task ApplyAsync(HttpContext httpContext, CandidateSet candidates) - { - if (httpContext == null) - { - throw new ArgumentNullException(nameof(httpContext)); - } - - if (candidates == null) - { - throw new ArgumentNullException(nameof(candidates)); - } - - // We want to return a 415 if we eliminated ALL of the currently valid endpoints due to content type - // mismatch. - bool? needs415Endpoint = null; - - for (var i = 0; i < candidates.Count; i++) - { - // We do this check first for consistency with how 415 is implemented for the graph version - // of this code. We still want to know if any endpoints in this set require an a ContentType - // even if those endpoints are already invalid - hence the null check. - var metadata = candidates[i].Endpoint?.Metadata.GetMetadata(); - if (metadata == null || metadata.ContentTypes.Count == 0) - { - // Can match any content type. - needs415Endpoint = false; - continue; - } - - // Saw a valid endpoint. - needs415Endpoint = needs415Endpoint ?? true; - - if (!candidates.IsValidCandidate(i)) - { - // If the candidate is already invalid, then do a search to see if it has a wildcard content type. - // - // We don't want to return a 415 if any content type could be accepted depending on other parameters. - if (metadata != null) - { - for (var j = 0; j < metadata.ContentTypes.Count; j++) - { - if (string.Equals("*/*", metadata.ContentTypes[j], StringComparison.Ordinal)) - { - needs415Endpoint = false; - break; - } - } - } - - continue; - } - - var contentType = httpContext.Request.ContentType; - var mediaType = string.IsNullOrEmpty(contentType) ? (MediaType?)null : new MediaType(contentType); - - var matched = false; - for (var j = 0; j < metadata.ContentTypes.Count; j++) - { - var candidateMediaType = new MediaType(metadata.ContentTypes[j]); - if (candidateMediaType.MatchesAllTypes) - { - // We don't need a 415 response because there's an endpoint that would accept any type. - needs415Endpoint = false; - } - - // If there's no ContentType, then then can only matched by a wildcard `*/*`. - if (mediaType == null && !candidateMediaType.MatchesAllTypes) - { - continue; - } - - // We have a ContentType but it's not a match. - else if (mediaType != null && !mediaType.Value.IsSubsetOf(candidateMediaType)) - { - continue; - } - - // We have a ContentType and we accept any value OR we have a ContentType and it's a match. - matched = true; - needs415Endpoint = false; - break; - } - - if (!matched) - { - candidates.SetValidity(i, false); - } - } - - if (needs415Endpoint == true) - { - // We saw some endpoints coming in, and we eliminated them all. - httpContext.SetEndpoint(CreateRejectionEndpoint()); - } - - return Task.CompletedTask; - } - - public IReadOnlyList GetEdges(IReadOnlyList endpoints) - { - if (endpoints == null) - { - throw new ArgumentNullException(nameof(endpoints)); - } - - // The algorithm here is designed to be preserve the order of the endpoints - // while also being relatively simple. Preserving order is important. - - // First, build a dictionary of all of the content-type patterns that are included - // at this node. - // - // For now we're just building up the set of keys. We don't add any endpoints - // to lists now because we don't want ordering problems. - var edges = new Dictionary>(StringComparer.OrdinalIgnoreCase); - for (var i = 0; i < endpoints.Count; i++) - { - var endpoint = endpoints[i]; - var contentTypes = endpoint.Metadata.GetMetadata()?.ContentTypes; - if (contentTypes == null || contentTypes.Count == 0) - { - contentTypes = new string[] { AnyContentType, }; - } - - for (var j = 0; j < contentTypes.Count; j++) - { - var contentType = contentTypes[j]; - - if (!edges.ContainsKey(contentType)) - { - edges.Add(contentType, new List()); - } - } - } - - // Now in a second loop, add endpoints to these lists. We've enumerated all of - // the states, so we want to see which states this endpoint matches. - for (var i = 0; i < endpoints.Count; i++) - { - var endpoint = endpoints[i]; - var contentTypes = endpoint.Metadata.GetMetadata()?.ContentTypes ?? Array.Empty(); - if (contentTypes.Count == 0) - { - // OK this means that this endpoint matches *all* content methods. - // So, loop and add it to all states. - foreach (var kvp in edges) - { - kvp.Value.Add(endpoint); - } - } - else - { - // OK this endpoint matches specific content types -- we have to loop through edges here - // because content types could either be exact (like 'application/json') or they - // could have wildcards (like 'text/*'). We don't expect wildcards to be especially common - // with consumes, but we need to support it. - foreach (var kvp in edges) - { - // The edgeKey maps to a possible request header value - var edgeKey = new MediaType(kvp.Key); - - for (var j = 0; j < contentTypes.Count; j++) - { - var contentType = contentTypes[j]; - - var mediaType = new MediaType(contentType); - - // Example: 'application/json' is subset of 'application/*' - // - // This means that when the request has content-type 'application/json' an endpoint - // what consumes 'application/*' should match. - if (edgeKey.IsSubsetOf(mediaType)) - { - kvp.Value.Add(endpoint); - - // It's possible that a ConsumesMetadata defines overlapping wildcards. Don't add an endpoint - // to any edge twice - break; - } - } - } - } - } - - // If after we're done there isn't any endpoint that accepts */*, then we'll synthesize an - // endpoint that always returns a 415. - if (!edges.TryGetValue(AnyContentType, out var anyEndpoints)) - { - edges.Add(AnyContentType, new List() - { - CreateRejectionEndpoint(), - }); - - // Add a node to use when there is no request content type. - // When there is no content type we want the policy to no-op - edges.Add(string.Empty, endpoints.ToList()); - } - else - { - // If there is an endpoint that accepts */* then it is also used when there is no content type - edges.Add(string.Empty, anyEndpoints.ToList()); - } - - - return edges - .Select(kvp => new PolicyNodeEdge(kvp.Key, kvp.Value)) - .ToArray(); - } - - private Endpoint CreateRejectionEndpoint() - { - return new Endpoint( - (context) => - { - context.Response.StatusCode = StatusCodes.Status415UnsupportedMediaType; - return Task.CompletedTask; - }, - EndpointMetadataCollection.Empty, - Http415EndpointDisplayName); - } - - public PolicyJumpTable BuildJumpTable(int exitDestination, IReadOnlyList edges) - { - if (edges == null) - { - throw new ArgumentNullException(nameof(edges)); - } - - // Since our 'edges' can have wildcards, we do a sort based on how wildcard-ey they - // are then then execute them in linear order. - var ordered = edges - .Select(e => (mediaType: CreateEdgeMediaType(ref e), destination: e.Destination)) - .OrderBy(e => GetScore(e.mediaType)) - .ToArray(); - - // If any edge matches all content types, then treat that as the 'exit'. This will - // always happen because we insert a 415 endpoint. - for (var i = 0; i < ordered.Length; i++) - { - if (ordered[i].mediaType.MatchesAllTypes) - { - exitDestination = ordered[i].destination; - break; - } - } - - var noContentTypeDestination = GetNoContentTypeDestination(ordered); - - return new ConsumesPolicyJumpTable(exitDestination, noContentTypeDestination, ordered); - } - - private static int GetNoContentTypeDestination((MediaType mediaType, int destination)[] destinations) - { - for (var i = 0; i < destinations.Length; i++) - { - if (!destinations[i].mediaType.Type.HasValue) - { - return destinations[i].destination; - } - } - - throw new InvalidOperationException("Could not find destination for no content type."); - } - - private static MediaType CreateEdgeMediaType(ref PolicyJumpTableEdge e) - { - var mediaType = (string)e.State; - return !string.IsNullOrEmpty(mediaType) ? new MediaType(mediaType) : default; - } - - private int GetScore(in MediaType mediaType) - { - // Higher score == lower priority - see comments on MediaType. - if (mediaType.MatchesAllTypes) - { - return 4; - } - else if (mediaType.MatchesAllSubTypes) - { - return 3; - } - else if (mediaType.MatchesAllSubTypesWithoutSuffix) - { - return 2; - } - else - { - return 1; - } - } - - private class ConsumesMetadataEndpointComparer : EndpointMetadataComparer - { - protected override int CompareMetadata(IConsumesMetadata? x, IConsumesMetadata? y) - { - // Ignore the metadata if it has an empty list of content types. - return base.CompareMetadata( - x?.ContentTypes.Count > 0 ? x : null, - y?.ContentTypes.Count > 0 ? y : null); - } - } - - private class ConsumesPolicyJumpTable : PolicyJumpTable - { - private readonly (MediaType mediaType, int destination)[] _destinations; - private readonly int _exitDestination; - private readonly int _noContentTypeDestination; - - public ConsumesPolicyJumpTable(int exitDestination, int noContentTypeDestination, (MediaType mediaType, int destination)[] destinations) - { - _exitDestination = exitDestination; - _noContentTypeDestination = noContentTypeDestination; - _destinations = destinations; - } - - public override int GetDestination(HttpContext httpContext) - { - var contentType = httpContext.Request.ContentType; - - if (string.IsNullOrEmpty(contentType)) - { - return _noContentTypeDestination; - } - - var requestMediaType = new MediaType(contentType); - var destinations = _destinations; - for (var i = 0; i < destinations.Length; i++) - { - if (requestMediaType.IsSubsetOf(destinations[i].mediaType)) - { - return destinations[i].destination; - } - } - - return _exitDestination; - } - } - } -} diff --git a/src/Mvc/Mvc.Core/src/Routing/ConsumesMetadata.cs b/src/Mvc/Mvc.Core/src/Routing/ConsumesMetadata.cs deleted file mode 100644 index 101dd3a26e76..000000000000 --- a/src/Mvc/Mvc.Core/src/Routing/ConsumesMetadata.cs +++ /dev/null @@ -1,24 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - - -using System; -using System.Collections.Generic; - -namespace Microsoft.AspNetCore.Mvc.Routing -{ - internal class ConsumesMetadata : IConsumesMetadata - { - public ConsumesMetadata(string[] contentTypes) - { - if (contentTypes == null) - { - throw new ArgumentNullException(nameof(contentTypes)); - } - - ContentTypes = contentTypes; - } - - public IReadOnlyList ContentTypes { get; } - } -} diff --git a/src/Mvc/Mvc.Core/src/Routing/IConsumesMetadata.cs b/src/Mvc/Mvc.Core/src/Routing/IConsumesMetadata.cs deleted file mode 100644 index 01aa87207835..000000000000 --- a/src/Mvc/Mvc.Core/src/Routing/IConsumesMetadata.cs +++ /dev/null @@ -1,13 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - - -using System.Collections.Generic; - -namespace Microsoft.AspNetCore.Mvc.Routing -{ - internal interface IConsumesMetadata - { - IReadOnlyList ContentTypes { get; } - } -} diff --git a/src/Mvc/Mvc.Core/test/DependencyInjection/MvcCoreServiceCollectionExtensionsTest.cs b/src/Mvc/Mvc.Core/test/DependencyInjection/MvcCoreServiceCollectionExtensionsTest.cs index 47dcd3012705..5ec7086a0015 100644 --- a/src/Mvc/Mvc.Core/test/DependencyInjection/MvcCoreServiceCollectionExtensionsTest.cs +++ b/src/Mvc/Mvc.Core/test/DependencyInjection/MvcCoreServiceCollectionExtensionsTest.cs @@ -15,6 +15,7 @@ using Microsoft.AspNetCore.Mvc.Infrastructure; using Microsoft.AspNetCore.Mvc.Routing; using Microsoft.AspNetCore.Routing; +using Microsoft.AspNetCore.Routing.Matching; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Options; @@ -324,7 +325,6 @@ private Dictionary MultiRegistrationServiceTypes typeof(MatcherPolicy), new Type[] { - typeof(ConsumesMatcherPolicy), typeof(ActionConstraintMatcherPolicy), typeof(DynamicControllerEndpointMatcherPolicy), } diff --git a/src/Mvc/test/Mvc.FunctionalTests/SimpleWithWebApplicationBuilderTests.cs b/src/Mvc/test/Mvc.FunctionalTests/SimpleWithWebApplicationBuilderTests.cs index f77902bf0ee3..ca2ae7a27c12 100644 --- a/src/Mvc/test/Mvc.FunctionalTests/SimpleWithWebApplicationBuilderTests.cs +++ b/src/Mvc/test/Mvc.FunctionalTests/SimpleWithWebApplicationBuilderTests.cs @@ -21,8 +21,11 @@ public class SimpleWithWebApplicationBuilderTests : IClassFixture fixture) { _fixture = fixture; + Client = _fixture.CreateDefaultClient(); } + public HttpClient Client { get; } + [Fact] public async Task HelloWorld() { @@ -187,5 +190,45 @@ public async Task Environment_Can_Be_Overridden() // Assert Assert.Equal(expected, content); } + + [Fact] + public async Task Accepts_Json_WhenBindingAComplexType() + { + // Act + var response = await Client.PostAsJsonAsync("accepts-default", new { name = "Test" }); + + // Assert + await response.AssertStatusCodeAsync(HttpStatusCode.OK); + } + + [Fact] + public async Task Rejects_NonJson_WhenBindingAComplexType() + { + // Arrange + var request = new HttpRequestMessage(HttpMethod.Post, "accepts-default"); + request.Content = new StringContent(""); + request.Content.Headers.ContentType = new("application/xml"); + + // Act + var response = await Client.SendAsync(request); + + // Assert + await response.AssertStatusCodeAsync(HttpStatusCode.UnsupportedMediaType); + } + + [Fact] + public async Task Accepts_NonJsonMediaType() + { + // Arrange + var request = new HttpRequestMessage(HttpMethod.Post, "accepts-xml"); + request.Content = new StringContent(""); + request.Content.Headers.ContentType = new("application/xml"); + + // Act + var response = await Client.SendAsync(request); + + // Assert + await response.AssertStatusCodeAsync(HttpStatusCode.Accepted); + } } } diff --git a/src/Mvc/test/WebSites/SimpleWebSiteWithWebApplicationBuilder/Program.cs b/src/Mvc/test/WebSites/SimpleWebSiteWithWebApplicationBuilder/Program.cs index 213aec9278d2..b0b21f572be3 100644 --- a/src/Mvc/test/WebSites/SimpleWebSiteWithWebApplicationBuilder/Program.cs +++ b/src/Mvc/test/WebSites/SimpleWebSiteWithWebApplicationBuilder/Program.cs @@ -37,6 +37,9 @@ app.MapGet("/greeting", (IConfiguration config) => config["Greeting"]); +app.MapPost("/accepts-default", (Person person) => Results.Ok(person.Name)); +app.MapPost("/accepts-xml", () => Accepted()).Accepts("application/xml"); + app.Run(); record Person(string Name, int Age); @@ -45,4 +48,4 @@ public class MyController : ControllerBase { [HttpGet("/greet")] public string Greet() => $"Hello human"; -} \ No newline at end of file +} diff --git a/src/Shared/MediaType/HttpTokenParsingRule.cs b/src/Shared/MediaType/HttpTokenParsingRule.cs new file mode 100644 index 000000000000..0910a3a20349 --- /dev/null +++ b/src/Shared/MediaType/HttpTokenParsingRule.cs @@ -0,0 +1,277 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Text; + +namespace Microsoft.AspNetCore.Http.Headers; + +internal static class HttpTokenParsingRules +{ + private static readonly bool[] TokenChars = CreateTokenChars(); + private const int MaxNestedCount = 5; + + internal const char CR = '\r'; + internal const char LF = '\n'; + internal const char SP = ' '; + internal const char Tab = '\t'; + internal const int MaxInt64Digits = 19; + internal const int MaxInt32Digits = 10; + + // iso-8859-1, Western European (ISO) + internal static readonly Encoding DefaultHttpEncoding = Encoding.GetEncoding(28591); + + private static bool[] CreateTokenChars() + { + // token = 1* + // CTL = + + var tokenChars = new bool[128]; // everything is false + + for (var i = 33; i < 127; i++) // skip Space (32) & DEL (127) + { + tokenChars[i] = true; + } + + // remove separators: these are not valid token characters + tokenChars[(byte)'('] = false; + tokenChars[(byte)')'] = false; + tokenChars[(byte)'<'] = false; + tokenChars[(byte)'>'] = false; + tokenChars[(byte)'@'] = false; + tokenChars[(byte)','] = false; + tokenChars[(byte)';'] = false; + tokenChars[(byte)':'] = false; + tokenChars[(byte)'\\'] = false; + tokenChars[(byte)'"'] = false; + tokenChars[(byte)'/'] = false; + tokenChars[(byte)'['] = false; + tokenChars[(byte)']'] = false; + tokenChars[(byte)'?'] = false; + tokenChars[(byte)'='] = false; + tokenChars[(byte)'{'] = false; + tokenChars[(byte)'}'] = false; + + return tokenChars; + } + + internal static bool IsTokenChar(char character) + { + // Must be between 'space' (32) and 'DEL' (127) + if (character > 127) + { + return false; + } + + return TokenChars[character]; + } + + internal static int GetTokenLength(string input, int startIndex) + { + Debug.Assert(input != null); + + if (startIndex >= input.Length) + { + return 0; + } + + var current = startIndex; + + while (current < input.Length) + { + if (!IsTokenChar(input[current])) + { + return current - startIndex; + } + current++; + } + return input.Length - startIndex; + } + + internal static int GetWhitespaceLength(string input, int startIndex) + { + Debug.Assert(input != null); + + if (startIndex >= input.Length) + { + return 0; + } + + var current = startIndex; + + while (current < input.Length) + { + var c = input[current]; + + if ((c == SP) || (c == Tab)) + { + current++; + continue; + } + + if (c == CR) + { + // If we have a #13 char, it must be followed by #10 and then at least one SP or HT. + if ((current + 2 < input.Length) && (input[current + 1] == LF)) + { + var spaceOrTab = input[current + 2]; + if ((spaceOrTab == SP) || (spaceOrTab == Tab)) + { + current += 3; + continue; + } + } + } + + return current - startIndex; + } + + // All characters between startIndex and the end of the string are LWS characters. + return input.Length - startIndex; + } + + internal static HttpParseResult GetQuotedStringLength(string input, int startIndex, out int length) + { + var nestedCount = 0; + return GetExpressionLength(input, startIndex, '"', '"', false, ref nestedCount, out length); + } + + // quoted-pair = "\" CHAR + // CHAR = + internal static HttpParseResult GetQuotedPairLength(string input, int startIndex, out int length) + { + Debug.Assert(input != null); + Debug.Assert((startIndex >= 0) && (startIndex < input.Length)); + + length = 0; + + if (input[startIndex] != '\\') + { + return HttpParseResult.NotParsed; + } + + // Quoted-char has 2 characters. Check whether there are 2 chars left ('\' + char) + // If so, check whether the character is in the range 0-127. If not, it's an invalid value. + if ((startIndex + 2 > input.Length) || (input[startIndex + 1] > 127)) + { + return HttpParseResult.InvalidFormat; + } + + // We don't care what the char next to '\' is. + length = 2; + return HttpParseResult.Parsed; + } + + // TEXT = + // LWS = [CRLF] 1*( SP | HT ) + // CTL = + // + // Since we don't really care about the content of a quoted string or comment, we're more tolerant and + // allow these characters. We only want to find the delimiters ('"' for quoted string and '(', ')' for comment). + // + // 'nestedCount': Comments can be nested. We allow a depth of up to 5 nested comments, i.e. something like + // "(((((comment)))))". If we wouldn't define a limit an attacker could send a comment with hundreds of nested + // comments, resulting in a stack overflow exception. In addition having more than 1 nested comment (if any) + // is unusual. + private static HttpParseResult GetExpressionLength( + string input, + int startIndex, + char openChar, + char closeChar, + bool supportsNesting, + ref int nestedCount, + out int length) + { + Debug.Assert(input != null); + Debug.Assert((startIndex >= 0) && (startIndex < input.Length)); + + length = 0; + + if (input[startIndex] != openChar) + { + return HttpParseResult.NotParsed; + } + + var current = startIndex + 1; // Start parsing with the character next to the first open-char + while (current < input.Length) + { + // Only check whether we have a quoted char, if we have at least 3 characters left to read (i.e. + // quoted char + closing char). Otherwise the closing char may be considered part of the quoted char. + if ((current + 2 < input.Length) && + (GetQuotedPairLength(input, current, out var quotedPairLength) == HttpParseResult.Parsed)) + { + // We ignore invalid quoted-pairs. Invalid quoted-pairs may mean that it looked like a quoted pair, + // but we actually have a quoted-string: e.g. "\ü" ('\' followed by a char >127 - quoted-pair only + // allows ASCII chars after '\'; qdtext allows both '\' and >127 chars). + current = current + quotedPairLength; + continue; + } + + // If we support nested expressions and we find an open-char, then parse the nested expressions. + if (supportsNesting && (input[current] == openChar)) + { + nestedCount++; + try + { + // Check if we exceeded the number of nested calls. + if (nestedCount > MaxNestedCount) + { + return HttpParseResult.InvalidFormat; + } + + var nestedResult = GetExpressionLength( + input, + current, + openChar, + closeChar, + supportsNesting, + ref nestedCount, + out var nestedLength); + + switch (nestedResult) + { + case HttpParseResult.Parsed: + current += nestedLength; // add the length of the nested expression and continue. + break; + + case HttpParseResult.NotParsed: + Debug.Fail("'NotParsed' is unexpected: We started nested expression " + + "parsing, because we found the open-char. So either it's a valid nested " + + "expression or it has invalid format."); + break; + + case HttpParseResult.InvalidFormat: + // If the nested expression is invalid, we can't continue, so we fail with invalid format. + return HttpParseResult.InvalidFormat; + + default: + Debug.Fail("Unknown enum result: " + nestedResult); + break; + } + } + finally + { + nestedCount--; + } + } + + if (input[current] == closeChar) + { + length = current - startIndex + 1; + return HttpParseResult.Parsed; + } + current++; + } + + // We didn't see the final quote, therefore we have an invalid expression string. + return HttpParseResult.InvalidFormat; + } +} + +internal enum HttpParseResult +{ + Parsed, + NotParsed, + InvalidFormat, +} + diff --git a/src/Shared/MediaType/ReadOnlyMediaTypeHeaderValue.cs b/src/Shared/MediaType/ReadOnlyMediaTypeHeaderValue.cs new file mode 100644 index 000000000000..7eaa8903334e --- /dev/null +++ b/src/Shared/MediaType/ReadOnlyMediaTypeHeaderValue.cs @@ -0,0 +1,625 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text; +using Microsoft.Extensions.Primitives; + +namespace Microsoft.AspNetCore.Http.Headers; + +/// +/// A media type value. +/// +internal readonly struct ReadOnlyMediaTypeHeaderValue +{ + /// + /// Initializes a instance. + /// + /// The with the media type. + public ReadOnlyMediaTypeHeaderValue(string mediaType) + : this(mediaType, 0, mediaType.Length) + { + } + + /// + /// Initializes a instance. + /// + /// The with the media type. + public ReadOnlyMediaTypeHeaderValue(StringSegment mediaType) + : this(mediaType.Buffer, mediaType.Offset, mediaType.Length) + { + } + + /// + /// Initializes a instance. + /// + /// The with the media type. + /// The offset in the where the parsing starts. + /// The length of the media type to parse if provided. + public ReadOnlyMediaTypeHeaderValue(string mediaType, int offset, int? length) + { + ParameterParser = default(MediaTypeParameterParser); + + var typeLength = GetTypeLength(mediaType, offset, out var type); + if (typeLength == 0) + { + Type = new StringSegment(); + SubType = new StringSegment(); + SubTypeWithoutSuffix = new StringSegment(); + SubTypeSuffix = new StringSegment(); + return; + } + else + { + Type = type; + } + + var subTypeLength = GetSubtypeLength(mediaType, offset + typeLength, out var subType); + if (subTypeLength == 0) + { + SubType = new StringSegment(); + SubTypeWithoutSuffix = new StringSegment(); + SubTypeSuffix = new StringSegment(); + return; + } + else + { + SubType = subType; + + if (TryGetSuffixLength(subType, out var subtypeSuffixLength)) + { + SubTypeWithoutSuffix = subType.Subsegment(0, subType.Length - subtypeSuffixLength - 1); + SubTypeSuffix = subType.Subsegment(subType.Length - subtypeSuffixLength, subtypeSuffixLength); + } + else + { + SubTypeWithoutSuffix = SubType; + SubTypeSuffix = new StringSegment(); + } + } + + ParameterParser = new MediaTypeParameterParser(mediaType, offset + typeLength + subTypeLength, length); + } + + // All GetXXXLength methods work in the same way. They expect to be on the right position for + // the token they are parsing, for example, the beginning of the media type or the delimiter + // from a previous token, like '/', ';' or '='. + // Each method consumes the delimiter token if any, the leading whitespace, then the given token + // itself, and finally the trailing whitespace. + private static int GetTypeLength(string input, int offset, out StringSegment type) + { + if (offset < 0 || offset >= input.Length) + { + type = default(StringSegment); + return 0; + } + + var current = offset + HttpTokenParsingRules.GetWhitespaceLength(input, offset); + + // Parse the type, i.e. in media type string "/; param1=value1; param2=value2" + var typeLength = HttpTokenParsingRules.GetTokenLength(input, current); + if (typeLength == 0) + { + type = default(StringSegment); + return 0; + } + + type = new StringSegment(input, current, typeLength); + + current += typeLength; + current += HttpTokenParsingRules.GetWhitespaceLength(input, current); + + return current - offset; + } + + private static int GetSubtypeLength(string input, int offset, out StringSegment subType) + { + var current = offset; + + // Parse the separator between type and subtype + if (current < 0 || current >= input.Length || input[current] != '/') + { + subType = default(StringSegment); + return 0; + } + + current++; // skip delimiter. + current += HttpTokenParsingRules.GetWhitespaceLength(input, current); + + var subtypeLength = HttpTokenParsingRules.GetTokenLength(input, current); + if (subtypeLength == 0) + { + subType = default(StringSegment); + return 0; + } + + subType = new StringSegment(input, current, subtypeLength); + + current += subtypeLength; + current += HttpTokenParsingRules.GetWhitespaceLength(input, current); + + return current - offset; + } + + private static bool TryGetSuffixLength(StringSegment subType, out int suffixLength) + { + // Find the last instance of '+', if there is one + var startPos = subType.Offset + subType.Length - 1; + for (var currentPos = startPos; currentPos >= subType.Offset; currentPos--) + { + if (subType.Buffer[currentPos] == '+') + { + suffixLength = startPos - currentPos; + return true; + } + } + + suffixLength = 0; + return false; + } + + /// + /// Gets the type of the . + /// + /// + /// For the media type "application/json", this property gives the value "application". + /// + public StringSegment Type { get; } + + /// + /// Gets whether this matches all types. + /// + public bool MatchesAllTypes => Type.Equals("*", StringComparison.OrdinalIgnoreCase); + + /// + /// Gets the subtype of the . + /// + /// + /// For the media type "application/vnd.example+json", this property gives the value + /// "vnd.example+json". + /// + public StringSegment SubType { get; } + + /// + /// Gets the subtype of the , excluding any structured syntax suffix. + /// + /// + /// For the media type "application/vnd.example+json", this property gives the value + /// "vnd.example". + /// + public StringSegment SubTypeWithoutSuffix { get; } + + /// + /// Gets the structured syntax suffix of the if it has one. + /// + /// + /// For the media type "application/vnd.example+json", this property gives the value + /// "json". + /// + public StringSegment SubTypeSuffix { get; } + + /// + /// Gets whether this matches all subtypes. + /// + /// + /// For the media type "application/*", this property is true. + /// + /// + /// For the media type "application/json", this property is false. + /// + public bool MatchesAllSubTypes => SubType.Equals("*", StringComparison.OrdinalIgnoreCase); + + /// + /// Gets whether this matches all subtypes, ignoring any structured syntax suffix. + /// + /// + /// For the media type "application/*+json", this property is true. + /// + /// + /// For the media type "application/vnd.example+json", this property is false. + /// + public bool MatchesAllSubTypesWithoutSuffix => SubTypeWithoutSuffix.Equals("*", StringComparison.OrdinalIgnoreCase); + + /// + /// Gets the of the if it has one. + /// + public Encoding? Encoding => GetEncodingFromCharset(GetParameter("charset")); + + /// + /// Gets the charset parameter of the if it has one. + /// + public StringSegment Charset => GetParameter("charset"); + + /// + /// Determines whether the current contains a wildcard. + /// + /// + /// true if this contains a wildcard; otherwise false. + /// + public bool HasWildcard + { + get + { + return MatchesAllTypes || + MatchesAllSubTypesWithoutSuffix || + GetParameter("*").Equals("*", StringComparison.OrdinalIgnoreCase); + } + } + + public MediaTypeParameterParser ParameterParser { get; } + + /// + /// Determines whether the current is a subset of the + /// . + /// + /// The set . + /// + /// true if this is a subset of ; otherwise false. + /// + public bool IsSubsetOf(ReadOnlyMediaTypeHeaderValue set) + { + return MatchesType(set) && + MatchesSubtype(set) && + ContainsAllParameters(set.ParameterParser); + } + + /// + /// Gets the parameter of the media type. + /// + /// The name of the parameter to retrieve. + /// + /// The for the given if found; otherwise + /// null. + /// + public StringSegment GetParameter(string parameterName) + { + return GetParameter(new StringSegment(parameterName)); + } + + /// + /// Gets the parameter of the media type. + /// + /// The name of the parameter to retrieve. + /// + /// The for the given if found; otherwise + /// null. + /// + public StringSegment GetParameter(StringSegment parameterName) + { + var parametersParser = ParameterParser; + + while (parametersParser.ParseNextParameter(out var parameter)) + { + if (parameter.HasName(parameterName)) + { + return parameter.Value; + } + } + + return new StringSegment(); + } + + /// + /// Gets the last parameter of the media type. + /// + /// The name of the parameter to retrieve. + /// The value for the last parameter + /// + /// if parsing succeeded. + /// + public bool TryGetLastParameter(StringSegment parameterName, out StringSegment parameterValue) + { + var parametersParser = ParameterParser; + + parameterValue = default; + while (parametersParser.ParseNextParameter(out var parameter)) + { + if (parameter.HasName(parameterName)) + { + parameterValue = parameter.Value; + } + } + + return !parametersParser.ParsingFailed; + } + + /// + /// Get an encoding for a mediaType. + /// + /// The mediaType. + /// The encoding. + public static Encoding? GetEncoding(string mediaType) + { + return GetEncoding(new StringSegment(mediaType)); + } + + /// + /// Get an encoding for a mediaType. + /// + /// The mediaType. + /// The encoding. + public static Encoding? GetEncoding(StringSegment mediaType) + { + var parsedMediaType = new ReadOnlyMediaTypeHeaderValue(mediaType); + return parsedMediaType.Encoding; + } + + private static Encoding? GetEncodingFromCharset(StringSegment charset) + { + if (charset.Equals("utf-8", StringComparison.OrdinalIgnoreCase)) + { + // This is an optimization for utf-8 that prevents the Substring caused by + // charset.Value + return Encoding.UTF8; + } + + try + { + // charset.Value might be an invalid encoding name as in charset=invalid. + // For that reason, we catch the exception thrown by Encoding.GetEncoding + // and return null instead. + return charset.HasValue ? Encoding.GetEncoding(charset.Value) : null; + } + catch (Exception) + { + return null; + } + } + + private static string CreateMediaTypeWithEncoding(StringSegment mediaType, Encoding encoding) + { + return $"{mediaType.Value}; charset={encoding.WebName}"; + } + + private bool MatchesType(ReadOnlyMediaTypeHeaderValue set) + { + return set.MatchesAllTypes || + set.Type.Equals(Type, StringComparison.OrdinalIgnoreCase); + } + + private bool MatchesSubtype(ReadOnlyMediaTypeHeaderValue set) + { + if (set.MatchesAllSubTypes) + { + return true; + } + + if (set.SubTypeSuffix.HasValue) + { + if (SubTypeSuffix.HasValue) + { + // Both the set and the media type being checked have suffixes, so both parts must match. + return MatchesSubtypeWithoutSuffix(set) && MatchesSubtypeSuffix(set); + } + else + { + // The set has a suffix, but the media type being checked doesn't. We never consider this to match. + return false; + } + } + else + { + // If this subtype or suffix matches the subtype of the set, + // it is considered a subtype. + // Ex: application/json > application/val+json + return MatchesEitherSubtypeOrSuffix(set); + } + } + + private bool MatchesSubtypeWithoutSuffix(ReadOnlyMediaTypeHeaderValue set) + { + return set.MatchesAllSubTypesWithoutSuffix || + set.SubTypeWithoutSuffix.Equals(SubTypeWithoutSuffix, StringComparison.OrdinalIgnoreCase); + } + + private bool MatchesSubtypeSuffix(ReadOnlyMediaTypeHeaderValue set) + { + // We don't have support for wildcards on suffixes alone (e.g., "application/entity+*") + // because there's no clear use case for it. + return set.SubTypeSuffix.Equals(SubTypeSuffix, StringComparison.OrdinalIgnoreCase); + } + + private bool MatchesEitherSubtypeOrSuffix(ReadOnlyMediaTypeHeaderValue set) + { + return set.SubType.Equals(SubType, StringComparison.OrdinalIgnoreCase) || + set.SubType.Equals(SubTypeSuffix, StringComparison.OrdinalIgnoreCase); + } + + private bool ContainsAllParameters(MediaTypeParameterParser setParameters) + { + var parameterFound = true; + while (setParameters.ParseNextParameter(out var setParameter) && parameterFound) + { + if (setParameter.HasName("q")) + { + // "q" and later parameters are not involved in media type matching. Quoting the RFC: The first + // "q" parameter (if any) separates the media-range parameter(s) from the accept-params. + break; + } + + if (setParameter.HasName("*")) + { + // A parameter named "*" has no effect on media type matching, as it is only used as an indication + // that the entire media type string should be treated as a wildcard. + continue; + } + + // Copy the parser as we need to iterate multiple times over it. + // We can do this because it's a struct + var subSetParameters = ParameterParser; + parameterFound = false; + while (subSetParameters.ParseNextParameter(out var subSetParameter) && !parameterFound) + { + parameterFound = subSetParameter.Equals(setParameter); + } + } + + return parameterFound; + } + + public struct MediaTypeParameterParser + { + private readonly string _mediaTypeBuffer; + private readonly int? _length; + + public MediaTypeParameterParser(string mediaTypeBuffer, int offset, int? length) + { + _mediaTypeBuffer = mediaTypeBuffer; + _length = length; + CurrentOffset = offset; + ParsingFailed = false; + } + + public int CurrentOffset { get; private set; } + + public bool ParsingFailed { get; private set; } + + public bool ParseNextParameter(out MediaTypeParameter result) + { + if (_mediaTypeBuffer == null) + { + ParsingFailed = true; + result = default(MediaTypeParameter); + return false; + } + + var parameterLength = GetParameterLength(_mediaTypeBuffer, CurrentOffset, out result); + CurrentOffset += parameterLength; + + if (parameterLength == 0) + { + ParsingFailed = _length != null && CurrentOffset < _length; + return false; + } + + return true; + } + + private static int GetParameterLength(string input, int startIndex, out MediaTypeParameter parsedValue) + { + if (OffsetIsOutOfRange(startIndex, input.Length) || input[startIndex] != ';') + { + parsedValue = default(MediaTypeParameter); + return 0; + } + + var nameLength = GetNameLength(input, startIndex, out var name); + + var current = startIndex + nameLength; + + if (nameLength == 0 || OffsetIsOutOfRange(current, input.Length) || input[current] != '=') + { + if (current == input.Length && name.Equals("*", StringComparison.OrdinalIgnoreCase)) + { + // As a special case, we allow a trailing ";*" to indicate a wildcard + // string allowing any other parameters. It's the same as ";*=*". + var asterisk = new StringSegment("*"); + parsedValue = new MediaTypeParameter(asterisk, asterisk); + return current - startIndex; + } + else + { + parsedValue = default(MediaTypeParameter); + return 0; + } + } + + var valueLength = GetValueLength(input, current, out var value); + + parsedValue = new MediaTypeParameter(name, value); + current += valueLength; + + return current - startIndex; + } + + private static int GetNameLength(string input, int startIndex, out StringSegment name) + { + var current = startIndex; + + current++; // skip ';' + current += HttpTokenParsingRules.GetWhitespaceLength(input, current); + + var nameLength = HttpTokenParsingRules.GetTokenLength(input, current); + if (nameLength == 0) + { + name = default(StringSegment); + return 0; + } + + name = new StringSegment(input, current, nameLength); + + current += nameLength; + current += HttpTokenParsingRules.GetWhitespaceLength(input, current); + + return current - startIndex; + } + + private static int GetValueLength(string input, int startIndex, out StringSegment value) + { + var current = startIndex; + + current++; // skip '='. + current += HttpTokenParsingRules.GetWhitespaceLength(input, current); + + var valueLength = HttpTokenParsingRules.GetTokenLength(input, current); + + if (valueLength == 0) + { + // A value can either be a token or a quoted string. Check if it is a quoted string. + var result = HttpTokenParsingRules.GetQuotedStringLength(input, current, out valueLength); + if (result != HttpParseResult.Parsed) + { + // We have an invalid value. Reset the name and return. + value = default(StringSegment); + return 0; + } + + // Quotation marks are not part of a quoted parameter value. + value = new StringSegment(input, current + 1, valueLength - 2); + } + else + { + value = new StringSegment(input, current, valueLength); + } + + current += valueLength; + current += HttpTokenParsingRules.GetWhitespaceLength(input, current); + + return current - startIndex; + } + + private static bool OffsetIsOutOfRange(int offset, int length) + { + return offset < 0 || offset >= length; + } + } + + public readonly struct MediaTypeParameter : IEquatable + { + public MediaTypeParameter(StringSegment name, StringSegment value) + { + Name = name; + Value = value; + } + + public StringSegment Name { get; } + + public StringSegment Value { get; } + + public bool HasName(string name) + { + return HasName(new StringSegment(name)); + } + + public bool HasName(StringSegment name) + { + return Name.Equals(name, StringComparison.OrdinalIgnoreCase); + } + + public bool Equals(MediaTypeParameter other) + { + return HasName(other.Name) && Value.Equals(other.Value, StringComparison.OrdinalIgnoreCase); + } + + public override string ToString() => $"{Name}={Value}"; + } +} diff --git a/src/Shared/RoutingMetadata/AcceptsMetadata.cs b/src/Shared/RoutingMetadata/AcceptsMetadata.cs new file mode 100644 index 000000000000..eadfd5deb565 --- /dev/null +++ b/src/Shared/RoutingMetadata/AcceptsMetadata.cs @@ -0,0 +1,54 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#nullable enable + +using System; +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Http.Metadata +{ + /// + /// Metadata that specifies the supported request content types. + /// + internal sealed class AcceptsMetadata : IAcceptsMetadata + { + /// + /// Creates a new instance of . + /// + public AcceptsMetadata(string[] contentTypes) + { + if (contentTypes == null) + { + throw new ArgumentNullException(nameof(contentTypes)); + } + + ContentTypes = contentTypes; + } + + /// + /// Creates a new instance of with a type. + /// + public AcceptsMetadata(Type? type, string[] contentTypes) + { + RequestType = type ?? throw new ArgumentNullException(nameof(type)); + + if (contentTypes == null) + { + throw new ArgumentNullException(nameof(contentTypes)); + } + + ContentTypes = contentTypes; + } + + /// + /// Gets the supported request content types. + /// + public IReadOnlyList ContentTypes { get; } + + /// + /// Accepts request content types of any shape. + /// + public Type? RequestType { get; } + } +}