From 554456c3804478e42e676e2347a02e9450c2d0ba Mon Sep 17 00:00:00 2001 From: Ryan Nowak Date: Wed, 3 Jul 2019 12:18:43 -0700 Subject: [PATCH 1/2] Fix nullable detection Fixes: #11828 and #11813 --- .../src/DataAnnotationsMetadataProvider.cs | 114 +++++++++++++++--- .../DataAnnotationsMetadataProviderTest.cs | 20 +-- .../NonNullableReferenceTypesTest.cs | 2 +- .../NullableReferenceTypeIntegrationTest.cs | 6 +- 4 files changed, 112 insertions(+), 30 deletions(-) diff --git a/src/Mvc/Mvc.DataAnnotations/src/DataAnnotationsMetadataProvider.cs b/src/Mvc/Mvc.DataAnnotations/src/DataAnnotationsMetadataProvider.cs index 97058965b454..46020580dbf2 100644 --- a/src/Mvc/Mvc.DataAnnotations/src/DataAnnotationsMetadataProvider.cs +++ b/src/Mvc/Mvc.DataAnnotations/src/DataAnnotationsMetadataProvider.cs @@ -2,11 +2,14 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections; using System.Collections.Generic; using System.ComponentModel; using System.ComponentModel.DataAnnotations; +using System.Diagnostics.Contracts; using System.Linq; using System.Reflection; +using System.Runtime.InteropServices.ComTypes; using Microsoft.AspNetCore.Mvc.ModelBinding; using Microsoft.AspNetCore.Mvc.ModelBinding.Metadata; using Microsoft.Extensions.Localization; @@ -27,6 +30,9 @@ internal class DataAnnotationsMetadataProvider : private const string NullableAttributeFullTypeName = "System.Runtime.CompilerServices.NullableAttribute"; private const string NullableFlagsFieldName = "NullableFlags"; + private const string NullableContextAttributeFullName = "System.Runtime.CompilerServices.NullableContextAttribute"; + private const string NullableContextFlagsFieldName = "Flag"; + private readonly IStringLocalizerFactory _stringLocalizerFactory; private readonly MvcOptions _options; private readonly MvcDataAnnotationsLocalizationOptions _localizationOptions; @@ -350,20 +356,44 @@ public void CreateValidationMetadata(ValidationMetadataProviderContext context) if (!_options.SuppressImplicitRequiredAttributeForNonNullableReferenceTypes && requiredAttribute == null && !context.Key.ModelType.IsValueType && - - // Look specifically at attributes on the property/parameter. [Nullable] on - // the type has a different meaning. - IsNonNullable(context.ParameterAttributes ?? context.PropertyAttributes ?? Array.Empty())) + context.Key.MetadataKind != ModelMetadataKind.Type) { - // Since this behavior specifically relates to non-null-ness, we will use the non-default - // option to tolerate empty/whitespace strings. empty/whitespace INPUT will still result in - // a validation error by default because we convert empty/whitespace strings to null - // unless you say otherwise. - requiredAttribute = new RequiredAttribute() + var addInferredRequiredAttribute = false; + if (context.Key.MetadataKind == ModelMetadataKind.Type) + { + // Do nothing. + } + else if (context.Key.MetadataKind == ModelMetadataKind.Property) + { + addInferredRequiredAttribute = IsNullableReferenceType( + context.Key.ContainerType, + member: null, + context.PropertyAttributes); + } + else if (context.Key.MetadataKind == ModelMetadataKind.Parameter) + { + addInferredRequiredAttribute = IsNullableReferenceType( + context.Key.ParameterInfo?.Member.ReflectedType, + context.Key.ParameterInfo.Member, + context.ParameterAttributes); + } + else { - AllowEmptyStrings = true, - }; - attributes.Add(requiredAttribute); + throw new InvalidOperationException("Unsupported ModelMetadataKind: " + context.Key.MetadataKind); + } + + if (addInferredRequiredAttribute) + { + // Since this behavior specifically relates to non-null-ness, we will use the non-default + // option to tolerate empty/whitespace strings. empty/whitespace INPUT will still result in + // a validation error by default because we convert empty/whitespace strings to null + // unless you say otherwise. + requiredAttribute = new RequiredAttribute() + { + AllowEmptyStrings = true, + }; + attributes.Add(requiredAttribute); + } } if (requiredAttribute != null) @@ -419,8 +449,18 @@ private static string GetDisplayGroup(FieldInfo field) return string.Empty; } + internal static bool IsNullableReferenceType(Type containingType, MemberInfo member, IEnumerable attributes) + { + if (HasNullableAttribute(attributes, out var result)) + { + return result; + } + + return IsNullableBasedOnContext(containingType, member); + } + // Internal for testing - internal static bool IsNonNullable(IEnumerable attributes) + internal static bool HasNullableAttribute(IEnumerable attributes, out bool isNullable) { // [Nullable] is compiler synthesized, comparing by name. var nullableAttribute = attributes @@ -428,7 +468,8 @@ internal static bool IsNonNullable(IEnumerable attributes) .FirstOrDefault(); if (nullableAttribute == null) { - return false; + isNullable = false; + return false; // [Nullable] not found } // We don't handle cases where generics and NNRT are used. This runs into a @@ -443,10 +484,51 @@ internal static bool IsNonNullable(IEnumerable attributes) flags.Length >= 0 && flags[0] == 1) // First element is the property/parameter type. { - return true; + isNullable = true; + return true; // [Nullable] found and type is an NNRT } - return false; + isNullable = false; + return true; // [Nullable] found but type is not an NNRT + } + + internal static bool IsNullableBasedOnContext(Type containingType, MemberInfo member) + { + var attributes = member?.GetCustomAttributes(inherit: true) ?? Array.Empty(); + var isNullable = AttributesHasNullableContext(attributes); + if (isNullable != null) + { + return isNullable.Value; + } + + attributes = containingType.GetCustomAttributes(inherit: false); + isNullable = AttributesHasNullableContext(attributes); + if (isNullable != null) + { + return isNullable.Value; + } + + // If we don't find the attribute on the declaring type then repeat at the module level + attributes = containingType.Module.GetCustomAttributes(inherit: false); + isNullable = AttributesHasNullableContext(attributes); + return isNullable ?? false; + + bool? AttributesHasNullableContext(object[] attributes) + { + var nullableContextAttribute = attributes + .Where(a => string.Equals(a.GetType().FullName, NullableContextAttributeFullName, StringComparison.Ordinal)) + .FirstOrDefault(); + if (nullableContextAttribute != null) + { + if (nullableContextAttribute.GetType().GetField(NullableContextFlagsFieldName) is FieldInfo field && + field.GetValue(nullableContextAttribute) is byte @byte) + { + return @byte == 1; // [NullableContext] found + } + } + + return null; + } } } } diff --git a/src/Mvc/Mvc.DataAnnotations/test/DataAnnotationsMetadataProviderTest.cs b/src/Mvc/Mvc.DataAnnotations/test/DataAnnotationsMetadataProviderTest.cs index 2d260405ef3c..677c452d2d2e 100644 --- a/src/Mvc/Mvc.DataAnnotations/test/DataAnnotationsMetadataProviderTest.cs +++ b/src/Mvc/Mvc.DataAnnotations/test/DataAnnotationsMetadataProviderTest.cs @@ -1142,7 +1142,7 @@ public void CreateValidationMetadata_NoRequiredAttribute_IsRequiredLeftAlone(boo Assert.Equal(initialValue, context.ValidationMetadata.IsRequired); } - [Fact(Skip = "https://github.com/aspnet/AspNetCore/issues/11828")] + [Fact] public void CreateValidationMetadata_InfersRequiredAttribute_NoNonNullableProperty() { // Arrange @@ -1152,7 +1152,7 @@ public void CreateValidationMetadata_InfersRequiredAttribute_NoNonNullableProper typeof(NullableReferenceTypes), typeof(NullableReferenceTypes).GetProperty(nameof(NullableReferenceTypes.NonNullableReferenceType))); var key = ModelMetadataIdentity.ForProperty( - typeof(NullableReferenceTypes), + typeof(NullableReferenceTypes), nameof(NullableReferenceTypes.NonNullableReferenceType), typeof(string)); var context = new ValidationMetadataProviderContext(key, attributes); @@ -1325,7 +1325,7 @@ public void CreateValidationDetails_ValidatableObject_AlreadyInContext_Ignores() Assert.Same(attribute, validatorMetadata); } - [Fact(Skip = "https://github.com/aspnet/AspNetCore/issues/11828")] + [Fact] public void IsNonNullable_FindsNonNullableProperty() { // Arrange @@ -1333,7 +1333,7 @@ public void IsNonNullable_FindsNonNullableProperty() var property = type.GetProperty(nameof(NullableReferenceTypes.NonNullableReferenceType)); // Act - var result = DataAnnotationsMetadataProvider.IsNonNullable(property.GetCustomAttributes(inherit: true)); + var result = DataAnnotationsMetadataProvider.IsNullableReferenceType(type, member: null, property.GetCustomAttributes(inherit: true)); // Assert Assert.True(result); @@ -1347,13 +1347,13 @@ public void IsNonNullable_FindsNullableProperty() var property = type.GetProperty(nameof(NullableReferenceTypes.NullableReferenceType)); // Act - var result = DataAnnotationsMetadataProvider.IsNonNullable(property.GetCustomAttributes(inherit: true)); + var result = DataAnnotationsMetadataProvider.IsNullableReferenceType(type, member: null, property.GetCustomAttributes(inherit: true)); // Assert Assert.False(result); } - [Fact(Skip = "https://github.com/aspnet/AspNetCore/issues/11828")] + [Fact] public void IsNonNullable_FindsNonNullableParameter() { // Arrange @@ -1362,7 +1362,7 @@ public void IsNonNullable_FindsNonNullableParameter() var parameter = method.GetParameters().Where(p => p.Name == "nonNullableParameter").Single(); // Act - var result = DataAnnotationsMetadataProvider.IsNonNullable(parameter.GetCustomAttributes(inherit: true)); + var result = DataAnnotationsMetadataProvider.IsNullableReferenceType(type, method, parameter.GetCustomAttributes(inherit: true)); // Assert Assert.True(result); @@ -1377,7 +1377,7 @@ public void IsNonNullable_FindsNullableParameter() var parameter = method.GetParameters().Where(p => p.Name == "nullableParameter").Single(); // Act - var result = DataAnnotationsMetadataProvider.IsNonNullable(parameter.GetCustomAttributes(inherit: true)); + var result = DataAnnotationsMetadataProvider.IsNullableReferenceType(type, method, parameter.GetCustomAttributes(inherit: true)); // Assert Assert.False(result); @@ -1429,12 +1429,12 @@ private DataAnnotationsMetadataProvider CreateIStringLocalizerProvider(bool useS return CreateProvider(options: null, localizationOptions, useStringLocalizer ? stringLocalizerFactory.Object : null); } - private ModelAttributes GetModelAttributes(IEnumerable typeAttributes) + private ModelAttributes GetModelAttributes(IEnumerable typeAttributes) => new ModelAttributes(typeAttributes, Array.Empty(), Array.Empty()); private ModelAttributes GetModelAttributes( IEnumerable typeAttributes, - IEnumerable propertyAttributes) + IEnumerable propertyAttributes) => new ModelAttributes(typeAttributes, propertyAttributes, Array.Empty()); private class KVPEnumGroupAndNameComparer : IEqualityComparer> diff --git a/src/Mvc/test/Mvc.FunctionalTests/NonNullableReferenceTypesTest.cs b/src/Mvc/test/Mvc.FunctionalTests/NonNullableReferenceTypesTest.cs index 052b180d2d5f..e22462ce342f 100644 --- a/src/Mvc/test/Mvc.FunctionalTests/NonNullableReferenceTypesTest.cs +++ b/src/Mvc/test/Mvc.FunctionalTests/NonNullableReferenceTypesTest.cs @@ -26,7 +26,7 @@ public NonNullableReferenceTypesTest(MvcTestFixture Date: Tue, 16 Jul 2019 13:09:13 -0700 Subject: [PATCH 2/2] PR feedback --- .../src/DataAnnotationsMetadataProvider.cs | 30 ++++++++++++------- .../Controllers/NonNullableController.cs | 1 - 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/src/Mvc/Mvc.DataAnnotations/src/DataAnnotationsMetadataProvider.cs b/src/Mvc/Mvc.DataAnnotations/src/DataAnnotationsMetadataProvider.cs index 46020580dbf2..e63b02edce4f 100644 --- a/src/Mvc/Mvc.DataAnnotations/src/DataAnnotationsMetadataProvider.cs +++ b/src/Mvc/Mvc.DataAnnotations/src/DataAnnotationsMetadataProvider.cs @@ -9,6 +9,7 @@ using System.Diagnostics.Contracts; using System.Linq; using System.Reflection; +using System.Runtime.InteropServices; using System.Runtime.InteropServices.ComTypes; using Microsoft.AspNetCore.Mvc.ModelBinding; using Microsoft.AspNetCore.Mvc.ModelBinding.Metadata; @@ -464,8 +465,7 @@ internal static bool HasNullableAttribute(IEnumerable attributes, out bo { // [Nullable] is compiler synthesized, comparing by name. var nullableAttribute = attributes - .Where(a => string.Equals(a.GetType().FullName, NullableAttributeFullTypeName, StringComparison.Ordinal)) - .FirstOrDefault(); + .FirstOrDefault(a => string.Equals(a.GetType().FullName, NullableAttributeFullTypeName, StringComparison.Ordinal)); if (nullableAttribute == null) { isNullable = false; @@ -494,19 +494,30 @@ internal static bool HasNullableAttribute(IEnumerable attributes, out bo internal static bool IsNullableBasedOnContext(Type containingType, MemberInfo member) { - var attributes = member?.GetCustomAttributes(inherit: true) ?? Array.Empty(); + // The [Nullable] and [NullableContext] attributes are not inherited. + // + // The [NullableContext] attribute can appear on a method or on the module. + var attributes = member?.GetCustomAttributes(inherit: false) ?? Array.Empty(); var isNullable = AttributesHasNullableContext(attributes); if (isNullable != null) { return isNullable.Value; } - attributes = containingType.GetCustomAttributes(inherit: false); - isNullable = AttributesHasNullableContext(attributes); - if (isNullable != null) + // Check on the containing type + var type = containingType; + do { - return isNullable.Value; - } + attributes = type.GetCustomAttributes(inherit: false); + isNullable = AttributesHasNullableContext(attributes); + if (isNullable != null) + { + return isNullable.Value; + } + + type = type.DeclaringType; + } + while (type != null); // If we don't find the attribute on the declaring type then repeat at the module level attributes = containingType.Module.GetCustomAttributes(inherit: false); @@ -516,8 +527,7 @@ internal static bool IsNullableBasedOnContext(Type containingType, MemberInfo me bool? AttributesHasNullableContext(object[] attributes) { var nullableContextAttribute = attributes - .Where(a => string.Equals(a.GetType().FullName, NullableContextAttributeFullName, StringComparison.Ordinal)) - .FirstOrDefault(); + .FirstOrDefault(a => string.Equals(a.GetType().FullName, NullableContextAttributeFullName, StringComparison.Ordinal)); if (nullableContextAttribute != null) { if (nullableContextAttribute.GetType().GetField(NullableContextFlagsFieldName) is FieldInfo field && diff --git a/src/Mvc/test/WebSites/BasicWebSite/Controllers/NonNullableController.cs b/src/Mvc/test/WebSites/BasicWebSite/Controllers/NonNullableController.cs index b5f3d6c3ac0e..881706e9eb92 100644 --- a/src/Mvc/test/WebSites/BasicWebSite/Controllers/NonNullableController.cs +++ b/src/Mvc/test/WebSites/BasicWebSite/Controllers/NonNullableController.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. #nullable enable -using BasicWebSite.Models; using Microsoft.AspNetCore.Mvc; namespace BasicWebSite.Controllers