From f4254d90a6795af857d00dae4f54f8217a78ed9f Mon Sep 17 00:00:00 2001 From: David Fowler Date: Thu, 19 Aug 2021 03:29:12 -0700 Subject: [PATCH 1/4] Added support for type based parameter binding - Added a convention that allows custom async binding logic to run for parameters that have a static BindAsync method that takes an HttpContext and return a ValueTask of object. This allows customers to write custom binders based solely on type (it's an extension of the existing TryParse pattern). - There's allocation overhead per request once there's a parameter binder for a delegate. This is because we need to box all of the arguments since we're not using generated code to compute data from the list of binders. - Changed TryParse tests to BindAsync tests and added more tests. --- .../src/RequestDelegateFactory.cs | 186 ++++++++++----- .../test/RequestDelegateFactoryTests.cs | 218 ++++++++++++++---- .../test/TryParseMethodCacheTests.cs | 61 +++-- .../EndpointMetadataApiDescriptionProvider.cs | 2 +- src/Shared/TryParseMethodCache.cs | 20 +- 5 files changed, 339 insertions(+), 148 deletions(-) diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index fb2239d64401..9052c3a06d38 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -41,12 +41,11 @@ public static partial class RequestDelegateFactory Log.ParameterBindingFailed(httpContext, parameterType, parameterName, sourceValue)); private static readonly MethodInfo LogRequiredParameterNotProvidedMethod = GetMethodInfo>((httpContext, parameterType, parameterName) => Log.RequiredParameterNotProvided(httpContext, parameterType, parameterName)); - private static readonly MethodInfo LogParameterBindingFromHttpContextFailedMethod = GetMethodInfo>((httpContext, parameterType, parameterName) => - Log.ParameterBindingFromHttpContextFailed(httpContext, parameterType, parameterName)); private static readonly ParameterExpression TargetExpr = Expression.Parameter(typeof(object), "target"); private static readonly ParameterExpression BodyValueExpr = Expression.Parameter(typeof(object), "bodyValue"); private static readonly ParameterExpression WasParamCheckFailureExpr = Expression.Variable(typeof(bool), "wasParamCheckFailure"); + private static readonly ParameterExpression BoundValuesArrayExpr = Expression.Parameter(typeof(object[]), "boundValues"); private static ParameterExpression HttpContextExpr => TryParseMethodCache.HttpContextExpr; private static readonly MemberExpression RequestServicesExpr = Expression.Property(HttpContextExpr, typeof(HttpContext).GetProperty(nameof(HttpContext.RequestServices))!); @@ -188,23 +187,23 @@ private static Expression[] CreateArguments(ParameterInfo[]? parameters, Factory } var args = new Expression[parameters.Length]; + factoryContext.ParameterCount = parameters.Length; for (var i = 0; i < parameters.Length; i++) { - args[i] = CreateArgument(parameters[i], factoryContext); + args[i] = CreateArgument(i, parameters[i], factoryContext); } if (factoryContext.HasMultipleBodyParameters) { var errorMessage = BuildErrorMessageForMultipleBodyParameters(factoryContext); throw new InvalidOperationException(errorMessage); - } return args; } - private static Expression CreateArgument(ParameterInfo parameter, FactoryContext factoryContext) + private static Expression CreateArgument(int index, ParameterInfo parameter, FactoryContext factoryContext) { if (parameter.Name is null) { @@ -263,9 +262,9 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext { return RequestAbortedExpr; } - else if (TryParseMethodCache.HasTryParseHttpContextMethod(parameter)) + else if (TryParseMethodCache.HasBindAsyncMethod(parameter)) { - return BindParameterFromTryParseHttpContext(parameter, factoryContext); + return BindParameterFromBindAsync(index, parameter, factoryContext); } else if (parameter.ParameterType == typeof(string) || TryParseMethodCache.HasTryParseStringMethod(parameter)) { @@ -275,7 +274,7 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext // when RDF.Create is manually invoked. if (factoryContext.RouteParameters is { } routeParams) { - + if (routeParams.Contains(parameter.Name, StringComparer.OrdinalIgnoreCase)) { // We're in the fallback case and we have a parameter and route parameter match so don't fallback @@ -361,7 +360,6 @@ private static Expression CreateParamCheckingResponseWritingMethodCall( var localVariables = new ParameterExpression[factoryContext.ExtraLocals.Count + 1]; var checkParamAndCallMethod = new Expression[factoryContext.ParamCheckExpressions.Count + 1]; - for (var i = 0; i < factoryContext.ExtraLocals.Count; i++) { localVariables[i] = factoryContext.ExtraLocals[i]; @@ -508,14 +506,33 @@ private static Expression AddResponseWritingToMethodCall(Expression methodCall, { if (factoryContext.JsonRequestBodyType is null) { + if (factoryContext.ParameterBinders.Count > 0) + { + // We need to generate the code for reading from the custom binders calling into the delegate + var continuation = Expression.Lambda>( + responseWritingMethodCall, TargetExpr, HttpContextExpr, BoundValuesArrayExpr).Compile(); + + // Looping over arrays is faster + var binders = factoryContext.ParameterBinders.ToArray(); + var count = factoryContext.ParameterCount; + + return async (target, httpContext) => + { + var boundValues = new object?[count]; + + foreach (var (index, binder) in binders) + { + boundValues[index] = await binder(httpContext); + } + + await continuation(target, httpContext, boundValues); + }; + } + return Expression.Lambda>( responseWritingMethodCall, TargetExpr, HttpContextExpr).Compile(); } - // We need to generate the code for reading from the body before calling into the delegate - var invoker = Expression.Lambda>( - responseWritingMethodCall, TargetExpr, HttpContextExpr, BodyValueExpr).Compile(); - var bodyType = factoryContext.JsonRequestBodyType; object? defaultBodyValue = null; @@ -524,31 +541,82 @@ private static Expression AddResponseWritingToMethodCall(Expression methodCall, defaultBodyValue = Activator.CreateInstance(bodyType); } - return async (target, httpContext) => + if (factoryContext.ParameterBinders.Count > 0) { - object? bodyValue = defaultBodyValue; - var feature = httpContext.Features.Get(); - if (feature?.CanHaveBody == true) + // We need to generate the code for reading from the body before calling into the delegate + var continuation = Expression.Lambda>( + responseWritingMethodCall, TargetExpr, HttpContextExpr, BodyValueExpr, BoundValuesArrayExpr).Compile(); + + // Looping over arrays is faster + var binders = factoryContext.ParameterBinders.ToArray(); + var count = factoryContext.ParameterCount; + + return async (target, httpContext) => { - try + // Run these first so that they can potentially read and rewind the body + var boundValues = new object?[count]; + + foreach (var (index, binder) in binders) { - bodyValue = await httpContext.Request.ReadFromJsonAsync(bodyType); + boundValues[index] = await binder(httpContext); } - catch (IOException ex) + + var bodyValue = defaultBodyValue; + var feature = httpContext.Features.Get(); + if (feature?.CanHaveBody == true) { - Log.RequestBodyIOException(httpContext, ex); - return; + try + { + bodyValue = await httpContext.Request.ReadFromJsonAsync(bodyType); + } + catch (IOException ex) + { + Log.RequestBodyIOException(httpContext, ex); + return; + } + catch (InvalidDataException ex) + { + Log.RequestBodyInvalidDataException(httpContext, ex); + httpContext.Response.StatusCode = 400; + return; + } } - catch (InvalidDataException ex) - { - Log.RequestBodyInvalidDataException(httpContext, ex); - httpContext.Response.StatusCode = 400; - return; + await continuation(target, httpContext, bodyValue, boundValues); + }; + } + else + { + // We need to generate the code for reading from the body before calling into the delegate + var continuation = Expression.Lambda>( + responseWritingMethodCall, TargetExpr, HttpContextExpr, BodyValueExpr).Compile(); + + return async (target, httpContext) => + { + var bodyValue = defaultBodyValue; + var feature = httpContext.Features.Get(); + if (feature?.CanHaveBody == true) + { + try + { + bodyValue = await httpContext.Request.ReadFromJsonAsync(bodyType); + } + catch (IOException ex) + { + Log.RequestBodyIOException(httpContext, ex); + return; + } + catch (InvalidDataException ex) + { + + Log.RequestBodyInvalidDataException(httpContext, ex); + httpContext.Response.StatusCode = 400; + return; + } } - } - await invoker(target, httpContext, bodyValue); - }; + await continuation(target, httpContext, bodyValue); + }; + } } private static Expression GetValueFromProperty(Expression sourceExpression, string key) @@ -747,40 +815,40 @@ private static Expression BindParameterFromRouteValueOrQueryString(ParameterInfo return BindParameterFromValue(parameter, Expression.Coalesce(routeValue, queryValue), factoryContext); } - private static Expression BindParameterFromTryParseHttpContext(ParameterInfo parameter, FactoryContext factoryContext) + private static Expression BindParameterFromBindAsync(int index, ParameterInfo parameter, FactoryContext factoryContext) { - // bool wasParamCheckFailure = false; - // - // // Assume "Foo param1" is the first parameter and "public static bool TryParse(HttpContext context, out Foo foo)" exists. - // Foo param1_local; - // - // if (!Foo.TryParse(httpContext, out param1_local)) - // { - // wasParamCheckFailure = true; - // Log.ParameterBindingFromHttpContextFailed(httpContext, "Foo", "foo") - // } - - var argument = Expression.Variable(parameter.ParameterType, $"{parameter.Name}_local"); - var tryParseMethodCall = TryParseMethodCache.FindTryParseHttpContextMethod(parameter.ParameterType); + // We reference the boundValues array by parameter index here + var nullability = NullabilityContext.Create(parameter); + var isOptional = parameter.HasDefaultValue || nullability.ReadState == NullabilityState.Nullable; - // There's no way to opt-in to using a TryParse method on HttpContext other than defining the method, so it's guaranteed to exist here. - Debug.Assert(tryParseMethodCall is not null); + // Get the BindAsync method + var body = TryParseMethodCache.FindBindAsyncMethod(parameter.ParameterType)!; - var parameterTypeNameConstant = Expression.Constant(parameter.ParameterType.Name); - var parameterNameConstant = Expression.Constant(parameter.Name); + // Compile the delegate to the BindAsync method for this parameter index + var bindAsyncDelegate = Expression.Lambda>>(body, HttpContextExpr).Compile(); + factoryContext.ParameterBinders.Add((index, bindAsyncDelegate)); - var failBlock = Expression.Block( - Expression.Assign(WasParamCheckFailureExpr, Expression.Constant(true)), - Expression.Call(LogParameterBindingFromHttpContextFailedMethod, - HttpContextExpr, parameterTypeNameConstant, parameterNameConstant)); + // boundValues[index] + var boundValueExpr = Expression.ArrayIndex(BoundValuesArrayExpr, Expression.Constant(index)); - var tryParseCall = tryParseMethodCall(argument); - var fullParamCheckBlock = Expression.IfThen(Expression.Not(tryParseCall), failBlock); + if (!isOptional) + { + var checkRequiredBodyBlock = Expression.Block( + Expression.IfThen( + Expression.Equal(boundValueExpr, Expression.Constant(null)), + Expression.Block( + Expression.Assign(WasParamCheckFailureExpr, Expression.Constant(true)), + Expression.Call(LogRequiredParameterNotProvidedMethod, + HttpContextExpr, Expression.Constant(parameter.ParameterType.Name), Expression.Constant(parameter.Name)) + ) + ) + ); - factoryContext.ExtraLocals.Add(argument); - factoryContext.ParamCheckExpressions.Add(fullParamCheckBlock); + factoryContext.ParamCheckExpressions.Add(checkRequiredBodyBlock); + } - return argument; + // (ParamterType)boundValues[i] + return Expression.Convert(boundValueExpr, parameter.ParameterType); } private static Expression BindParameterFromBody(ParameterInfo parameter, bool allowEmpty, FactoryContext factoryContext) @@ -793,7 +861,6 @@ private static Expression BindParameterFromBody(ParameterInfo parameter, bool al { factoryContext.TrackedParameters.Remove(parameterName); factoryContext.TrackedParameters.Add(parameterName, "UNKNOWN"); - } } @@ -925,7 +992,6 @@ static async Task ExecuteAwaited(Task task, HttpContext httpContext) private static Task ExecuteTaskOfString(Task task, HttpContext httpContext) { - SetPlaintextContentType(httpContext); EnsureRequestTaskNotNull(task); @@ -1032,6 +1098,8 @@ private class FactoryContext public bool UsingTempSourceString { get; set; } public List ExtraLocals { get; } = new(); public List ParamCheckExpressions { get; } = new(); + public List<(int, Func>)> ParameterBinders { get; } = new(); + public int ParameterCount { get; set; } public Dictionary TrackedParameters { get; } = new(); public bool HasMultipleBodyParameters { get; set; } diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index 621a9ce5a474..00832e35454a 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -499,45 +499,49 @@ public static bool TryParse(string? value, out MyTryParseStringRecord? result) } } - private record MyTryParseHttpContextRecord(Uri Uri) + private class MyBindAsyncTypeThatThrows { - public static bool TryParse(HttpContext context, out MyTryParseHttpContextRecord? result) + public static ValueTask BindAsync(HttpContext context) + { + throw new InvalidOperationException("BindAsync failed"); + } + } + + private record MyBindAsyncRecord(Uri Uri) + { + public static ValueTask BindAsync(HttpContext context) { if (!Uri.TryCreate(context.Request.Headers.Referer, UriKind.Absolute, out var uri)) { - result = null; - return false; + return ValueTask.FromResult(null); } - result = new MyTryParseHttpContextRecord(uri); - return true; + return ValueTask.FromResult(new MyBindAsyncRecord(uri)); } // TryParse(HttpContext, ...) should be preferred over TryParse(string, ...) if there's // no [FromRoute] or [FromQuery] attributes. - public static bool TryParse(string? value, out MyTryParseHttpContextRecord? result) + public static bool TryParse(string? value, out MyBindAsyncRecord? result) { throw new NotImplementedException(); } } - private record struct MyTryParseHttpContextStruct(Uri Uri) + private record struct MyBindAsyncStruct(Uri Uri) { - public static bool TryParse(HttpContext context, out MyTryParseHttpContextStruct result) + public static ValueTask BindAsync(HttpContext context) { if (!Uri.TryCreate(context.Request.Headers.Referer, UriKind.Absolute, out var uri)) { - result = default; - return false; + return ValueTask.FromResult(null); } - result = new MyTryParseHttpContextStruct(uri); - return true; + return ValueTask.FromResult(new MyBindAsyncStruct(uri)); } // TryParse(HttpContext, ...) should be preferred over TryParse(string, ...) if there's // no [FromRoute] or [FromQuery] attributes. - public static bool TryParse(string? value, out MyTryParseHttpContextStruct result) => + public static bool TryParse(string? value, out MyBindAsyncStruct result) => throw new NotImplementedException(); } @@ -604,44 +608,44 @@ public async Task RequestDelegatePopulatesUnattributedTryParsableParametersFromR } [Fact] - public async Task RequestDelegatePrefersTryParseHttpContextOverTryParseString() + public async Task RequestDelegatePrefersBindAsyncOverTryParseString() { var httpContext = new DefaultHttpContext(); httpContext.Request.Headers.Referer = "https://example.org"; - var requestDelegate = RequestDelegateFactory.Create((HttpContext httpContext, MyTryParseHttpContextRecord tryParsable) => + var requestDelegate = RequestDelegateFactory.Create((HttpContext httpContext, MyBindAsyncRecord tryParsable) => { httpContext.Items["tryParsable"] = tryParsable; }); await requestDelegate(httpContext); - Assert.Equal(new MyTryParseHttpContextRecord(new Uri("https://example.org")), httpContext.Items["tryParsable"]); + Assert.Equal(new MyBindAsyncRecord(new Uri("https://example.org")), httpContext.Items["tryParsable"]); } [Fact] - public async Task RequestDelegatePrefersTryParseHttpContextOverTryParseStringForNonNullableStruct() + public async Task RequestDelegatePrefersBindAsyncOverTryParseStringForNonNullableStruct() { var httpContext = new DefaultHttpContext(); httpContext.Request.Headers.Referer = "https://example.org"; - var requestDelegate = RequestDelegateFactory.Create((HttpContext httpContext, MyTryParseHttpContextStruct tryParsable) => + var requestDelegate = RequestDelegateFactory.Create((HttpContext httpContext, MyBindAsyncStruct tryParsable) => { httpContext.Items["tryParsable"] = tryParsable; }); await requestDelegate(httpContext); - Assert.Equal(new MyTryParseHttpContextStruct(new Uri("https://example.org")), httpContext.Items["tryParsable"]); + Assert.Equal(new MyBindAsyncStruct(new Uri("https://example.org")), httpContext.Items["tryParsable"]); } [Fact] - public async Task RequestDelegateUsesTryParseStringoOverTryParseHttpContextGivenExplicitAttribute() + public async Task RequestDelegateUsesTryParseStringoOverBindAsyncGivenExplicitAttribute() { - var fromRouteRequestDelegate = RequestDelegateFactory.Create((HttpContext httpContext, [FromRoute] MyTryParseHttpContextRecord tryParsable) => { }); - var fromQueryRequestDelegate = RequestDelegateFactory.Create((HttpContext httpContext, [FromQuery] MyTryParseHttpContextRecord tryParsable) => { }); + var fromRouteRequestDelegate = RequestDelegateFactory.Create((HttpContext httpContext, [FromRoute] MyBindAsyncRecord tryParsable) => { }); + var fromQueryRequestDelegate = RequestDelegateFactory.Create((HttpContext httpContext, [FromQuery] MyBindAsyncRecord tryParsable) => { }); var httpContext = new DefaultHttpContext { @@ -663,9 +667,9 @@ public async Task RequestDelegateUsesTryParseStringoOverTryParseHttpContextGiven } [Fact] - public async Task RequestDelegateUsesTryParseStringoOverTryParseHttpContextGivenNullableStruct() + public async Task RequestDelegateUsesTryParseStringOverBindAsyncGivenNullableStruct() { - var fromRouteRequestDelegate = RequestDelegateFactory.Create((HttpContext httpContext, MyTryParseHttpContextStruct? tryParsable) => { }); + var fromRouteRequestDelegate = RequestDelegateFactory.Create((HttpContext httpContext, MyBindAsyncStruct? tryParsable) => { }); var httpContext = new DefaultHttpContext { @@ -756,7 +760,7 @@ void TestAction([FromRoute] int tryParsable, [FromRoute] int tryParsable2) } [Fact] - public async Task RequestDelegateLogsTryParseHttpContextFailuresAndSets400Response() + public async Task RequestDelegateLogsBindAsyncFailuresAndSets400Response() { // Not supplying any headers will cause the HttpContext TryParse overload to fail. var httpContext = new DefaultHttpContext() @@ -766,7 +770,7 @@ public async Task RequestDelegateLogsTryParseHttpContextFailuresAndSets400Respon var invoked = false; - var requestDelegate = RequestDelegateFactory.Create((MyTryParseHttpContextRecord arg1, MyTryParseHttpContextRecord arg2) => + var requestDelegate = RequestDelegateFactory.Create((MyBindAsyncRecord arg1, MyBindAsyncRecord arg2) => { invoked = true; }); @@ -781,13 +785,136 @@ public async Task RequestDelegateLogsTryParseHttpContextFailuresAndSets400Respon Assert.Equal(2, logs.Length); - Assert.Equal(new EventId(5, "ParamaterBindingFromHttpContextFailed"), logs[0].EventId); + Assert.Equal(new EventId(4, "RequiredParameterNotProvided"), logs[0].EventId); Assert.Equal(LogLevel.Debug, logs[0].LogLevel); - Assert.Equal(@"Failed to bind parameter ""MyTryParseHttpContextRecord arg1"" from HttpContext.", logs[0].Message); + Assert.Equal(@"Required parameter ""MyBindAsyncRecord arg1"" was not provided.", logs[0].Message); - Assert.Equal(new EventId(5, "ParamaterBindingFromHttpContextFailed"), logs[1].EventId); + Assert.Equal(new EventId(4, "RequiredParameterNotProvided"), logs[1].EventId); Assert.Equal(LogLevel.Debug, logs[1].LogLevel); - Assert.Equal(@"Failed to bind parameter ""MyTryParseHttpContextRecord arg2"" from HttpContext.", logs[1].Message); + Assert.Equal(@"Required parameter ""MyBindAsyncRecord arg2"" was not provided.", logs[1].Message); + } + + [Fact] + public async Task BindAsyncExceptionsThrowException() + { + // Not supplying any headers will cause the HttpContext TryParse overload to fail. + var httpContext = new DefaultHttpContext() + { + RequestServices = new ServiceCollection().AddSingleton(LoggerFactory).BuildServiceProvider(), + }; + + var requestDelegate = RequestDelegateFactory.Create((MyBindAsyncTypeThatThrows arg1) => { }); + + var ex = await Assert.ThrowsAsync(() => requestDelegate(httpContext)); + Assert.Equal("BindAsync failed", ex.Message); + } + + [Fact] + public async Task BindAsyncWithBodyArgument() + { + Todo originalTodo = new() + { + Name = "Write more tests!" + }; + + var httpContext = new DefaultHttpContext(); + + var requestBodyBytes = JsonSerializer.SerializeToUtf8Bytes(originalTodo); + var stream = new MemoryStream(requestBodyBytes); ; + httpContext.Request.Body = stream; + + httpContext.Request.Headers["Content-Type"] = "application/json"; + httpContext.Request.Headers["Content-Length"] = stream.Length.ToString(); + httpContext.Features.Set(new RequestBodyDetectionFeature(true)); + + var jsonOptions = new JsonOptions(); + jsonOptions.SerializerOptions.Converters.Add(new TodoJsonConverter()); + + var mock = new Mock(); + mock.Setup(m => m.GetService(It.IsAny())).Returns(t => + { + if (t == typeof(IOptions)) + { + return Options.Create(jsonOptions); + } + return null; + }); + + httpContext.RequestServices = mock.Object; + httpContext.Request.Headers.Referer = "https://example.org"; + + var invoked = false; + + var requestDelegate = RequestDelegateFactory.Create((HttpContext context, MyBindAsyncRecord arg1, Todo todo) => + { + invoked = true; + context.Items[nameof(arg1)] = arg1; + context.Items[nameof(todo)] = todo; + }); + + await requestDelegate(httpContext); + + Assert.True(invoked); + var arg = httpContext.Items["arg1"] as MyBindAsyncRecord; + Assert.NotNull(arg); + Assert.Equal("https://example.org/", arg!.Uri.ToString()); + var todo = httpContext.Items["todo"] as Todo; + Assert.NotNull(todo); + Assert.Equal("Write more tests!", todo!.Name); + } + + [Fact] + public async Task BindAsyncRunsBeforeBodyBinding() + { + Todo originalTodo = new() + { + Name = "Write more tests!" + }; + + var httpContext = new DefaultHttpContext(); + + var requestBodyBytes = JsonSerializer.SerializeToUtf8Bytes(originalTodo); + var stream = new MemoryStream(requestBodyBytes); ; + httpContext.Request.Body = stream; + + httpContext.Request.Headers["Content-Type"] = "application/json"; + httpContext.Request.Headers["Content-Length"] = stream.Length.ToString(); + httpContext.Features.Set(new RequestBodyDetectionFeature(true)); + + var jsonOptions = new JsonOptions(); + jsonOptions.SerializerOptions.Converters.Add(new TodoJsonConverter()); + + var mock = new Mock(); + mock.Setup(m => m.GetService(It.IsAny())).Returns(t => + { + if (t == typeof(IOptions)) + { + return Options.Create(jsonOptions); + } + return null; + }); + + httpContext.RequestServices = mock.Object; + httpContext.Request.Headers.Referer = "https://example.org"; + + var invoked = false; + + var requestDelegate = RequestDelegateFactory.Create((HttpContext context, CustomTodo customTodo, Todo todo) => + { + invoked = true; + context.Items[nameof(customTodo)] = customTodo; + context.Items[nameof(todo)] = todo; + }); + + await requestDelegate(httpContext); + + Assert.True(invoked); + var todo0 = httpContext.Items["customTodo"] as Todo; + Assert.NotNull(todo0); + Assert.Equal("Write more tests!", todo0!.Name); + var todo1 = httpContext.Items["todo"] as Todo; + Assert.NotNull(todo1); + Assert.Equal("Write more tests!", todo1!.Name); } [Fact] @@ -1825,11 +1952,9 @@ public async Task RequestDelegateHandlesBodyParamOptionality(Delegate @delegate, } } - public async Task RequestDelegateDoesNotSupportTryParseHttpContextOptionality() + [Fact] + public async Task RequestDelegateDoesSupportBindAsyncOptionality() { - // Not supplying any headers will cause the HttpContext TryParse overload to fail. - // However, RequestDelegateFactory cannot differentiate between a missing parameter and an invalid one, so - // the nullability of the argument doesn't change behavior. var httpContext = new DefaultHttpContext() { RequestServices = new ServiceCollection().AddSingleton(LoggerFactory).BuildServiceProvider(), @@ -1837,23 +1962,14 @@ public async Task RequestDelegateDoesNotSupportTryParseHttpContextOptionality() var invoked = false; - var requestDelegate = RequestDelegateFactory.Create((MyTryParseHttpContextRecord? arg1) => + var requestDelegate = RequestDelegateFactory.Create((MyBindAsyncRecord? arg1) => { invoked = true; }); await requestDelegate(httpContext); - Assert.False(invoked); - Assert.False(httpContext.RequestAborted.IsCancellationRequested); - Assert.Equal(400, httpContext.Response.StatusCode); - - var logs = TestSink.Writes.ToArray(); - var log = Assert.Single(logs); - - Assert.Equal(new EventId(5, "ParamaterBindingFromHttpContextFailed"), log.EventId); - Assert.Equal(LogLevel.Debug, log.LogLevel); - Assert.Equal(@"Failed to bind parameter ""MyTryParseHttpContextRecord arg1"" from HttpContext.", log.Message); + Assert.True(invoked); } public static IEnumerable ServiceParamOptionalityData @@ -2030,6 +2146,16 @@ private class Todo : ITodo public bool IsComplete { get; set; } } + private class CustomTodo : Todo + { + public static async ValueTask BindAsync(HttpContext context) + { + var body = await context.Request.ReadFromJsonAsync(); + context.Request.Body.Position = 0; + return body; + } + } + private record struct TodoStruct(int Id, string? Name, bool IsComplete) : ITodo; private interface ITodo diff --git a/src/Http/Http.Extensions/test/TryParseMethodCacheTests.cs b/src/Http/Http.Extensions/test/TryParseMethodCacheTests.cs index 132dba5d2017..85b8792f7661 100644 --- a/src/Http/Http.Extensions/test/TryParseMethodCacheTests.cs +++ b/src/Http/Http.Extensions/test/TryParseMethodCacheTests.cs @@ -168,26 +168,25 @@ public void FindTryParseStringMethod_WorksForEnumsWhenNonGenericEnumParseIsUsed( } [Fact] - public void FindTryParseHttpContextMethod_FindsCorrectMethodOnClass() + public async Task FindBindAsyncMethod_FindsCorrectMethodOnClass() { - var type = typeof(TryParseHttpContextRecord); + var type = typeof(BindAsyncRecord); var cache = new TryParseMethodCache(); - var methodFound = cache.FindTryParseHttpContextMethod(type); + var methodFound = cache.FindBindAsyncMethod(type); Assert.NotNull(methodFound); var parsedValue = Expression.Variable(type, "parsedValue"); - var call = methodFound!(parsedValue) as MethodCallExpression; + var call = methodFound as MethodCallExpression; Assert.NotNull(call); var method = call!.Method; var parameters = method.GetParameters(); + Assert.Single(parameters); Assert.Equal(typeof(HttpContext), parameters[0].ParameterType); - Assert.True(parameters[1].IsOut); - var parseHttpContext = Expression.Lambda>(Expression.Block(new[] { parsedValue }, - call, - parsedValue), cache.HttpContextExpr).Compile(); + var parseHttpContext = Expression.Lambda>>(Expression.Block(new[] { parsedValue }, + call), cache.HttpContextExpr).Compile(); var httpContext = new DefaultHttpContext { @@ -200,10 +199,10 @@ public void FindTryParseHttpContextMethod_FindsCorrectMethodOnClass() }, }; - Assert.Equal(new TryParseHttpContextRecord(42), parseHttpContext(httpContext)); + Assert.Equal(new BindAsyncRecord(42), await parseHttpContext(httpContext)); } - public static IEnumerable TryParseHttpContextParameterInfoData + public static IEnumerable BindAsyncParameterInfoData { get { @@ -211,28 +210,28 @@ public static IEnumerable TryParseHttpContextParameterInfoData { new[] { - GetFirstParameter((TryParseHttpContextRecord arg) => TryParseHttpContextRecordMethod(arg)), + GetFirstParameter((BindAsyncRecord arg) => BindAsyncRecordMethod(arg)), }, new[] { - GetFirstParameter((TryParseHttpContextStruct arg) => TryParseHttpContextStructMethod(arg)), + GetFirstParameter((BindAsyncStruct arg) => BindAsyncStructMethod(arg)), }, }; } } [Theory] - [MemberData(nameof(TryParseHttpContextParameterInfoData))] - public void HasTryParseHttpContextMethod_ReturnsTrueWhenMethodExists(ParameterInfo parameterInfo) + [MemberData(nameof(BindAsyncParameterInfoData))] + public void HasBindAsyncMethod_ReturnsTrueWhenMethodExists(ParameterInfo parameterInfo) { - Assert.True(new TryParseMethodCache().HasTryParseHttpContextMethod(parameterInfo)); + Assert.True(new TryParseMethodCache().HasBindAsyncMethod(parameterInfo)); } [Fact] - public void FindTryParseHttpContextMethod_DoesNotFindMethodGivenNullableType() + public void FindBindAsyncMethod_DoesNotFindMethodGivenNullableType() { - var parameterInfo = GetFirstParameter((TryParseHttpContextStruct? arg) => TryParseHttpContextNullableStructMethod(arg)); - Assert.False(new TryParseMethodCache().HasTryParseHttpContextMethod(parameterInfo)); + var parameterInfo = GetFirstParameter((BindAsyncStruct? arg) => BindAsyncNullableStructMethod(arg)); + Assert.False(new TryParseMethodCache().HasBindAsyncMethod(parameterInfo)); } enum Choice @@ -246,9 +245,9 @@ private static void TryParseStringRecordMethod(TryParseStringRecord arg) { } private static void TryParseStringStructMethod(TryParseStringStruct arg) { } private static void TryParseStringNullableStructMethod(TryParseStringStruct? arg) { } - private static void TryParseHttpContextRecordMethod(TryParseHttpContextRecord arg) { } - private static void TryParseHttpContextStructMethod(TryParseHttpContextStruct arg) { } - private static void TryParseHttpContextNullableStructMethod(TryParseHttpContextStruct? arg) { } + private static void BindAsyncRecordMethod(BindAsyncRecord arg) { } + private static void BindAsyncStructMethod(BindAsyncStruct arg) { } + private static void BindAsyncNullableStructMethod(BindAsyncStruct? arg) { } private static ParameterInfo GetFirstParameter(Expression> expr) @@ -287,33 +286,29 @@ public static bool TryParse(string? value, IFormatProvider formatProvider, out T } } - private record TryParseHttpContextRecord(int Value) + private record BindAsyncRecord(int Value) { - public static bool TryParse(HttpContext context, out TryParseHttpContextRecord? result) + public static ValueTask BindAsync(HttpContext context) { if (!int.TryParse(context.Request.Headers.ETag, out var val)) { - result = null; - return false; + return ValueTask.FromResult(null); } - result = new TryParseHttpContextRecord(val); - return true; + return ValueTask.FromResult(new BindAsyncRecord(val)); } } - private record struct TryParseHttpContextStruct(int Value) + private record struct BindAsyncStruct(int Value) { - public static bool TryParse(HttpContext context, out TryParseHttpContextStruct result) + public static ValueTask BindAsync(HttpContext context) { if (!int.TryParse(context.Request.Headers.ETag, out var val)) { - result = default; - return false; + return ValueTask.FromResult(null); } - result = new TryParseHttpContextStruct(val); - return true; + return ValueTask.FromResult(new BindAsyncRecord(val)); } } } diff --git a/src/Mvc/Mvc.ApiExplorer/src/EndpointMetadataApiDescriptionProvider.cs b/src/Mvc/Mvc.ApiExplorer/src/EndpointMetadataApiDescriptionProvider.cs index 06084059080c..bb9a36b868d0 100644 --- a/src/Mvc/Mvc.ApiExplorer/src/EndpointMetadataApiDescriptionProvider.cs +++ b/src/Mvc/Mvc.ApiExplorer/src/EndpointMetadataApiDescriptionProvider.cs @@ -184,7 +184,7 @@ private ApiDescription CreateApiDescription(RouteEndpoint routeEndpoint, string parameter.ParameterType == typeof(HttpResponse) || parameter.ParameterType == typeof(ClaimsPrincipal) || parameter.ParameterType == typeof(CancellationToken) || - TryParseMethodCache.HasTryParseHttpContextMethod(parameter) || + TryParseMethodCache.HasBindAsyncMethod(parameter) || _serviceProviderIsService?.IsService(parameter.ParameterType) == true) { return (BindingSource.Services, parameter.Name ?? string.Empty, false); diff --git a/src/Shared/TryParseMethodCache.cs b/src/Shared/TryParseMethodCache.cs index 28c63d85b62e..91aff2f4a6cc 100644 --- a/src/Shared/TryParseMethodCache.cs +++ b/src/Shared/TryParseMethodCache.cs @@ -20,7 +20,7 @@ internal sealed class TryParseMethodCache // Since this is shared source, the cache won't be shared between RequestDelegateFactory and the ApiDescriptionProvider sadly :( private readonly ConcurrentDictionary?> _stringMethodCallCache = new(); - private readonly ConcurrentDictionary?> _httpContextMethodCallCache = new(); + private readonly ConcurrentDictionary _bindAsyncMethodCallCache = new(); internal readonly ParameterExpression TempSourceStringExpr = Expression.Variable(typeof(string), "tempSourceString"); internal readonly ParameterExpression HttpContextExpr = Expression.Parameter(typeof(HttpContext), "httpContext"); @@ -43,8 +43,8 @@ public bool HasTryParseStringMethod(ParameterInfo parameter) return FindTryParseStringMethod(nonNullableParameterType) is not null; } - public bool HasTryParseHttpContextMethod(ParameterInfo parameter) => - FindTryParseHttpContextMethod(parameter.ParameterType) is not null; + public bool HasBindAsyncMethod(ParameterInfo parameter) => + FindBindAsyncMethod(parameter.ParameterType) is not null; public Func? FindTryParseStringMethod(Type type) { @@ -125,21 +125,23 @@ public bool HasTryParseHttpContextMethod(ParameterInfo parameter) => return _stringMethodCallCache.GetOrAdd(type, Finder); } - public Func? FindTryParseHttpContextMethod(Type type) + public Expression? FindBindAsyncMethod(Type type) { - Func? Finder(Type type) + Expression? Finder(Type type) { - var methodInfo = type.GetMethod("TryParse", BindingFlags.Public | BindingFlags.Static, new[] { typeof(HttpContext), type.MakeByRefType() }); + var methodInfo = type.GetMethod("BindAsync", BindingFlags.Public | BindingFlags.Static, new[] { typeof(HttpContext) }); - if (methodInfo is not null) + // We're looking for a method with the following signature: + // static ValueTask BindAsync(HttpContext context) + if (methodInfo is not null && methodInfo.ReturnType == typeof(ValueTask)) { - return (expression) => Expression.Call(methodInfo, HttpContextExpr, expression); + return Expression.Call(methodInfo, HttpContextExpr); } return null; } - return _httpContextMethodCallCache.GetOrAdd(type, Finder); + return _bindAsyncMethodCallCache.GetOrAdd(type, Finder); } private static MethodInfo GetEnumTryParseMethod(bool preferNonGenericEnumParseOverload) From 910a45ef1ff3e9ed5e032c2ffbb115f95d92910e Mon Sep 17 00:00:00 2001 From: David Fowler Date: Thu, 19 Aug 2021 03:34:03 -0700 Subject: [PATCH 2/4] Apply suggestions from code review --- src/Http/Http.Extensions/src/RequestDelegateFactory.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index 9052c3a06d38..632bbc7aa812 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -274,7 +274,6 @@ private static Expression CreateArgument(int index, ParameterInfo parameter, Fac // when RDF.Create is manually invoked. if (factoryContext.RouteParameters is { } routeParams) { - if (routeParams.Contains(parameter.Name, StringComparer.OrdinalIgnoreCase)) { // We're in the fallback case and we have a parameter and route parameter match so don't fallback From bf8b0e603ae3e12fed25b52f55870531c6165207 Mon Sep 17 00:00:00 2001 From: David Fowler Date: Thu, 19 Aug 2021 11:07:03 -0700 Subject: [PATCH 3/4] Fixed tests --- .../test/EndpointMetadataApiDescriptionProviderTest.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Mvc/Mvc.ApiExplorer/test/EndpointMetadataApiDescriptionProviderTest.cs b/src/Mvc/Mvc.ApiExplorer/test/EndpointMetadataApiDescriptionProviderTest.cs index 83f1704db921..c6d59ce15024 100644 --- a/src/Mvc/Mvc.ApiExplorer/test/EndpointMetadataApiDescriptionProviderTest.cs +++ b/src/Mvc/Mvc.ApiExplorer/test/EndpointMetadataApiDescriptionProviderTest.cs @@ -299,7 +299,7 @@ public void DoesNotAddFromServiceParameterAsService() Assert.Empty(GetApiDescription((HttpResponse response) => { }).ParameterDescriptions); Assert.Empty(GetApiDescription((ClaimsPrincipal user) => { }).ParameterDescriptions); Assert.Empty(GetApiDescription((CancellationToken token) => { }).ParameterDescriptions); - Assert.Empty(GetApiDescription((TryParseHttpContextRecord context) => { }).ParameterDescriptions); + Assert.Empty(GetApiDescription((BindAsyncRecord context) => { }).ParameterDescriptions); } [Fact] @@ -681,11 +681,11 @@ public static bool TryParse(string value, out TryParseStringRecord result) => throw new NotImplementedException(); } - private record TryParseHttpContextRecord(int Value) + private record BindAsyncRecord(int Value) { - public static bool TryParse(HttpContext context, out TryParseHttpContextRecord result) => + public static ValueTask BindAsync(HttpContext context) => throw new NotImplementedException(); - public static bool TryParse(string value, out TryParseHttpContextRecord result) => + public static bool TryParse(string value, out BindAsyncRecord result) => throw new NotImplementedException(); } } From baf839591daafc415df4f5c057f6b62f1ae996c1 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Thu, 19 Aug 2021 11:58:43 -0700 Subject: [PATCH 4/4] Remove index from RequestDelegateFactory --- .../src/RequestDelegateFactory.cs | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index 632bbc7aa812..94d04695dc99 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -187,11 +187,10 @@ private static Expression[] CreateArguments(ParameterInfo[]? parameters, Factory } var args = new Expression[parameters.Length]; - factoryContext.ParameterCount = parameters.Length; for (var i = 0; i < parameters.Length; i++) { - args[i] = CreateArgument(i, parameters[i], factoryContext); + args[i] = CreateArgument(parameters[i], factoryContext); } if (factoryContext.HasMultipleBodyParameters) @@ -203,7 +202,7 @@ private static Expression[] CreateArguments(ParameterInfo[]? parameters, Factory return args; } - private static Expression CreateArgument(int index, ParameterInfo parameter, FactoryContext factoryContext) + private static Expression CreateArgument(ParameterInfo parameter, FactoryContext factoryContext) { if (parameter.Name is null) { @@ -264,7 +263,7 @@ private static Expression CreateArgument(int index, ParameterInfo parameter, Fac } else if (TryParseMethodCache.HasBindAsyncMethod(parameter)) { - return BindParameterFromBindAsync(index, parameter, factoryContext); + return BindParameterFromBindAsync(parameter, factoryContext); } else if (parameter.ParameterType == typeof(string) || TryParseMethodCache.HasTryParseStringMethod(parameter)) { @@ -513,15 +512,15 @@ private static Expression AddResponseWritingToMethodCall(Expression methodCall, // Looping over arrays is faster var binders = factoryContext.ParameterBinders.ToArray(); - var count = factoryContext.ParameterCount; + var count = binders.Length; return async (target, httpContext) => { var boundValues = new object?[count]; - foreach (var (index, binder) in binders) + for (var i = 0; i < count; i++) { - boundValues[index] = await binder(httpContext); + boundValues[i] = await binders[i](httpContext); } await continuation(target, httpContext, boundValues); @@ -548,16 +547,16 @@ private static Expression AddResponseWritingToMethodCall(Expression methodCall, // Looping over arrays is faster var binders = factoryContext.ParameterBinders.ToArray(); - var count = factoryContext.ParameterCount; + var count = binders.Length; return async (target, httpContext) => { // Run these first so that they can potentially read and rewind the body var boundValues = new object?[count]; - foreach (var (index, binder) in binders) + for (var i = 0; i < count; i++) { - boundValues[index] = await binder(httpContext); + boundValues[i] = await binders[i](httpContext); } var bodyValue = defaultBodyValue; @@ -814,7 +813,7 @@ private static Expression BindParameterFromRouteValueOrQueryString(ParameterInfo return BindParameterFromValue(parameter, Expression.Coalesce(routeValue, queryValue), factoryContext); } - private static Expression BindParameterFromBindAsync(int index, ParameterInfo parameter, FactoryContext factoryContext) + private static Expression BindParameterFromBindAsync(ParameterInfo parameter, FactoryContext factoryContext) { // We reference the boundValues array by parameter index here var nullability = NullabilityContext.Create(parameter); @@ -825,10 +824,10 @@ private static Expression BindParameterFromBindAsync(int index, ParameterInfo pa // Compile the delegate to the BindAsync method for this parameter index var bindAsyncDelegate = Expression.Lambda>>(body, HttpContextExpr).Compile(); - factoryContext.ParameterBinders.Add((index, bindAsyncDelegate)); + factoryContext.ParameterBinders.Add(bindAsyncDelegate); // boundValues[index] - var boundValueExpr = Expression.ArrayIndex(BoundValuesArrayExpr, Expression.Constant(index)); + var boundValueExpr = Expression.ArrayIndex(BoundValuesArrayExpr, Expression.Constant(factoryContext.ParameterBinders.Count - 1)); if (!isOptional) { @@ -1097,8 +1096,7 @@ private class FactoryContext public bool UsingTempSourceString { get; set; } public List ExtraLocals { get; } = new(); public List ParamCheckExpressions { get; } = new(); - public List<(int, Func>)> ParameterBinders { get; } = new(); - public int ParameterCount { get; set; } + public List>> ParameterBinders { get; } = new(); public Dictionary TrackedParameters { get; } = new(); public bool HasMultipleBodyParameters { get; set; }