diff --git a/src/Mvc/Mvc.DataAnnotations/src/DataAnnotationsMetadataProvider.cs b/src/Mvc/Mvc.DataAnnotations/src/DataAnnotationsMetadataProvider.cs index 97058965b454..e63b02edce4f 100644 --- a/src/Mvc/Mvc.DataAnnotations/src/DataAnnotationsMetadataProvider.cs +++ b/src/Mvc/Mvc.DataAnnotations/src/DataAnnotationsMetadataProvider.cs @@ -2,11 +2,15 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections; using System.Collections.Generic; using System.ComponentModel; using System.ComponentModel.DataAnnotations; +using System.Diagnostics.Contracts; using System.Linq; using System.Reflection; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.ComTypes; using Microsoft.AspNetCore.Mvc.ModelBinding; using Microsoft.AspNetCore.Mvc.ModelBinding.Metadata; using Microsoft.Extensions.Localization; @@ -27,6 +31,9 @@ internal class DataAnnotationsMetadataProvider : private const string NullableAttributeFullTypeName = "System.Runtime.CompilerServices.NullableAttribute"; private const string NullableFlagsFieldName = "NullableFlags"; + private const string NullableContextAttributeFullName = "System.Runtime.CompilerServices.NullableContextAttribute"; + private const string NullableContextFlagsFieldName = "Flag"; + private readonly IStringLocalizerFactory _stringLocalizerFactory; private readonly MvcOptions _options; private readonly MvcDataAnnotationsLocalizationOptions _localizationOptions; @@ -350,20 +357,44 @@ public void CreateValidationMetadata(ValidationMetadataProviderContext context) if (!_options.SuppressImplicitRequiredAttributeForNonNullableReferenceTypes && requiredAttribute == null && !context.Key.ModelType.IsValueType && - - // Look specifically at attributes on the property/parameter. [Nullable] on - // the type has a different meaning. - IsNonNullable(context.ParameterAttributes ?? context.PropertyAttributes ?? Array.Empty())) + context.Key.MetadataKind != ModelMetadataKind.Type) { - // Since this behavior specifically relates to non-null-ness, we will use the non-default - // option to tolerate empty/whitespace strings. empty/whitespace INPUT will still result in - // a validation error by default because we convert empty/whitespace strings to null - // unless you say otherwise. - requiredAttribute = new RequiredAttribute() + var addInferredRequiredAttribute = false; + if (context.Key.MetadataKind == ModelMetadataKind.Type) + { + // Do nothing. + } + else if (context.Key.MetadataKind == ModelMetadataKind.Property) + { + addInferredRequiredAttribute = IsNullableReferenceType( + context.Key.ContainerType, + member: null, + context.PropertyAttributes); + } + else if (context.Key.MetadataKind == ModelMetadataKind.Parameter) { - AllowEmptyStrings = true, - }; - attributes.Add(requiredAttribute); + addInferredRequiredAttribute = IsNullableReferenceType( + context.Key.ParameterInfo?.Member.ReflectedType, + context.Key.ParameterInfo.Member, + context.ParameterAttributes); + } + else + { + throw new InvalidOperationException("Unsupported ModelMetadataKind: " + context.Key.MetadataKind); + } + + if (addInferredRequiredAttribute) + { + // Since this behavior specifically relates to non-null-ness, we will use the non-default + // option to tolerate empty/whitespace strings. empty/whitespace INPUT will still result in + // a validation error by default because we convert empty/whitespace strings to null + // unless you say otherwise. + requiredAttribute = new RequiredAttribute() + { + AllowEmptyStrings = true, + }; + attributes.Add(requiredAttribute); + } } if (requiredAttribute != null) @@ -419,16 +450,26 @@ private static string GetDisplayGroup(FieldInfo field) return string.Empty; } + internal static bool IsNullableReferenceType(Type containingType, MemberInfo member, IEnumerable attributes) + { + if (HasNullableAttribute(attributes, out var result)) + { + return result; + } + + return IsNullableBasedOnContext(containingType, member); + } + // Internal for testing - internal static bool IsNonNullable(IEnumerable attributes) + internal static bool HasNullableAttribute(IEnumerable attributes, out bool isNullable) { // [Nullable] is compiler synthesized, comparing by name. var nullableAttribute = attributes - .Where(a => string.Equals(a.GetType().FullName, NullableAttributeFullTypeName, StringComparison.Ordinal)) - .FirstOrDefault(); + .FirstOrDefault(a => string.Equals(a.GetType().FullName, NullableAttributeFullTypeName, StringComparison.Ordinal)); if (nullableAttribute == null) { - return false; + isNullable = false; + return false; // [Nullable] not found } // We don't handle cases where generics and NNRT are used. This runs into a @@ -443,10 +484,61 @@ internal static bool IsNonNullable(IEnumerable attributes) flags.Length >= 0 && flags[0] == 1) // First element is the property/parameter type. { - return true; + isNullable = true; + return true; // [Nullable] found and type is an NNRT } - return false; + isNullable = false; + return true; // [Nullable] found but type is not an NNRT + } + + internal static bool IsNullableBasedOnContext(Type containingType, MemberInfo member) + { + // The [Nullable] and [NullableContext] attributes are not inherited. + // + // The [NullableContext] attribute can appear on a method or on the module. + var attributes = member?.GetCustomAttributes(inherit: false) ?? Array.Empty(); + var isNullable = AttributesHasNullableContext(attributes); + if (isNullable != null) + { + return isNullable.Value; + } + + // Check on the containing type + var type = containingType; + do + { + attributes = type.GetCustomAttributes(inherit: false); + isNullable = AttributesHasNullableContext(attributes); + if (isNullable != null) + { + return isNullable.Value; + } + + type = type.DeclaringType; + } + while (type != null); + + // If we don't find the attribute on the declaring type then repeat at the module level + attributes = containingType.Module.GetCustomAttributes(inherit: false); + isNullable = AttributesHasNullableContext(attributes); + return isNullable ?? false; + + bool? AttributesHasNullableContext(object[] attributes) + { + var nullableContextAttribute = attributes + .FirstOrDefault(a => string.Equals(a.GetType().FullName, NullableContextAttributeFullName, StringComparison.Ordinal)); + if (nullableContextAttribute != null) + { + if (nullableContextAttribute.GetType().GetField(NullableContextFlagsFieldName) is FieldInfo field && + field.GetValue(nullableContextAttribute) is byte @byte) + { + return @byte == 1; // [NullableContext] found + } + } + + return null; + } } } } diff --git a/src/Mvc/Mvc.DataAnnotations/test/DataAnnotationsMetadataProviderTest.cs b/src/Mvc/Mvc.DataAnnotations/test/DataAnnotationsMetadataProviderTest.cs index 2d260405ef3c..677c452d2d2e 100644 --- a/src/Mvc/Mvc.DataAnnotations/test/DataAnnotationsMetadataProviderTest.cs +++ b/src/Mvc/Mvc.DataAnnotations/test/DataAnnotationsMetadataProviderTest.cs @@ -1142,7 +1142,7 @@ public void CreateValidationMetadata_NoRequiredAttribute_IsRequiredLeftAlone(boo Assert.Equal(initialValue, context.ValidationMetadata.IsRequired); } - [Fact(Skip = "https://github.com/aspnet/AspNetCore/issues/11828")] + [Fact] public void CreateValidationMetadata_InfersRequiredAttribute_NoNonNullableProperty() { // Arrange @@ -1152,7 +1152,7 @@ public void CreateValidationMetadata_InfersRequiredAttribute_NoNonNullableProper typeof(NullableReferenceTypes), typeof(NullableReferenceTypes).GetProperty(nameof(NullableReferenceTypes.NonNullableReferenceType))); var key = ModelMetadataIdentity.ForProperty( - typeof(NullableReferenceTypes), + typeof(NullableReferenceTypes), nameof(NullableReferenceTypes.NonNullableReferenceType), typeof(string)); var context = new ValidationMetadataProviderContext(key, attributes); @@ -1325,7 +1325,7 @@ public void CreateValidationDetails_ValidatableObject_AlreadyInContext_Ignores() Assert.Same(attribute, validatorMetadata); } - [Fact(Skip = "https://github.com/aspnet/AspNetCore/issues/11828")] + [Fact] public void IsNonNullable_FindsNonNullableProperty() { // Arrange @@ -1333,7 +1333,7 @@ public void IsNonNullable_FindsNonNullableProperty() var property = type.GetProperty(nameof(NullableReferenceTypes.NonNullableReferenceType)); // Act - var result = DataAnnotationsMetadataProvider.IsNonNullable(property.GetCustomAttributes(inherit: true)); + var result = DataAnnotationsMetadataProvider.IsNullableReferenceType(type, member: null, property.GetCustomAttributes(inherit: true)); // Assert Assert.True(result); @@ -1347,13 +1347,13 @@ public void IsNonNullable_FindsNullableProperty() var property = type.GetProperty(nameof(NullableReferenceTypes.NullableReferenceType)); // Act - var result = DataAnnotationsMetadataProvider.IsNonNullable(property.GetCustomAttributes(inherit: true)); + var result = DataAnnotationsMetadataProvider.IsNullableReferenceType(type, member: null, property.GetCustomAttributes(inherit: true)); // Assert Assert.False(result); } - [Fact(Skip = "https://github.com/aspnet/AspNetCore/issues/11828")] + [Fact] public void IsNonNullable_FindsNonNullableParameter() { // Arrange @@ -1362,7 +1362,7 @@ public void IsNonNullable_FindsNonNullableParameter() var parameter = method.GetParameters().Where(p => p.Name == "nonNullableParameter").Single(); // Act - var result = DataAnnotationsMetadataProvider.IsNonNullable(parameter.GetCustomAttributes(inherit: true)); + var result = DataAnnotationsMetadataProvider.IsNullableReferenceType(type, method, parameter.GetCustomAttributes(inherit: true)); // Assert Assert.True(result); @@ -1377,7 +1377,7 @@ public void IsNonNullable_FindsNullableParameter() var parameter = method.GetParameters().Where(p => p.Name == "nullableParameter").Single(); // Act - var result = DataAnnotationsMetadataProvider.IsNonNullable(parameter.GetCustomAttributes(inherit: true)); + var result = DataAnnotationsMetadataProvider.IsNullableReferenceType(type, method, parameter.GetCustomAttributes(inherit: true)); // Assert Assert.False(result); @@ -1429,12 +1429,12 @@ private DataAnnotationsMetadataProvider CreateIStringLocalizerProvider(bool useS return CreateProvider(options: null, localizationOptions, useStringLocalizer ? stringLocalizerFactory.Object : null); } - private ModelAttributes GetModelAttributes(IEnumerable typeAttributes) + private ModelAttributes GetModelAttributes(IEnumerable typeAttributes) => new ModelAttributes(typeAttributes, Array.Empty(), Array.Empty()); private ModelAttributes GetModelAttributes( IEnumerable typeAttributes, - IEnumerable propertyAttributes) + IEnumerable propertyAttributes) => new ModelAttributes(typeAttributes, propertyAttributes, Array.Empty()); private class KVPEnumGroupAndNameComparer : IEqualityComparer> diff --git a/src/Mvc/test/Mvc.FunctionalTests/NonNullableReferenceTypesTest.cs b/src/Mvc/test/Mvc.FunctionalTests/NonNullableReferenceTypesTest.cs index 052b180d2d5f..e22462ce342f 100644 --- a/src/Mvc/test/Mvc.FunctionalTests/NonNullableReferenceTypesTest.cs +++ b/src/Mvc/test/Mvc.FunctionalTests/NonNullableReferenceTypesTest.cs @@ -26,7 +26,7 @@ public NonNullableReferenceTypesTest(MvcTestFixture