Skip to content

Added support for type based parameter binding #35496

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 122 additions & 57 deletions src/Http/Http.Extensions/src/RequestDelegateFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,11 @@ public static partial class RequestDelegateFactory
Log.ParameterBindingFailed(httpContext, parameterType, parameterName, sourceValue));
private static readonly MethodInfo LogRequiredParameterNotProvidedMethod = GetMethodInfo<Action<HttpContext, string, string>>((httpContext, parameterType, parameterName) =>
Log.RequiredParameterNotProvided(httpContext, parameterType, parameterName));
private static readonly MethodInfo LogParameterBindingFromHttpContextFailedMethod = GetMethodInfo<Action<HttpContext, string, string>>((httpContext, parameterType, parameterName) =>
Log.ParameterBindingFromHttpContextFailed(httpContext, parameterType, parameterName));

private static readonly ParameterExpression TargetExpr = Expression.Parameter(typeof(object), "target");
private static readonly ParameterExpression BodyValueExpr = Expression.Parameter(typeof(object), "bodyValue");
private static readonly ParameterExpression WasParamCheckFailureExpr = Expression.Variable(typeof(bool), "wasParamCheckFailure");
private static readonly ParameterExpression BoundValuesArrayExpr = Expression.Parameter(typeof(object[]), "boundValues");

private static ParameterExpression HttpContextExpr => TryParseMethodCache.HttpContextExpr;
private static readonly MemberExpression RequestServicesExpr = Expression.Property(HttpContextExpr, typeof(HttpContext).GetProperty(nameof(HttpContext.RequestServices))!);
Expand Down Expand Up @@ -198,7 +197,6 @@ private static Expression[] CreateArguments(ParameterInfo[]? parameters, Factory
{
var errorMessage = BuildErrorMessageForMultipleBodyParameters(factoryContext);
throw new InvalidOperationException(errorMessage);

}

return args;
Expand Down Expand Up @@ -263,9 +261,9 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext
{
return RequestAbortedExpr;
}
else if (TryParseMethodCache.HasTryParseHttpContextMethod(parameter))
else if (TryParseMethodCache.HasBindAsyncMethod(parameter))
{
return BindParameterFromTryParseHttpContext(parameter, factoryContext);
return BindParameterFromBindAsync(parameter, factoryContext);
}
else if (parameter.ParameterType == typeof(string) || TryParseMethodCache.HasTryParseStringMethod(parameter))
{
Expand All @@ -275,7 +273,6 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext
// when RDF.Create is manually invoked.
if (factoryContext.RouteParameters is { } routeParams)
{

if (routeParams.Contains(parameter.Name, StringComparer.OrdinalIgnoreCase))
{
// We're in the fallback case and we have a parameter and route parameter match so don't fallback
Expand Down Expand Up @@ -361,7 +358,6 @@ private static Expression CreateParamCheckingResponseWritingMethodCall(
var localVariables = new ParameterExpression[factoryContext.ExtraLocals.Count + 1];
var checkParamAndCallMethod = new Expression[factoryContext.ParamCheckExpressions.Count + 1];


for (var i = 0; i < factoryContext.ExtraLocals.Count; i++)
{
localVariables[i] = factoryContext.ExtraLocals[i];
Expand Down Expand Up @@ -508,14 +504,33 @@ private static Expression AddResponseWritingToMethodCall(Expression methodCall,
{
if (factoryContext.JsonRequestBodyType is null)
{
if (factoryContext.ParameterBinders.Count > 0)
{
// We need to generate the code for reading from the custom binders calling into the delegate
var continuation = Expression.Lambda<Func<object?, HttpContext, object?[], Task>>(
responseWritingMethodCall, TargetExpr, HttpContextExpr, BoundValuesArrayExpr).Compile();

// Looping over arrays is faster
var binders = factoryContext.ParameterBinders.ToArray();
var count = binders.Length;

return async (target, httpContext) =>
{
var boundValues = new object?[count];

for (var i = 0; i < count; i++)
{
boundValues[i] = await binders[i](httpContext);
}

await continuation(target, httpContext, boundValues);
};
}

return Expression.Lambda<Func<object?, HttpContext, Task>>(
responseWritingMethodCall, TargetExpr, HttpContextExpr).Compile();
}

// We need to generate the code for reading from the body before calling into the delegate
var invoker = Expression.Lambda<Func<object?, HttpContext, object?, Task>>(
responseWritingMethodCall, TargetExpr, HttpContextExpr, BodyValueExpr).Compile();

var bodyType = factoryContext.JsonRequestBodyType;
object? defaultBodyValue = null;

Expand All @@ -524,31 +539,82 @@ private static Expression AddResponseWritingToMethodCall(Expression methodCall,
defaultBodyValue = Activator.CreateInstance(bodyType);
}

return async (target, httpContext) =>
if (factoryContext.ParameterBinders.Count > 0)
{
object? bodyValue = defaultBodyValue;
var feature = httpContext.Features.Get<IHttpRequestBodyDetectionFeature>();
if (feature?.CanHaveBody == true)
// We need to generate the code for reading from the body before calling into the delegate
var continuation = Expression.Lambda<Func<object?, HttpContext, object?, object?[], Task>>(
responseWritingMethodCall, TargetExpr, HttpContextExpr, BodyValueExpr, BoundValuesArrayExpr).Compile();

// Looping over arrays is faster
var binders = factoryContext.ParameterBinders.ToArray();
var count = binders.Length;

return async (target, httpContext) =>
{
try
// Run these first so that they can potentially read and rewind the body
var boundValues = new object?[count];

for (var i = 0; i < count; i++)
{
bodyValue = await httpContext.Request.ReadFromJsonAsync(bodyType);
boundValues[i] = await binders[i](httpContext);
}
catch (IOException ex)

var bodyValue = defaultBodyValue;
var feature = httpContext.Features.Get<IHttpRequestBodyDetectionFeature>();
if (feature?.CanHaveBody == true)
{
Log.RequestBodyIOException(httpContext, ex);
return;
try
{
bodyValue = await httpContext.Request.ReadFromJsonAsync(bodyType);
}
catch (IOException ex)
{
Log.RequestBodyIOException(httpContext, ex);
return;
}
catch (InvalidDataException ex)
{
Log.RequestBodyInvalidDataException(httpContext, ex);
httpContext.Response.StatusCode = 400;
return;
}
}
catch (InvalidDataException ex)
{

Log.RequestBodyInvalidDataException(httpContext, ex);
httpContext.Response.StatusCode = 400;
return;
await continuation(target, httpContext, bodyValue, boundValues);
};
}
else
{
// We need to generate the code for reading from the body before calling into the delegate
var continuation = Expression.Lambda<Func<object?, HttpContext, object?, Task>>(
responseWritingMethodCall, TargetExpr, HttpContextExpr, BodyValueExpr).Compile();

return async (target, httpContext) =>
{
var bodyValue = defaultBodyValue;
var feature = httpContext.Features.Get<IHttpRequestBodyDetectionFeature>();
if (feature?.CanHaveBody == true)
{
try
{
bodyValue = await httpContext.Request.ReadFromJsonAsync(bodyType);
}
catch (IOException ex)
{
Log.RequestBodyIOException(httpContext, ex);
return;
}
catch (InvalidDataException ex)
{

Log.RequestBodyInvalidDataException(httpContext, ex);
httpContext.Response.StatusCode = 400;
return;
}
}
}
await invoker(target, httpContext, bodyValue);
};
await continuation(target, httpContext, bodyValue);
};
}
}

private static Expression GetValueFromProperty(Expression sourceExpression, string key)
Expand Down Expand Up @@ -747,40 +813,40 @@ private static Expression BindParameterFromRouteValueOrQueryString(ParameterInfo
return BindParameterFromValue(parameter, Expression.Coalesce(routeValue, queryValue), factoryContext);
}

private static Expression BindParameterFromTryParseHttpContext(ParameterInfo parameter, FactoryContext factoryContext)
private static Expression BindParameterFromBindAsync(ParameterInfo parameter, FactoryContext factoryContext)
{
// bool wasParamCheckFailure = false;
//
// // Assume "Foo param1" is the first parameter and "public static bool TryParse(HttpContext context, out Foo foo)" exists.
// Foo param1_local;
//
// if (!Foo.TryParse(httpContext, out param1_local))
// {
// wasParamCheckFailure = true;
// Log.ParameterBindingFromHttpContextFailed(httpContext, "Foo", "foo")
// }

var argument = Expression.Variable(parameter.ParameterType, $"{parameter.Name}_local");
var tryParseMethodCall = TryParseMethodCache.FindTryParseHttpContextMethod(parameter.ParameterType);
// We reference the boundValues array by parameter index here
var nullability = NullabilityContext.Create(parameter);
var isOptional = parameter.HasDefaultValue || nullability.ReadState == NullabilityState.Nullable;

// There's no way to opt-in to using a TryParse method on HttpContext other than defining the method, so it's guaranteed to exist here.
Debug.Assert(tryParseMethodCall is not null);
// Get the BindAsync method
var body = TryParseMethodCache.FindBindAsyncMethod(parameter.ParameterType)!;

var parameterTypeNameConstant = Expression.Constant(parameter.ParameterType.Name);
var parameterNameConstant = Expression.Constant(parameter.Name);
// Compile the delegate to the BindAsync method for this parameter index
var bindAsyncDelegate = Expression.Lambda<Func<HttpContext, ValueTask<object?>>>(body, HttpContextExpr).Compile();
factoryContext.ParameterBinders.Add(bindAsyncDelegate);

var failBlock = Expression.Block(
Expression.Assign(WasParamCheckFailureExpr, Expression.Constant(true)),
Expression.Call(LogParameterBindingFromHttpContextFailedMethod,
HttpContextExpr, parameterTypeNameConstant, parameterNameConstant));
// boundValues[index]
var boundValueExpr = Expression.ArrayIndex(BoundValuesArrayExpr, Expression.Constant(factoryContext.ParameterBinders.Count - 1));

var tryParseCall = tryParseMethodCall(argument);
var fullParamCheckBlock = Expression.IfThen(Expression.Not(tryParseCall), failBlock);
if (!isOptional)
{
var checkRequiredBodyBlock = Expression.Block(
Expression.IfThen(
Expression.Equal(boundValueExpr, Expression.Constant(null)),
Expression.Block(
Expression.Assign(WasParamCheckFailureExpr, Expression.Constant(true)),
Expression.Call(LogRequiredParameterNotProvidedMethod,
HttpContextExpr, Expression.Constant(parameter.ParameterType.Name), Expression.Constant(parameter.Name))
)
)
);

factoryContext.ExtraLocals.Add(argument);
factoryContext.ParamCheckExpressions.Add(fullParamCheckBlock);
factoryContext.ParamCheckExpressions.Add(checkRequiredBodyBlock);
}

return argument;
// (ParamterType)boundValues[i]
return Expression.Convert(boundValueExpr, parameter.ParameterType);
}

private static Expression BindParameterFromBody(ParameterInfo parameter, bool allowEmpty, FactoryContext factoryContext)
Expand All @@ -793,7 +859,6 @@ private static Expression BindParameterFromBody(ParameterInfo parameter, bool al
{
factoryContext.TrackedParameters.Remove(parameterName);
factoryContext.TrackedParameters.Add(parameterName, "UNKNOWN");

}
}

Expand Down Expand Up @@ -925,7 +990,6 @@ static async Task ExecuteAwaited(Task<T> task, HttpContext httpContext)

private static Task ExecuteTaskOfString(Task<string?> task, HttpContext httpContext)
{

SetPlaintextContentType(httpContext);
EnsureRequestTaskNotNull(task);

Expand Down Expand Up @@ -1032,6 +1096,7 @@ private class FactoryContext
public bool UsingTempSourceString { get; set; }
public List<ParameterExpression> ExtraLocals { get; } = new();
public List<Expression> ParamCheckExpressions { get; } = new();
public List<Func<HttpContext, ValueTask<object?>>> ParameterBinders { get; } = new();

public Dictionary<string, string> TrackedParameters { get; } = new();
public bool HasMultipleBodyParameters { get; set; }
Expand Down
Loading