Skip to content

[release/6.0-rc2] Add support for BindAsync without ParameterInfo #36590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/Http/Http.Extensions/src/RequestDelegateFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -840,12 +840,12 @@ private static Expression BindParameterFromBindAsync(ParameterInfo parameter, Fa
var isOptional = IsOptionalParameter(parameter, factoryContext);

// Get the BindAsync method for the type.
var bindAsyncExpression = ParameterBindingMethodCache.FindBindAsyncMethod(parameter);
var bindAsyncMethod = ParameterBindingMethodCache.FindBindAsyncMethod(parameter);
// We know BindAsync exists because there's no way to opt-in without defining the method on the type.
Debug.Assert(bindAsyncExpression is not null);
Debug.Assert(bindAsyncMethod.Expression is not null);

// Compile the delegate to the BindAsync method for this parameter index
var bindAsyncDelegate = Expression.Lambda<Func<HttpContext, ValueTask<object?>>>(bindAsyncExpression, HttpContextExpr).Compile();
var bindAsyncDelegate = Expression.Lambda<Func<HttpContext, ValueTask<object?>>>(bindAsyncMethod.Expression, HttpContextExpr).Compile();
factoryContext.ParameterBinders.Add(bindAsyncDelegate);

// boundValues[index]
Expand All @@ -854,6 +854,7 @@ private static Expression BindParameterFromBindAsync(ParameterInfo parameter, Fa
if (!isOptional)
{
var typeName = TypeNameHelper.GetTypeDisplayName(parameter.ParameterType, fullName: false);
var message = bindAsyncMethod.ParamCount == 2 ? $"{typeName}.BindAsync(HttpContext, ParameterInfo)" : $"{typeName}.BindAsync(HttpContext)";
var checkRequiredBodyBlock = Expression.Block(
Expression.IfThen(
Expression.Equal(boundValueExpr, Expression.Constant(null)),
Expand All @@ -863,7 +864,7 @@ private static Expression BindParameterFromBindAsync(ParameterInfo parameter, Fa
HttpContextExpr,
Expression.Constant(typeName),
Expression.Constant(parameter.Name),
Expression.Constant($"{typeName}.BindAsync(HttpContext, ParameterInfo)"),
Expression.Constant(message),
Expression.Constant(factoryContext.ThrowOnBadRequest))
)
)
Expand Down
81 changes: 79 additions & 2 deletions src/Http/Http.Extensions/test/ParameterBindingMethodCacheTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,13 @@ public async Task FindBindAsyncMethod_FindsCorrectMethodOnClass()
var parameter = new MockParameterInfo(type, "bindAsyncRecord");
var methodFound = cache.FindBindAsyncMethod(parameter);

Assert.NotNull(methodFound);
Assert.NotNull(methodFound.Expression);
Assert.Equal(2, methodFound.ParamCount);

var parsedValue = Expression.Variable(type, "parsedValue");

var parseHttpContext = Expression.Lambda<Func<HttpContext, ValueTask<object>>>(
Expression.Block(new[] { parsedValue }, methodFound!),
Expression.Block(new[] { parsedValue }, methodFound.Expression!),
ParameterBindingMethodCache.HttpContextExpr).Compile();

var httpContext = new DefaultHttpContext
Expand All @@ -195,6 +196,37 @@ public async Task FindBindAsyncMethod_FindsCorrectMethodOnClass()
Assert.Equal(new BindAsyncRecord(42), await parseHttpContext(httpContext));
}

[Fact]
public async Task FindBindAsyncMethod_FindsSingleArgBindAsync()
{
var type = typeof(BindAsyncSingleArgStruct);
var cache = new ParameterBindingMethodCache();
var parameter = new MockParameterInfo(type, "bindAsyncSingleArgStruct");
var methodFound = cache.FindBindAsyncMethod(parameter);

Assert.NotNull(methodFound.Expression);
Assert.Equal(1, methodFound.ParamCount);

var parsedValue = Expression.Variable(type, "parsedValue");

var parseHttpContext = Expression.Lambda<Func<HttpContext, ValueTask<object>>>(
Expression.Block(new[] { parsedValue }, methodFound.Expression!),
ParameterBindingMethodCache.HttpContextExpr).Compile();

var httpContext = new DefaultHttpContext
{
Request =
{
Headers =
{
["ETag"] = "42",
},
},
};

Assert.Equal(new BindAsyncSingleArgStruct(42), await parseHttpContext(httpContext));
}

public static IEnumerable<object[]> BindAsyncParameterInfoData
{
get
Expand All @@ -209,6 +241,14 @@ public static IEnumerable<object[]> BindAsyncParameterInfoData
{
GetFirstParameter((BindAsyncStruct arg) => BindAsyncStructMethod(arg)),
},
new[]
{
GetFirstParameter((BindAsyncSingleArgRecord arg) => BindAsyncSingleArgRecordMethod(arg)),
},
new[]
{
GetFirstParameter((BindAsyncSingleArgStruct arg) => BindAsyncSingleArgStructMethod(arg)),
}
};
}
}
Expand Down Expand Up @@ -250,6 +290,11 @@ private static void BindAsyncStructMethod(BindAsyncStruct arg) { }
private static void BindAsyncNullableStructMethod(BindAsyncStruct? arg) { }
private static void NullableReturningBindAsyncStructMethod(NullableReturningBindAsyncStruct arg) { }

private static void BindAsyncSingleArgRecordMethod(BindAsyncSingleArgRecord arg) { }
private static void BindAsyncSingleArgStructMethod(BindAsyncSingleArgStruct arg) { }
private static void BindAsyncNullableSingleArgStructMethod(BindAsyncSingleArgStruct? arg) { }
private static void NullableReturningBindAsyncSingleArgStructMethod(NullableReturningBindAsyncSingleArgStruct arg) { }

private static ParameterInfo GetFirstParameter<T>(Expression<Action<T>> expr)
{
var mc = (MethodCallExpression)expr.Body;
Expand Down Expand Up @@ -324,6 +369,38 @@ private record struct NullableReturningBindAsyncStruct(int Value)
throw new NotImplementedException();
}

private record BindAsyncSingleArgRecord(int Value)
{
public static ValueTask<BindAsyncSingleArgRecord?> BindAsync(HttpContext context)
{
if (!int.TryParse(context.Request.Headers.ETag, out var val))
{
return new(result: null);
}

return new(result: new(val));
}
}

private record struct BindAsyncSingleArgStruct(int Value)
{
public static ValueTask<BindAsyncSingleArgStruct> BindAsync(HttpContext context)
{
if (!int.TryParse(context.Request.Headers.ETag, out var val))
{
throw new BadHttpRequestException("The request is missing the required ETag header.");
}

return new(result: new(val));
}
}

private record struct NullableReturningBindAsyncSingleArgStruct(int Value)
{
public static ValueTask<NullableReturningBindAsyncStruct?> BindAsync(HttpContext context, ParameterInfo parameter) =>
throw new NotImplementedException();
}

private class MockParameterInfo : ParameterInfo
{
public MockParameterInfo(Type type, string name)
Expand Down
144 changes: 141 additions & 3 deletions src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,54 @@ public static async ValueTask<MyAwaitedBindAsyncStruct> BindAsync(HttpContext co
}
}

private record struct MyBothBindAsyncStruct(Uri Uri)
{
public static ValueTask<MyBothBindAsyncStruct> BindAsync(HttpContext context, ParameterInfo parameter)
{
Assert.True(parameter.ParameterType == typeof(MyBothBindAsyncStruct) || parameter.ParameterType == typeof(MyBothBindAsyncStruct?));
Assert.Equal("myBothBindAsyncStruct", parameter.Name);

if (!Uri.TryCreate(context.Request.Headers.Referer, UriKind.Absolute, out var uri))
{
throw new BadHttpRequestException("The request is missing the required Referer header.");
}

return new(result: new(uri));
}

// BindAsync with ParameterInfo is preferred
public static ValueTask<MyBothBindAsyncStruct> BindAsync(HttpContext context)
{
throw new NotImplementedException();
}
}

private record struct MySimpleBindAsyncStruct(Uri Uri)
{
public static ValueTask<MySimpleBindAsyncStruct> BindAsync(HttpContext context)
{
if (!Uri.TryCreate(context.Request.Headers.Referer, UriKind.Absolute, out var uri))
{
throw new BadHttpRequestException("The request is missing the required Referer header.");
}

return new(result: new(uri));
}
}

private record MySimpleBindAsyncRecord(Uri Uri)
{
public static ValueTask<MySimpleBindAsyncRecord?> BindAsync(HttpContext context)
{
if (!Uri.TryCreate(context.Request.Headers.Referer, UriKind.Absolute, out var uri))
{
return new(result: null);
}

return new(result: new(uri));
}
}

[Theory]
[MemberData(nameof(TryParsableParameters))]
public async Task RequestDelegatePopulatesUnattributedTryParsableParametersFromRouteValue(Delegate action, string? routeValue, object? expectedParameterValue)
Expand Down Expand Up @@ -724,6 +772,24 @@ public async Task RequestDelegateUsesBindAsyncOverTryParseGivenNullableStruct()
Assert.Equal(new MyBindAsyncStruct(new Uri("https://example.org")), httpContext.Items["myBindAsyncStruct"]);
}

[Fact]
public async Task RequestDelegateUsesParameterInfoBindAsyncOverOtherBindAsync()
{
var httpContext = CreateHttpContext();

httpContext.Request.Headers.Referer = "https://example.org";

var resultFactory = RequestDelegateFactory.Create((HttpContext httpContext, MyBothBindAsyncStruct? myBothBindAsyncStruct) =>
{
httpContext.Items["myBothBindAsyncStruct"] = myBothBindAsyncStruct;
});

var requestDelegate = resultFactory.RequestDelegate;
await requestDelegate(httpContext);

Assert.Equal(new MyBothBindAsyncStruct(new Uri("https://example.org")), httpContext.Items["myBothBindAsyncStruct"]);
}

[Fact]
public async Task RequestDelegateUsesTryParseOverBindAsyncGivenExplicitAttribute()
{
Expand Down Expand Up @@ -873,7 +939,7 @@ void TestAction([FromRoute] int tryParsable, [FromRoute] int tryParsable2)
[Fact]
public async Task RequestDelegateLogsBindAsyncFailuresAndSets400Response()
{
// Not supplying any headers will cause the HttpContext TryParse overload to fail.
// Not supplying any headers will cause the HttpContext BindAsync overload to return null.
var httpContext = CreateHttpContext();
var invoked = false;

Expand Down Expand Up @@ -905,7 +971,7 @@ public async Task RequestDelegateLogsBindAsyncFailuresAndSets400Response()
[Fact]
public async Task RequestDelegateLogsBindAsyncFailuresAndThrowsIfThrowOnBadRequest()
{
// Not supplying any headers will cause the HttpContext TryParse overload to fail.
// Not supplying any headers will cause the HttpContext BindAsync overload to return null.
var httpContext = CreateHttpContext();
var invoked = false;

Expand All @@ -931,10 +997,72 @@ public async Task RequestDelegateLogsBindAsyncFailuresAndThrowsIfThrowOnBadReque
Assert.Equal(400, badHttpRequestException.StatusCode);
}

[Fact]
public async Task RequestDelegateLogsSingleArgBindAsyncFailuresAndSets400Response()
{
// Not supplying any headers will cause the HttpContext BindAsync overload to return null.
var httpContext = CreateHttpContext();
var invoked = false;

var factoryResult = RequestDelegateFactory.Create((MySimpleBindAsyncRecord mySimpleBindAsyncRecord1,
MySimpleBindAsyncRecord mySimpleBindAsyncRecord2) =>
{
invoked = true;
});

var requestDelegate = factoryResult.RequestDelegate;
await requestDelegate(httpContext);

Assert.False(invoked);
Assert.False(httpContext.RequestAborted.IsCancellationRequested);
Assert.Equal(400, httpContext.Response.StatusCode);

var logs = TestSink.Writes.ToArray();

Assert.Equal(2, logs.Length);

Assert.Equal(new EventId(4, "RequiredParameterNotProvided"), logs[0].EventId);
Assert.Equal(LogLevel.Debug, logs[0].LogLevel);
Assert.Equal(@"Required parameter ""MySimpleBindAsyncRecord mySimpleBindAsyncRecord1"" was not provided from MySimpleBindAsyncRecord.BindAsync(HttpContext).", logs[0].Message);

Assert.Equal(new EventId(4, "RequiredParameterNotProvided"), logs[1].EventId);
Assert.Equal(LogLevel.Debug, logs[1].LogLevel);
Assert.Equal(@"Required parameter ""MySimpleBindAsyncRecord mySimpleBindAsyncRecord2"" was not provided from MySimpleBindAsyncRecord.BindAsync(HttpContext).", logs[1].Message);
}

[Fact]
public async Task RequestDelegateLogsSingleArgBindAsyncFailuresAndThrowsIfThrowOnBadRequest()
{
// Not supplying any headers will cause the HttpContext BindAsync overload to return null.
var httpContext = CreateHttpContext();
var invoked = false;

var factoryResult = RequestDelegateFactory.Create((MySimpleBindAsyncRecord mySimpleBindAsyncRecord1,
MySimpleBindAsyncRecord mySimpleBindAsyncRecord2) =>
{
invoked = true;
}, new() { ThrowOnBadRequest = true });

var requestDelegate = factoryResult.RequestDelegate;
var badHttpRequestException = await Assert.ThrowsAsync<BadHttpRequestException>(() => requestDelegate(httpContext));

Assert.False(invoked);

// The httpContext should be untouched.
Assert.False(httpContext.RequestAborted.IsCancellationRequested);
Assert.Equal(200, httpContext.Response.StatusCode);
Assert.False(httpContext.Response.HasStarted);

// We don't log bad requests when we throw.
Assert.Empty(TestSink.Writes);

Assert.Equal(@"Required parameter ""MySimpleBindAsyncRecord mySimpleBindAsyncRecord1"" was not provided from MySimpleBindAsyncRecord.BindAsync(HttpContext).", badHttpRequestException.Message);
Assert.Equal(400, badHttpRequestException.StatusCode);
}

[Fact]
public async Task BindAsyncExceptionsAreUncaught()
{
// Not supplying any headers will cause the HttpContext BindAsync overload to fail.
var httpContext = CreateHttpContext();

var factoryResult = RequestDelegateFactory.Create((MyBindAsyncTypeThatThrows arg1) => { });
Expand Down Expand Up @@ -2239,6 +2367,10 @@ void nullableReferenceType(HttpContext context, MyBindAsyncRecord? myBindAsyncRe
{
context.Items["uri"] = myBindAsyncRecord?.Uri;
}
void requiredReferenceTypeSimple(HttpContext context, MySimpleBindAsyncRecord mySimpleBindAsyncRecord)
{
context.Items["uri"] = mySimpleBindAsyncRecord.Uri;
}


void requiredValueType(HttpContext context, MyNullableBindAsyncStruct myNullableBindAsyncStruct)
Expand All @@ -2253,11 +2385,16 @@ void nullableValueType(HttpContext context, MyNullableBindAsyncStruct? myNullabl
{
context.Items["uri"] = myNullableBindAsyncStruct?.Uri;
}
void requiredValueTypeSimple(HttpContext context, MySimpleBindAsyncStruct mySimpleBindAsyncStruct)
{
context.Items["uri"] = mySimpleBindAsyncStruct.Uri;
}

return new object?[][]
{
new object?[] { (Action<HttpContext, MyBindAsyncRecord>)requiredReferenceType, false, true, false },
new object?[] { (Action<HttpContext, MyBindAsyncRecord>)requiredReferenceType, true, false, false, },
new object?[] { (Action<HttpContext, MySimpleBindAsyncRecord>)requiredReferenceTypeSimple, true, false, false },

new object?[] { (Action<HttpContext, MyBindAsyncRecord?>)defaultReferenceType, false, false, false, },
new object?[] { (Action<HttpContext, MyBindAsyncRecord?>)defaultReferenceType, true, false, false },
Expand All @@ -2267,6 +2404,7 @@ void nullableValueType(HttpContext context, MyNullableBindAsyncStruct? myNullabl

new object?[] { (Action<HttpContext, MyNullableBindAsyncStruct>)requiredValueType, false, true, true },
new object?[] { (Action<HttpContext, MyNullableBindAsyncStruct>)requiredValueType, true, false, true },
new object?[] { (Action<HttpContext, MySimpleBindAsyncStruct>)requiredValueTypeSimple, true, false, true },

new object?[] { (Action<HttpContext, MyNullableBindAsyncStruct?>)defaultValueType, false, false, true },
new object?[] { (Action<HttpContext, MyNullableBindAsyncStruct?>)defaultValueType, true, false, true },
Expand Down
Loading