Skip to content

Support resolving keyed services from DI in RDF and RDG #50095

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/Http/Http.Extensions/gen/DiagnosticDescriptors.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,12 @@ internal static class DiagnosticDescriptors
"Usage",
DiagnosticSeverity.Warning,
isEnabledByDefault: true);

public static DiagnosticDescriptor KeyedAndNotKeyedServiceAttributesNotSupported { get; } = new(
"RDG013",
new LocalizableResourceString(nameof(Resources.KeyedAndNotKeyedServiceAttributesNotSupported_Title), Resources.ResourceManager, typeof(Resources)),
new LocalizableResourceString(nameof(Resources.KeyedAndNotKeyedServiceAttributesNotSupported_Message), Resources.ResourceManager, typeof(Resources)),
"Usage",
DiagnosticSeverity.Warning,
isEnabledByDefault: true);
}
6 changes: 6 additions & 0 deletions src/Http/Http.Extensions/gen/Resources.resx
Original file line number Diff line number Diff line change
Expand Up @@ -189,4 +189,10 @@
<data name="InaccessibleTypesNotSupported_Message" xml:space="preserve">
<value>Encountered inaccessible type '{0}' while processing endpoint. Compile-time endpoint generation will skip this endpoint.</value>
</data>
<data name="KeyedAndNotKeyedServiceAttributesNotSupported_Title" xml:space="preserve">
<value>Invalid source attributes</value>
</data>
<data name="KeyedAndNotKeyedServiceAttributesNotSupported_Message" xml:space="preserve">
<value>The [FromKeyedServices] attribute is not supported on parameters that are also annotated with IFromServiceMetadata.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ internal static string EmitParameterPreparation(this IEnumerable<EndpointParamet
case EndpointParameterSource.Service:
parameter.EmitServiceParameterPreparation(parameterPreparationBuilder);
break;
case EndpointParameterSource.KeyedService:
parameter.EmitKeyedServiceParameterPreparation(parameterPreparationBuilder);
break;
case EndpointParameterSource.AsParameters:
parameter.EmitAsParametersParameterPreparation(parameterPreparationBuilder, emitterContext);
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,16 @@ internal static void EmitServiceParameterPreparation(this EndpointParameter endp
codeWriter.WriteLine($"var {endpointParameter.EmitHandlerArgument()} = {assigningCode};");
}

internal static void EmitKeyedServiceParameterPreparation(this EndpointParameter endpointParameter, CodeWriter codeWriter)
{
codeWriter.WriteLine(endpointParameter.EmitParameterDiagnosticComment());

var assigningCode = endpointParameter.IsOptional ?
$"httpContext.RequestServices.GetKeyedService<{endpointParameter.Type}>({endpointParameter.KeyedServiceKey});" :
$"httpContext.RequestServices.GetRequiredKeyedService<{endpointParameter.Type}>({endpointParameter.KeyedServiceKey})";
codeWriter.WriteLine($"var {endpointParameter.EmitHandlerArgument()} = {assigningCode};");
}

internal static void EmitAsParametersParameterPreparation(this EndpointParameter endpointParameter, CodeWriter codeWriter, EmitterContext emitterContext)
{
codeWriter.WriteLine(endpointParameter.EmitParameterDiagnosticComment());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,17 @@ private void ProcessEndpointParameterSource(Endpoint endpoint, ISymbol symbol, I
else if (attributes.HasAttributeImplementingInterface(wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_Metadata_IFromServiceMetadata)))
{
Source = EndpointParameterSource.Service;
if (attributes.TryGetAttribute(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute))
{
var location = endpoint.Operation.Syntax.GetLocation();
endpoint.Diagnostics.Add(Diagnostic.Create(DiagnosticDescriptors.KeyedAndNotKeyedServiceAttributesNotSupported, location));
}
}
else if (attributes.TryGetAttribute(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute))
{
Source = EndpointParameterSource.KeyedService;
var constructorArgument = keyedServicesAttribute.ConstructorArguments.FirstOrDefault();
KeyedServiceKey = SymbolDisplay.FormatPrimitive(constructorArgument.Value!, true, true);
}
else if (attributes.HasAttribute(wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_AsParametersAttribute)))
{
Expand Down Expand Up @@ -260,6 +271,7 @@ private static bool ImplementsIEndpointParameterMetadataProvider(ITypeSymbol typ
public string? PropertyAsParameterInfoConstruction { get; set; }
public IEnumerable<EndpointParameter>? EndpointParameters { get; set; }
public bool IsFormFile { get; set; }
public string? KeyedServiceKey { get; set; }

// Only used for SpecialType parameters that need
// to be resolved by a specific WellKnownType
Expand Down Expand Up @@ -613,15 +625,17 @@ obj is EndpointParameter other &&
other.SymbolName == SymbolName &&
other.Ordinal == Ordinal &&
other.IsOptional == IsOptional &&
SymbolEqualityComparer.IncludeNullability.Equals(other.Type, Type);
SymbolEqualityComparer.IncludeNullability.Equals(other.Type, Type) &&
other.KeyedServiceKey == KeyedServiceKey;

public bool SignatureEquals(object obj) =>
obj is EndpointParameter other &&
SymbolEqualityComparer.IncludeNullability.Equals(other.Type, Type) &&
// The name of the parameter matters when we are querying for a specific parameter using
// an indexer, like `context.Request.RouteValues["id"]` or `context.Request.Query["id"]`
// and when generating log messages for required bodies or services.
other.SymbolName == SymbolName;
other.SymbolName == SymbolName &&
other.KeyedServiceKey == KeyedServiceKey;

public override int GetHashCode()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ internal enum EndpointParameterSource
JsonBodyOrService,
FormBody,
Service,
KeyedService,
// SpecialType refers to HttpContext, HttpRequest, CancellationToken, Stream, etc...
// that are specially checked for in RequestDelegateFactory.CreateArgument()
SpecialType,
Expand Down
27 changes: 27 additions & 0 deletions src/Http/Http.Extensions/src/RequestDelegateFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ public static partial class RequestDelegateFactory
private static readonly MethodInfo ExecuteAwaitedReturnMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteAwaitedReturn), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo GetRequiredServiceMethod = typeof(ServiceProviderServiceExtensions).GetMethod(nameof(ServiceProviderServiceExtensions.GetRequiredService), BindingFlags.Public | BindingFlags.Static, new Type[] { typeof(IServiceProvider) })!;
private static readonly MethodInfo GetServiceMethod = typeof(ServiceProviderServiceExtensions).GetMethod(nameof(ServiceProviderServiceExtensions.GetService), BindingFlags.Public | BindingFlags.Static, new Type[] { typeof(IServiceProvider) })!;
private static readonly MethodInfo GetRequiredKeyedServiceMethod = typeof(ServiceProviderKeyedServiceExtensions).GetMethod(nameof(ServiceProviderKeyedServiceExtensions.GetRequiredKeyedService), BindingFlags.Public | BindingFlags.Static, new Type[] { typeof(IServiceProvider), typeof(object) })!;
private static readonly MethodInfo GetKeyedServiceMethod = typeof(ServiceProviderKeyedServiceExtensions).GetMethod(nameof(ServiceProviderKeyedServiceExtensions.GetKeyedService), BindingFlags.Public | BindingFlags.Static, new Type[] { typeof(IServiceProvider), typeof(object) })!;
private static readonly MethodInfo ResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteResultWriteResponse), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo StringResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteWriteStringResponseAsync), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo StringIsNullOrEmptyMethod = typeof(string).GetMethod(nameof(string.IsNullOrEmpty), BindingFlags.Static | BindingFlags.Public)!;
Expand Down Expand Up @@ -761,9 +763,19 @@ private static Expression CreateArgument(ParameterInfo parameter, RequestDelegat
}
else if (parameter.CustomAttributes.Any(a => typeof(IFromServiceMetadata).IsAssignableFrom(a.AttributeType)))
{
if (parameterCustomAttributes.OfType<FromKeyedServicesAttribute>().FirstOrDefault() is not null)
{
throw new NotSupportedException(
$"The {nameof(FromKeyedServicesAttribute)} is not supported on parameters that are also annotated with {nameof(IFromServiceMetadata)}.");
}
factoryContext.TrackedParameters.Add(parameter.Name, RequestDelegateFactoryConstants.ServiceAttribute);
return BindParameterFromService(parameter, factoryContext);
}
else if (parameterCustomAttributes.OfType<FromKeyedServicesAttribute>().FirstOrDefault() is { } keyedServicesAttribute)
{
var key = keyedServicesAttribute.Key;
return BindParameterFromKeyedService(parameter, key, factoryContext);
}
else if (parameterCustomAttributes.OfType<AsParametersAttribute>().Any())
{
if (parameter is PropertyAsParameterInfo)
Expand Down Expand Up @@ -1563,6 +1575,21 @@ private static Expression BindParameterFromService(ParameterInfo parameter, Requ
return Expression.Call(GetRequiredServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr);
}

private static Expression BindParameterFromKeyedService(ParameterInfo parameter, object key, RequestDelegateFactoryContext factoryContext)
{
var isOptional = IsOptionalParameter(parameter, factoryContext);

if (isOptional)
{
return Expression.Call(GetKeyedServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr, Expression.Convert(
Expression.Constant(key),
typeof(object)));
}
return Expression.Call(GetRequiredKeyedServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr, Expression.Convert(
Expression.Constant(key),
typeof(object)));
}

private static Expression BindParameterFromValue(ParameterInfo parameter, Expression valueExpression, RequestDelegateFactoryContext factoryContext, string source)
{
if (parameter.ParameterType == typeof(string) || parameter.ParameterType == typeof(string[])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ namespace Microsoft.AspNetCore.Http.Generated
private static readonly JsonOptions FallbackJsonOptions = new();
private static readonly string[] GetVerb = new[] { global::Microsoft.AspNetCore.Http.HttpMethods.Get };

[InterceptsLocation(@"TestMapActions.cs", 25, 13)]
[InterceptsLocation(@"TestMapActions.cs", 26, 5)]
[InterceptsLocation(@"TestMapActions.cs", 26, 13)]
[InterceptsLocation(@"TestMapActions.cs", 27, 5)]
internal static RouteHandlerBuilder MapGet0(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -166,8 +166,8 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 28, 5)]
[InterceptsLocation(@"TestMapActions.cs", 29, 5)]
[InterceptsLocation(@"TestMapActions.cs", 30, 5)]
internal static RouteHandlerBuilder MapGet1(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ namespace Microsoft.AspNetCore.Http.Generated
private static readonly JsonOptions FallbackJsonOptions = new();
private static readonly string[] GetVerb = new[] { global::Microsoft.AspNetCore.Http.HttpMethods.Get };

[InterceptsLocation(@"TestMapActions.cs", 25, 13)]
[InterceptsLocation(@"TestMapActions.cs", 26, 13)]
internal static RouteHandlerBuilder MapGet0(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -167,7 +167,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 26, 5)]
[InterceptsLocation(@"TestMapActions.cs", 27, 5)]
internal static RouteHandlerBuilder MapGet1(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -262,7 +262,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 27, 5)]
[InterceptsLocation(@"TestMapActions.cs", 28, 5)]
internal static RouteHandlerBuilder MapGet2(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -359,7 +359,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 28, 5)]
[InterceptsLocation(@"TestMapActions.cs", 29, 5)]
internal static RouteHandlerBuilder MapGet3(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -454,7 +454,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 29, 5)]
[InterceptsLocation(@"TestMapActions.cs", 30, 5)]
internal static RouteHandlerBuilder MapGet4(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -563,7 +563,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 30, 5)]
[InterceptsLocation(@"TestMapActions.cs", 31, 5)]
internal static RouteHandlerBuilder MapGet5(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -658,7 +658,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 31, 5)]
[InterceptsLocation(@"TestMapActions.cs", 32, 5)]
internal static RouteHandlerBuilder MapGet6(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -755,7 +755,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 32, 5)]
[InterceptsLocation(@"TestMapActions.cs", 33, 5)]
internal static RouteHandlerBuilder MapGet7(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -850,7 +850,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 33, 5)]
[InterceptsLocation(@"TestMapActions.cs", 34, 5)]
internal static RouteHandlerBuilder MapGet8(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -958,7 +958,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 34, 5)]
[InterceptsLocation(@"TestMapActions.cs", 35, 5)]
internal static RouteHandlerBuilder MapGet9(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -1052,7 +1052,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 35, 5)]
[InterceptsLocation(@"TestMapActions.cs", 36, 5)]
internal static RouteHandlerBuilder MapGet10(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -1148,7 +1148,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 36, 5)]
[InterceptsLocation(@"TestMapActions.cs", 37, 5)]
internal static RouteHandlerBuilder MapGet11(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -1242,7 +1242,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 37, 5)]
[InterceptsLocation(@"TestMapActions.cs", 38, 5)]
internal static RouteHandlerBuilder MapGet12(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -1351,7 +1351,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 38, 5)]
[InterceptsLocation(@"TestMapActions.cs", 39, 5)]
internal static RouteHandlerBuilder MapGet13(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -1446,7 +1446,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 39, 5)]
[InterceptsLocation(@"TestMapActions.cs", 40, 5)]
internal static RouteHandlerBuilder MapGet14(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -1554,7 +1554,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 40, 5)]
[InterceptsLocation(@"TestMapActions.cs", 41, 5)]
internal static RouteHandlerBuilder MapGet15(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -1648,7 +1648,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 41, 5)]
[InterceptsLocation(@"TestMapActions.cs", 42, 5)]
internal static RouteHandlerBuilder MapGet16(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -1757,7 +1757,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 42, 5)]
[InterceptsLocation(@"TestMapActions.cs", 43, 5)]
internal static RouteHandlerBuilder MapGet17(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -1852,7 +1852,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 43, 5)]
[InterceptsLocation(@"TestMapActions.cs", 44, 5)]
internal static RouteHandlerBuilder MapGet18(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down Expand Up @@ -1960,7 +1960,7 @@ namespace Microsoft.AspNetCore.Http.Generated
createRequestDelegate);
}

[InterceptsLocation(@"TestMapActions.cs", 44, 5)]
[InterceptsLocation(@"TestMapActions.cs", 45, 5)]
internal static RouteHandlerBuilder MapGet19(
this IEndpointRouteBuilder endpoints,
[StringSyntax("Route")] string pattern,
Expand Down
Loading