From f477237a54d44348ca0dbe306c2e98db200f469f Mon Sep 17 00:00:00 2001 From: maca88 Date: Sun, 27 Jan 2019 15:33:32 +0100 Subject: [PATCH 01/11] Enhance nullability check for "==" and "!=" operators for LINQ provider --- .../Entities/AnotherEntityRequired.cs | 33 ++ .../Mappings/AnotherEntityRequired.hbm.xml | 23 ++ .../Async/Linq/NullComparisonTests.cs | 307 ++++++++++++++++++ src/NHibernate.Test/Linq/LinqTestCase.cs | 3 +- .../Linq/NullComparisonTests.cs | 306 +++++++++++++++++ .../Expressions/NhAggregatedExpression.cs | 2 + .../Linq/Expressions/NhCountExpression.cs | 2 + .../Functions/BaseHqlGeneratorForMethod.cs | 19 +- .../Linq/Functions/CompareGenerator.cs | 1 + .../Linq/Functions/DictionaryGenerator.cs | 4 +- .../Linq/Functions/EqualsGenerator.cs | 6 +- .../Functions/GetValueOrDefaultGenerator.cs | 4 +- .../Linq/Functions/IHqlGeneratorForMethod.cs | 22 +- .../Linq/Functions/QueryableGenerator.cs | 8 +- .../Linq/Functions/StringGenerator.cs | 10 +- .../Visitors/HqlGeneratorExpressionVisitor.cs | 300 ++++++++++++++++- 16 files changed, 1027 insertions(+), 23 deletions(-) create mode 100644 src/NHibernate.DomainModel/Northwind/Entities/AnotherEntityRequired.cs create mode 100644 src/NHibernate.DomainModel/Northwind/Mappings/AnotherEntityRequired.hbm.xml diff --git a/src/NHibernate.DomainModel/Northwind/Entities/AnotherEntityRequired.cs b/src/NHibernate.DomainModel/Northwind/Entities/AnotherEntityRequired.cs new file mode 100644 index 00000000000..9b6492d4cb0 --- /dev/null +++ b/src/NHibernate.DomainModel/Northwind/Entities/AnotherEntityRequired.cs @@ -0,0 +1,33 @@ +using System.Collections.Generic; + +namespace NHibernate.DomainModel.Northwind.Entities +{ + public class AnotherEntityRequired + { + public virtual int Id { get; set; } + + public virtual string Output { get; set; } + + public virtual string Input { get; set; } + + public virtual Address Address { get; set; } + + public virtual AnotherEntityNullability InputNullability { get; set; } + + public virtual string NullableOutput { get; set; } + + public virtual AnotherEntityRequired NullableAnotherEntityRequired { get; set; } + + public virtual int? NullableAnotherEntityRequiredId { get; set; } + + public virtual ISet RelatedItems { get; set; } = new HashSet(); + + public virtual bool? NullableBool { get; set; } + } + + public enum AnotherEntityNullability + { + False = 0, + True = 1 + } +} diff --git a/src/NHibernate.DomainModel/Northwind/Mappings/AnotherEntityRequired.hbm.xml b/src/NHibernate.DomainModel/Northwind/Mappings/AnotherEntityRequired.hbm.xml new file mode 100644 index 00000000000..0d9efe4136f --- /dev/null +++ b/src/NHibernate.DomainModel/Northwind/Mappings/AnotherEntityRequired.hbm.xml @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs index cecc6778c74..af1510ad68d 100644 --- a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs @@ -15,10 +15,12 @@ using NHibernate.Linq; using NHibernate.DomainModel.Northwind.Entities; using NUnit.Framework; +using NUnit.Framework.Constraints; namespace NHibernate.Test.Linq { using System.Threading.Tasks; + using System.Threading; [TestFixture] public class NullComparisonTestsAsync : LinqTestCase { @@ -28,6 +30,271 @@ public class NullComparisonTestsAsync : LinqTestCase private static readonly AnotherEntity BothNull = new AnotherEntity(); private static readonly AnotherEntity BothDifferent = new AnotherEntity {Input = "input", Output = "output"}; + [Test] + public async Task NullInequalityWithNotNullAsync() + { + IQueryable q; + + q = session.Query().Where(o => o.Input != null); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothSame, BothDifferent)); + + q = session.Query().Where(o => null != o.Input); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothSame, BothDifferent)); + + q = session.Query().Where(o => o.InputNullability != AnotherEntityNullability.True); + await (ExpectAsync(q, Does.Not.Contain("end is null").IgnoreCase, InputSet, BothSame, BothDifferent)); + + q = session.Query().Where(o => AnotherEntityNullability.True != o.InputNullability); + await (ExpectAsync(q, Does.Not.Contain("end is null").IgnoreCase, InputSet, BothSame, BothDifferent)); + + q = session.Query().Where(o => "input" != o.Input); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => o.Input != "input"); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => o.Input != o.Output); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent)); + + q = session.Query().Where(o => o.Output != o.Input); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent)); + + q = session.Query().Where(o => o.Input != o.NullableOutput); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase, BothDifferent, InputSet, BothNull)); + + q = session.Query().Where(o => o.NullableOutput != o.Input); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase, BothDifferent, InputSet, BothNull)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.Output != o.Input); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase, BothDifferent, InputSet, BothNull)); + + q = session.Query().Where(o => o.Input != o.NullableAnotherEntityRequired.Output); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase, BothDifferent, InputSet, BothNull)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.Input != o.Output); + await (ExpectAsync(q, Does.Contain("Input is null").IgnoreCase, BothDifferent, OutputSet, BothNull)); + + q = session.Query().Where(o => o.Output != o.NullableAnotherEntityRequired.Input); + await (ExpectAsync(q, Does.Contain("Input is null").IgnoreCase, BothDifferent, OutputSet, BothNull)); + + q = session.Query().Where(o => 3 != o.NullableOutput.Length); + await (ExpectAsync(q, Does.Contain("is null").IgnoreCase, InputSet, BothDifferent, BothNull, OutputSet)); + + q = session.Query().Where(o => o.NullableOutput.Length != 3); + await (ExpectAsync(q, Does.Contain("is null").IgnoreCase, InputSet, BothDifferent, BothNull, OutputSet)); + + q = session.Query().Where(o => 3 != o.Input.Length); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothDifferent)); + + q = session.Query().Where(o => o.Input.Length != 3); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothDifferent)); + + q = session.Query().Where(o => (o.NullableAnotherEntityRequiredId ?? 0) != (o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId ?? 0)); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => (o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId ?? 0) != (o.NullableAnotherEntityRequiredId ?? 0)); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.GetValueOrDefault() != o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.GetValueOrDefault()); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.GetValueOrDefault() != o.NullableAnotherEntityRequiredId.GetValueOrDefault()); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value != o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value != 0); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue || o.NullableAnotherEntityRequiredId.Value != 0); + await (ExpectAllAsync(q, Does.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableOutput != null && o.NullableOutput != "test"); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent, BothSame, OutputSet)); + + q = session.Query().Where(o => o.NullableOutput != null || o.NullableOutput != "test"); + await (ExpectAllAsync(q, Does.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value != o.NullableAnotherEntityRequiredId.Value); + await (ExpectAsync(q, Does.Contain("or case").IgnoreCase)); + + q = session.Query().Where(o => o.RelatedItems.Any(r => r.Output != o.Input)); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase.And.Contain("Output is null").IgnoreCase, BothDifferent, InputSet, BothNull)); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output != o.Input)); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase.And.Contain("Output is null").IgnoreCase, InputSet, OutputSet, BothDifferent, BothNull)); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output != null && r.Output != o.Input)); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase.And.Not.Contain("Output is null").IgnoreCase, BothDifferent, OutputSet)); + + q = session.Query().Where(o => (o.NullableOutput + o.Output) != o.Output); + await (ExpectAllAsync(q, Does.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => (o.Input + o.Output) != o.Output); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothSame, BothDifferent)); + + q = session.Query().Where(o => o.RelatedItems.Count != 1); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.RelatedItems.Max(r => r.Id) != 0); + await (ExpectAllAsync(q, Does.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.RelatedItems.Take(10).All(r => r.Output != null) != o.NullableBool); + await (ExpectAllAsync(q, Does.Not.Contain("or case").IgnoreCase)); + + q = session.Query().Where(o => o.RelatedItems.Where(r => r.Id == 0).Sum(r => r.Input.Length) != 5); + await (ExpectAllAsync(q, Does.Contain("or (").IgnoreCase)); + + q = session.Query().Where(o => o.Address.Street != o.Output); + await (ExpectAsync(q, Does.Contain("Input is null").IgnoreCase, BothDifferent, OutputSet, BothNull)); + + q = session.Query().Where(o => o.Address.City != o.Output); + await (ExpectAsync(q, Does.Contain("Output is null").IgnoreCase, InputSet, BothNull)); + + q = session.Query().Where(o => o.Address.City != null && o.Address.City != o.Output); + await (ExpectAsync(q, Does.Not.Contain("Output is null").IgnoreCase)); + + q = session.Query().Where(o => o.Address.Street != null && o.Address.Street != o.NullableOutput); + await (ExpectAsync(q, Does.Contain("Output is null").IgnoreCase, InputSet, BothDifferent)); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output != null) != (o.NullableOutput.Length > 0)); + await (ExpectAsync(q, Does.Not.Contain("or case").IgnoreCase)); + } + + [Test] + public async Task NullEqualityWithNotNullAsync() + { + IQueryable q; + + q = session.Query().Where(o => o.Input == null); + await (ExpectAsync(q, Does.Not.Contain("or is null").IgnoreCase, OutputSet, BothNull)); + + q = session.Query().Where(o => null == o.Input); + await (ExpectAsync(q, Does.Not.Contain("or is null").IgnoreCase, OutputSet, BothNull)); + + q = session.Query().Where(o => o.InputNullability == AnotherEntityNullability.True); + await (ExpectAsync(q, Does.Not.Contain("end is null").IgnoreCase, BothNull, OutputSet)); + + q = session.Query().Where(o => AnotherEntityNullability.True == o.InputNullability); + await (ExpectAsync(q, Does.Not.Contain("end is null").IgnoreCase, BothNull, OutputSet)); + + q = session.Query().Where(o => "input" == o.Input); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothDifferent)); + + q = session.Query().Where(o => o.Input == "input"); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothDifferent)); + + q = session.Query().Where(o => o.Input == o.Output); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => o.Output == o.Input); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => o.Input == o.NullableOutput); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => o.NullableOutput == o.Input); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.Output == o.Input); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => o.Input == o.NullableAnotherEntityRequired.Output); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.Input == o.Output); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => o.Output == o.NullableAnotherEntityRequired.Input); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => 3 == o.Input.Length); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => o.Input.Length == 3); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => (o.NullableAnotherEntityRequiredId ?? 0) == (o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId ?? 0)); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => (o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId ?? 0) == (o.NullableAnotherEntityRequiredId ?? 0)); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.GetValueOrDefault() == o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.GetValueOrDefault()); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.GetValueOrDefault() == o.NullableAnotherEntityRequiredId.GetValueOrDefault()); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value == o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value == 0); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue || o.NullableAnotherEntityRequiredId.Value == 0); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableOutput == "test"); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableOutput != null || o.NullableOutput == "test"); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, OutputSet, BothDifferent, BothSame)); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value == o.NullableAnotherEntityRequiredId.Value); + await (ExpectAllAsync(q, Does.Contain("Id is null").IgnoreCase)); + + q = session.Query().Where(o => o.RelatedItems.Any(r => r.Output == o.Input)); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase.And.Not.Contain("Output is null").IgnoreCase, BothSame)); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output == o.Input)); + await (ExpectAsync(q, Does.Not.Contain("Input is null").IgnoreCase.And.Not.Contain("Output is null").IgnoreCase, BothSame, BothNull, InputSet, OutputSet)); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output == o.NullableOutput)); + await (ExpectAllAsync(q, Does.Contain("Output is null").IgnoreCase)); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output != null && o.NullableOutput != null && r.Output == o.NullableOutput)); + await (ExpectAsync(q, Does.Not.Contain("Output is null").IgnoreCase, BothSame, BothDifferent, OutputSet)); + + q = session.Query().Where(o => (o.NullableOutput + o.Output) == o.Output); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => (o.Input + o.Output) == o.Output); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.RelatedItems.Count == 1); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.RelatedItems.Max(r => r.Id) == 0); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => !o.Input.Equals(o.Output)); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent)); + + q = session.Query().Where(o => !o.Output.Equals(o.Input)); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent)); + + q = session.Query().Where(o => !o.Input.Equals(o.NullableOutput)); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent)); + + q = session.Query().Where(o => !o.NullableOutput.Equals(o.Input)); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent)); + + q = session.Query().Where(o => !o.NullableOutput.Equals(o.NullableOutput)); + await (ExpectAsync(q, Does.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => !o.NullableOutput.Equals(o.NullableOutput)); + await (ExpectAsync(q, Does.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.Address.City == o.NullableOutput); + await (ExpectAllAsync(q, Does.Contain("Output is null").IgnoreCase)); + + q = session.Query().Where(o => o.Address.Street != null && o.Address.Street == o.NullableOutput); + await (ExpectAsync(q, Does.Not.Contain("Output is null").IgnoreCase, BothSame)); + } + + [Test] public async Task NullEqualityAsync() { @@ -316,5 +583,45 @@ private string Key(AnotherEntity e) { return "Input=" + (e.Input ?? "NULL") + ", Output=" + (e.Output ?? "NULL"); } + + private Task ExpectAllAsync(IQueryable q, IResolveConstraint sqlConstraint) + { + return ExpectAsync(q, sqlConstraint, BothNull, BothSame, BothDifferent, InputSet, OutputSet); + } + + private async Task ExpectAsync(IQueryable q, IResolveConstraint sqlConstraint, params AnotherEntity[] entities) + { + IList results; + if (sqlConstraint == null) + { + results = await (GetResultsAsync(q)); + } + else + { + using (var sqlLog = new SqlLogSpy()) + { + results = await (GetResultsAsync(q)); + Assert.That(sqlLog.GetWholeLog(), sqlConstraint); + } + } + + IList check = entities.OrderBy(Key).ToList(); + + Assert.That(results.Count, Is.EqualTo(check.Count)); + for (var i = 0; i < check.Count; i++) + { + Assert.That(Key(results[i]), Is.EqualTo(Key(check[i]))); + } + } + + private async Task> GetResultsAsync(IQueryable q, CancellationToken cancellationToken = default(CancellationToken)) + { + return (await (q.ToListAsync(cancellationToken))).OrderBy(Key).ToList(); + } + + private static string Key(AnotherEntityRequired e) + { + return "Input=" + (e.Input ?? "NULL") + ", Output=" + (e.Output ?? "NULL"); + } } } diff --git a/src/NHibernate.Test/Linq/LinqTestCase.cs b/src/NHibernate.Test/Linq/LinqTestCase.cs index 26503786bb0..e047732d7ad 100755 --- a/src/NHibernate.Test/Linq/LinqTestCase.cs +++ b/src/NHibernate.Test/Linq/LinqTestCase.cs @@ -29,6 +29,7 @@ protected override string[] Mappings "Northwind.Mappings.Supplier.hbm.xml", "Northwind.Mappings.Territory.hbm.xml", "Northwind.Mappings.AnotherEntity.hbm.xml", + "Northwind.Mappings.AnotherEntityRequired.hbm.xml", "Northwind.Mappings.Role.hbm.xml", "Northwind.Mappings.User.hbm.xml", "Northwind.Mappings.TimeSheet.hbm.xml", @@ -69,4 +70,4 @@ public static void AssertByIds(IEnumerable entities, TId[ Assert.That(entities.Select(x => entityIdGetter(x)), Is.EquivalentTo(expectedIds)); } } -} \ No newline at end of file +} diff --git a/src/NHibernate.Test/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Linq/NullComparisonTests.cs index e8a4c6eec7f..0d9363bb5ed 100644 --- a/src/NHibernate.Test/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Linq/NullComparisonTests.cs @@ -5,6 +5,7 @@ using NHibernate.Linq; using NHibernate.DomainModel.Northwind.Entities; using NUnit.Framework; +using NUnit.Framework.Constraints; namespace NHibernate.Test.Linq { @@ -17,6 +18,271 @@ public class NullComparisonTests : LinqTestCase private static readonly AnotherEntity BothNull = new AnotherEntity(); private static readonly AnotherEntity BothDifferent = new AnotherEntity {Input = "input", Output = "output"}; + [Test] + public void NullInequalityWithNotNull() + { + IQueryable q; + + q = session.Query().Where(o => o.Input != null); + Expect(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothSame, BothDifferent); + + q = session.Query().Where(o => null != o.Input); + Expect(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothSame, BothDifferent); + + q = session.Query().Where(o => o.InputNullability != AnotherEntityNullability.True); + Expect(q, Does.Not.Contain("end is null").IgnoreCase, InputSet, BothSame, BothDifferent); + + q = session.Query().Where(o => AnotherEntityNullability.True != o.InputNullability); + Expect(q, Does.Not.Contain("end is null").IgnoreCase, InputSet, BothSame, BothDifferent); + + q = session.Query().Where(o => "input" != o.Input); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => o.Input != "input"); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => o.Input != o.Output); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent); + + q = session.Query().Where(o => o.Output != o.Input); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent); + + q = session.Query().Where(o => o.Input != o.NullableOutput); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase, BothDifferent, InputSet, BothNull); + + q = session.Query().Where(o => o.NullableOutput != o.Input); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase, BothDifferent, InputSet, BothNull); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.Output != o.Input); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase, BothDifferent, InputSet, BothNull); + + q = session.Query().Where(o => o.Input != o.NullableAnotherEntityRequired.Output); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase, BothDifferent, InputSet, BothNull); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.Input != o.Output); + Expect(q, Does.Contain("Input is null").IgnoreCase, BothDifferent, OutputSet, BothNull); + + q = session.Query().Where(o => o.Output != o.NullableAnotherEntityRequired.Input); + Expect(q, Does.Contain("Input is null").IgnoreCase, BothDifferent, OutputSet, BothNull); + + q = session.Query().Where(o => 3 != o.NullableOutput.Length); + Expect(q, Does.Contain("is null").IgnoreCase, InputSet, BothDifferent, BothNull, OutputSet); + + q = session.Query().Where(o => o.NullableOutput.Length != 3); + Expect(q, Does.Contain("is null").IgnoreCase, InputSet, BothDifferent, BothNull, OutputSet); + + q = session.Query().Where(o => 3 != o.Input.Length); + Expect(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothDifferent); + + q = session.Query().Where(o => o.Input.Length != 3); + Expect(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothDifferent); + + q = session.Query().Where(o => (o.NullableAnotherEntityRequiredId ?? 0) != (o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId ?? 0)); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => (o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId ?? 0) != (o.NullableAnotherEntityRequiredId ?? 0)); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.GetValueOrDefault() != o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.GetValueOrDefault()); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.GetValueOrDefault() != o.NullableAnotherEntityRequiredId.GetValueOrDefault()); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value != o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value != 0); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue || o.NullableAnotherEntityRequiredId.Value != 0); + ExpectAll(q, Does.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableOutput != null && o.NullableOutput != "test"); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent, BothSame, OutputSet); + + q = session.Query().Where(o => o.NullableOutput != null || o.NullableOutput != "test"); + ExpectAll(q, Does.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value != o.NullableAnotherEntityRequiredId.Value); + Expect(q, Does.Contain("or case").IgnoreCase); + + q = session.Query().Where(o => o.RelatedItems.Any(r => r.Output != o.Input)); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase.And.Contain("Output is null").IgnoreCase, BothDifferent, InputSet, BothNull); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output != o.Input)); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase.And.Contain("Output is null").IgnoreCase, InputSet, OutputSet, BothDifferent, BothNull); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output != null && r.Output != o.Input)); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase.And.Not.Contain("Output is null").IgnoreCase, BothDifferent, OutputSet); + + q = session.Query().Where(o => (o.NullableOutput + o.Output) != o.Output); + ExpectAll(q, Does.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => (o.Input + o.Output) != o.Output); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothSame, BothDifferent); + + q = session.Query().Where(o => o.RelatedItems.Count != 1); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.RelatedItems.Max(r => r.Id) != 0); + ExpectAll(q, Does.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.RelatedItems.Take(10).All(r => r.Output != null) != o.NullableBool); + ExpectAll(q, Does.Not.Contain("or case").IgnoreCase); + + q = session.Query().Where(o => o.RelatedItems.Where(r => r.Id == 0).Sum(r => r.Input.Length) != 5); + ExpectAll(q, Does.Contain("or (").IgnoreCase); + + q = session.Query().Where(o => o.Address.Street != o.Output); + Expect(q, Does.Contain("Input is null").IgnoreCase, BothDifferent, OutputSet, BothNull); + + q = session.Query().Where(o => o.Address.City != o.Output); + Expect(q, Does.Contain("Output is null").IgnoreCase, InputSet, BothNull); + + q = session.Query().Where(o => o.Address.City != null && o.Address.City != o.Output); + Expect(q, Does.Not.Contain("Output is null").IgnoreCase); + + q = session.Query().Where(o => o.Address.Street != null && o.Address.Street != o.NullableOutput); + Expect(q, Does.Contain("Output is null").IgnoreCase, InputSet, BothDifferent); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output != null) != (o.NullableOutput.Length > 0)); + Expect(q, Does.Not.Contain("or case").IgnoreCase); + } + + [Test] + public void NullEqualityWithNotNull() + { + IQueryable q; + + q = session.Query().Where(o => o.Input == null); + Expect(q, Does.Not.Contain("or is null").IgnoreCase, OutputSet, BothNull); + + q = session.Query().Where(o => null == o.Input); + Expect(q, Does.Not.Contain("or is null").IgnoreCase, OutputSet, BothNull); + + q = session.Query().Where(o => o.InputNullability == AnotherEntityNullability.True); + Expect(q, Does.Not.Contain("end is null").IgnoreCase, BothNull, OutputSet); + + q = session.Query().Where(o => AnotherEntityNullability.True == o.InputNullability); + Expect(q, Does.Not.Contain("end is null").IgnoreCase, BothNull, OutputSet); + + q = session.Query().Where(o => "input" == o.Input); + Expect(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothDifferent); + + q = session.Query().Where(o => o.Input == "input"); + Expect(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothDifferent); + + q = session.Query().Where(o => o.Input == o.Output); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => o.Output == o.Input); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => o.Input == o.NullableOutput); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => o.NullableOutput == o.Input); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.Output == o.Input); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => o.Input == o.NullableAnotherEntityRequired.Output); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.Input == o.Output); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => o.Output == o.NullableAnotherEntityRequired.Input); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => 3 == o.Input.Length); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => o.Input.Length == 3); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => (o.NullableAnotherEntityRequiredId ?? 0) == (o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId ?? 0)); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => (o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId ?? 0) == (o.NullableAnotherEntityRequiredId ?? 0)); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.GetValueOrDefault() == o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.GetValueOrDefault()); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.GetValueOrDefault() == o.NullableAnotherEntityRequiredId.GetValueOrDefault()); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value == o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value == 0); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue || o.NullableAnotherEntityRequiredId.Value == 0); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableOutput == "test"); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableOutput != null || o.NullableOutput == "test"); + Expect(q, Does.Not.Contain("is null").IgnoreCase, OutputSet, BothDifferent, BothSame); + + q = session.Query().Where(o => o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value == o.NullableAnotherEntityRequiredId.Value); + ExpectAll(q, Does.Contain("Id is null").IgnoreCase); + + q = session.Query().Where(o => o.RelatedItems.Any(r => r.Output == o.Input)); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase.And.Not.Contain("Output is null").IgnoreCase, BothSame); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output == o.Input)); + Expect(q, Does.Not.Contain("Input is null").IgnoreCase.And.Not.Contain("Output is null").IgnoreCase, BothSame, BothNull, InputSet, OutputSet); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output == o.NullableOutput)); + ExpectAll(q, Does.Contain("Output is null").IgnoreCase); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output != null && o.NullableOutput != null && r.Output == o.NullableOutput)); + Expect(q, Does.Not.Contain("Output is null").IgnoreCase, BothSame, BothDifferent, OutputSet); + + q = session.Query().Where(o => (o.NullableOutput + o.Output) == o.Output); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => (o.Input + o.Output) == o.Output); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.RelatedItems.Count == 1); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.RelatedItems.Max(r => r.Id) == 0); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => !o.Input.Equals(o.Output)); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent); + + q = session.Query().Where(o => !o.Output.Equals(o.Input)); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent); + + q = session.Query().Where(o => !o.Input.Equals(o.NullableOutput)); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent); + + q = session.Query().Where(o => !o.NullableOutput.Equals(o.Input)); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent); + + q = session.Query().Where(o => !o.NullableOutput.Equals(o.NullableOutput)); + Expect(q, Does.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => !o.NullableOutput.Equals(o.NullableOutput)); + Expect(q, Does.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.Address.City == o.NullableOutput); + ExpectAll(q, Does.Contain("Output is null").IgnoreCase); + + q = session.Query().Where(o => o.Address.Street != null && o.Address.Street == o.NullableOutput); + Expect(q, Does.Not.Contain("Output is null").IgnoreCase, BothSame); + } + + [Test] public void NullEquality() { @@ -305,5 +571,45 @@ private string Key(AnotherEntity e) { return "Input=" + (e.Input ?? "NULL") + ", Output=" + (e.Output ?? "NULL"); } + + private void ExpectAll(IQueryable q, IResolveConstraint sqlConstraint) + { + Expect(q, sqlConstraint, BothNull, BothSame, BothDifferent, InputSet, OutputSet); + } + + private void Expect(IQueryable q, IResolveConstraint sqlConstraint, params AnotherEntity[] entities) + { + IList results; + if (sqlConstraint == null) + { + results = GetResults(q); + } + else + { + using (var sqlLog = new SqlLogSpy()) + { + results = GetResults(q); + Assert.That(sqlLog.GetWholeLog(), sqlConstraint); + } + } + + IList check = entities.OrderBy(Key).ToList(); + + Assert.That(results.Count, Is.EqualTo(check.Count)); + for (var i = 0; i < check.Count; i++) + { + Assert.That(Key(results[i]), Is.EqualTo(Key(check[i]))); + } + } + + private IList GetResults(IQueryable q) + { + return q.ToList().OrderBy(Key).ToList(); + } + + private static string Key(AnotherEntityRequired e) + { + return "Input=" + (e.Input ?? "NULL") + ", Output=" + (e.Output ?? "NULL"); + } } } diff --git a/src/NHibernate/Linq/Expressions/NhAggregatedExpression.cs b/src/NHibernate/Linq/Expressions/NhAggregatedExpression.cs index 6b705d3d92e..ab60d1a3d29 100644 --- a/src/NHibernate/Linq/Expressions/NhAggregatedExpression.cs +++ b/src/NHibernate/Linq/Expressions/NhAggregatedExpression.cs @@ -16,6 +16,8 @@ protected NhAggregatedExpression(Expression expression, System.Type type) Type = type; } + public virtual bool AllowsNullableReturnType => true; + public sealed override System.Type Type { get; } public Expression Expression { get; } diff --git a/src/NHibernate/Linq/Expressions/NhCountExpression.cs b/src/NHibernate/Linq/Expressions/NhCountExpression.cs index 6dc698add5c..e41ed926410 100644 --- a/src/NHibernate/Linq/Expressions/NhCountExpression.cs +++ b/src/NHibernate/Linq/Expressions/NhCountExpression.cs @@ -10,6 +10,8 @@ protected NhCountExpression(Expression expression, System.Type type) { } + public override bool AllowsNullableReturnType => false; + protected override Expression Accept(NhExpressionVisitor visitor) { return visitor.VisitNhCount(this); diff --git a/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs b/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs index 3a9462d49ef..d65acba8b87 100644 --- a/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs +++ b/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs @@ -7,10 +7,17 @@ namespace NHibernate.Linq.Functions { - public abstract class BaseHqlGeneratorForMethod : IHqlGeneratorForMethod - { - public IEnumerable SupportedMethods { get; protected set; } + public abstract class BaseHqlGeneratorForMethod : IHqlGeneratorForMethod, IHqlGeneratorForMethodExtended + { + public IEnumerable SupportedMethods { get; protected set; } - public abstract HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor); - } -} \ No newline at end of file + public abstract HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor); + + public virtual bool AllowsNullableReturnType(MethodInfo method) => true; + + bool IHqlGeneratorForMethodExtended.AllowsNullableReturnType(MethodInfo method) + { + return AllowsNullableReturnType(method); + } + } +} diff --git a/src/NHibernate/Linq/Functions/CompareGenerator.cs b/src/NHibernate/Linq/Functions/CompareGenerator.cs index a819bcd8151..e703e6c62e7 100644 --- a/src/NHibernate/Linq/Functions/CompareGenerator.cs +++ b/src/NHibernate/Linq/Functions/CompareGenerator.cs @@ -51,6 +51,7 @@ internal static bool IsCompareMethod(MethodInfo methodInfo) methodInfo.DeclaringType.FullName == "System.Data.Services.Providers.DataServiceProviderMethods"; } + public override bool AllowsNullableReturnType(MethodInfo method) => false; public CompareGenerator() { SupportedMethods = ActingMethods.ToArray(); diff --git a/src/NHibernate/Linq/Functions/DictionaryGenerator.cs b/src/NHibernate/Linq/Functions/DictionaryGenerator.cs index 6131c4d7ab8..eb583d0cd2f 100644 --- a/src/NHibernate/Linq/Functions/DictionaryGenerator.cs +++ b/src/NHibernate/Linq/Functions/DictionaryGenerator.cs @@ -25,6 +25,8 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, public class DictionaryContainsKeyGenerator : BaseHqlGeneratorForMethod { + public override bool AllowsNullableReturnType(MethodInfo method) => false; + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) { return treeBuilder.In(visitor.Visit(arguments[0]).AsExpression(), treeBuilder.Indices(visitor.Visit(targetObject).AsExpression())); @@ -98,4 +100,4 @@ protected override string MethodName get { return "get_Item"; } } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Functions/EqualsGenerator.cs b/src/NHibernate/Linq/Functions/EqualsGenerator.cs index 27165978b34..55ec39e40ee 100644 --- a/src/NHibernate/Linq/Functions/EqualsGenerator.cs +++ b/src/NHibernate/Linq/Functions/EqualsGenerator.cs @@ -63,14 +63,14 @@ public EqualsGenerator() }; } + public override bool AllowsNullableReturnType(MethodInfo method) => false; + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) { Expression lhs = arguments.Count == 1 ? targetObject : arguments[0]; Expression rhs = arguments.Count == 1 ? arguments[0] : arguments[1]; - return treeBuilder.Equality( - visitor.Visit(lhs).ToArithmeticExpression(), - visitor.Visit(rhs).ToArithmeticExpression()); + return visitor.Visit(Expression.Equal(lhs, rhs)); } } } diff --git a/src/NHibernate/Linq/Functions/GetValueOrDefaultGenerator.cs b/src/NHibernate/Linq/Functions/GetValueOrDefaultGenerator.cs index 87c8efb01a9..33cb12c2c6c 100644 --- a/src/NHibernate/Linq/Functions/GetValueOrDefaultGenerator.cs +++ b/src/NHibernate/Linq/Functions/GetValueOrDefaultGenerator.cs @@ -9,7 +9,7 @@ namespace NHibernate.Linq.Functions { - internal class GetValueOrDefaultGenerator : IHqlGeneratorForMethod, IRuntimeMethodHqlGenerator + internal class GetValueOrDefaultGenerator : IHqlGeneratorForMethod, IRuntimeMethodHqlGenerator, IHqlGeneratorForMethodExtended { public bool SupportsMethod(MethodInfo method) { @@ -40,5 +40,7 @@ private static HqlExpression GetRhs(MethodInfo method, ReadOnlyCollection !method.ReturnType.IsValueType; } } diff --git a/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs b/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs index 06b1545ef23..fde4ffd45f0 100644 --- a/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs +++ b/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs @@ -12,4 +12,24 @@ public interface IHqlGeneratorForMethod IEnumerable SupportedMethods { get; } HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor); } -} \ No newline at end of file + + // 6.0 TODO: Merge into IHqlGeneratorForMethod + internal interface IHqlGeneratorForMethodExtended + { + bool AllowsNullableReturnType(MethodInfo method); + } + + internal static class HqlGeneratorForMethodExtensions + { + // 6.0 TODO: Remove + public static bool AllowsNullableReturnType(this IHqlGeneratorForMethod generator, MethodInfo method) + { + if (generator is IHqlGeneratorForMethodExtended extendedGenerator) + { + return extendedGenerator.AllowsNullableReturnType(method); + } + + return true; + } + } +} diff --git a/src/NHibernate/Linq/Functions/QueryableGenerator.cs b/src/NHibernate/Linq/Functions/QueryableGenerator.cs index 4f3a9568c69..3da75fd0c99 100644 --- a/src/NHibernate/Linq/Functions/QueryableGenerator.cs +++ b/src/NHibernate/Linq/Functions/QueryableGenerator.cs @@ -22,6 +22,8 @@ public AnyHqlGenerator() }; } + public override bool AllowsNullableReturnType(MethodInfo method) => false; + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) { HqlAlias alias = null; @@ -59,6 +61,8 @@ public AllHqlGenerator() }; } + public override bool AllowsNullableReturnType(MethodInfo method) => false; + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) { // All has two arguments. Arg 1 is the source and arg 2 is the predicate @@ -148,6 +152,8 @@ public CollectionContainsGenerator() }; } + public override bool AllowsNullableReturnType(MethodInfo method) => false; + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) { // TODO - alias generator @@ -170,4 +176,4 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, where)); } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Functions/StringGenerator.cs b/src/NHibernate/Linq/Functions/StringGenerator.cs index 7ec127b1f17..31edf4a42d1 100644 --- a/src/NHibernate/Linq/Functions/StringGenerator.cs +++ b/src/NHibernate/Linq/Functions/StringGenerator.cs @@ -10,7 +10,7 @@ namespace NHibernate.Linq.Functions { - public class LikeGenerator : IHqlGeneratorForMethod, IRuntimeMethodHqlGenerator + public class LikeGenerator : IHqlGeneratorForMethod, IRuntimeMethodHqlGenerator, IHqlGeneratorForMethodExtended { public IEnumerable SupportedMethods { @@ -57,6 +57,8 @@ public IHqlGeneratorForMethod GetMethodGenerator(MethodInfo method) { return this; } + + public bool AllowsNullableReturnType(MethodInfo method) => false; } public class LengthGenerator : BaseHqlGeneratorForProperty @@ -79,6 +81,8 @@ public StartsWithGenerator() SupportedMethods = new[] { ReflectHelper.GetMethodDefinition(x => x.StartsWith(null)) }; } + public override bool AllowsNullableReturnType(MethodInfo method) => false; + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) { return treeBuilder.Like( @@ -96,6 +100,8 @@ public EndsWithGenerator() SupportedMethods = new[] { ReflectHelper.GetMethodDefinition(x => x.EndsWith(null)) }; } + public override bool AllowsNullableReturnType(MethodInfo method) => false; + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) { return treeBuilder.Like( @@ -113,6 +119,8 @@ public ContainsGenerator() SupportedMethods = new[] { ReflectHelper.GetMethodDefinition(x => x.Contains(null)) }; } + public override bool AllowsNullableReturnType(MethodInfo method) => false; + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) { return treeBuilder.Like( diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index 4e6df61afba..8cc5c8089e8 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Data; using System.Dynamic; using System.Linq; @@ -7,12 +8,16 @@ using NHibernate.Engine.Query; using NHibernate.Hql.Ast; using NHibernate.Hql.Ast.ANTLR; +using NHibernate.Linq.Clauses; using NHibernate.Linq.Expressions; using NHibernate.Linq.Functions; +using NHibernate.Mapping.ByCode; using NHibernate.Param; using NHibernate.Type; using NHibernate.Util; +using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Clauses.ResultOperators; namespace NHibernate.Linq.Visitors { @@ -21,10 +26,21 @@ public class HqlGeneratorExpressionVisitor : IHqlExpressionVisitor private readonly HqlTreeBuilder _hqlTreeBuilder = new HqlTreeBuilder(); private readonly VisitorParameters _parameters; private readonly ILinqToHqlGeneratorsRegistry _functionRegistry; + private readonly Dictionary> _equalityNotNullMembers = + new Dictionary>(); + + private static readonly HashSet NotNullOperators = new HashSet() + { + typeof(AllResultOperator), + typeof(AnyResultOperator), + typeof(ContainsResultOperator), + typeof(CountResultOperator), + typeof(LongCountResultOperator) + }; public static HqlTreeNode Visit(Expression expression, VisitorParameters parameters) { - return new HqlGeneratorExpressionVisitor(parameters).VisitExpression(expression); + return new HqlGeneratorExpressionVisitor(parameters).Visit(expression); } public HqlGeneratorExpressionVisitor(VisitorParameters parameters) @@ -292,6 +308,94 @@ private HqlTreeNode VisitVBStringComparisonExpression(VBStringComparisonExpressi return VisitExpression(expression.Comparison); } + private void SearchForNotNullMembersCheck(BinaryExpression expression) + { + // Check for a member not null check that has a not equals expression + // Example: o.Status != null && o.Status != "New" + // Example: (o.Status != null && o.OldStatus != null) && (o.Status != o.OldStatus) + // Example: (o.Status != null && o.OldStatus != null) && (o.Status == o.OldStatus) + if (expression.NodeType != ExpressionType.AndAlso || + expression.Right.NodeType != ExpressionType.NotEqual && + expression.Right.NodeType != ExpressionType.Equal || + expression.Left.NodeType != ExpressionType.AndAlso && + expression.Left.NodeType != ExpressionType.NotEqual) + { + return; + } + + // Skip if there are no member access expressions on the right side + var notEqualExpression = (BinaryExpression) expression.Right; + if (!IsMemberAccess(notEqualExpression.Left) && !IsMemberAccess(notEqualExpression.Right)) + { + return; + } + + var notNullMembers = new List(); + // We may have multiple conditions + // Example: o.Status != null && o.OldStatus != null + if (expression.Left.NodeType == ExpressionType.AndAlso) + { + FindAllNotNullMembers((BinaryExpression) expression.Left, notNullMembers); + } + else + { + FindNotNullMember((BinaryExpression) expression.Left, notNullMembers); + } + + if (notNullMembers.Count > 0) + { + _equalityNotNullMembers[notEqualExpression] = notNullMembers; + } + } + + private static bool IsMemberAccess(Expression expression) + { + if (expression.NodeType == ExpressionType.MemberAccess) + { + return true; + } + + // Nullable members can be wrapped in a convert expression + return expression is UnaryExpression unaryExpression && unaryExpression.Operand.NodeType == ExpressionType.MemberAccess; + } + + private static void FindAllNotNullMembers(BinaryExpression andAlsoExpression, List notNullMembers) + { + if (andAlsoExpression.Right.NodeType == ExpressionType.NotEqual) + { + FindNotNullMember((BinaryExpression) andAlsoExpression.Right, notNullMembers); + } + else if (andAlsoExpression.Right.NodeType == ExpressionType.AndAlso) + { + FindAllNotNullMembers((BinaryExpression) andAlsoExpression.Right, notNullMembers); + } + else + { + return; + } + + if (andAlsoExpression.Left.NodeType == ExpressionType.NotEqual) + { + FindNotNullMember((BinaryExpression) andAlsoExpression.Left, notNullMembers); + } + else if (andAlsoExpression.Left.NodeType == ExpressionType.AndAlso) + { + FindAllNotNullMembers((BinaryExpression) andAlsoExpression.Left, notNullMembers); + } + } + + private static void FindNotNullMember(BinaryExpression notEqualExpression, List notNullMembers) + { + if (notEqualExpression.Left.NodeType == ExpressionType.MemberAccess && VisitorUtil.IsNullConstant(notEqualExpression.Right)) + { + notNullMembers.Add((MemberExpression) notEqualExpression.Left); + } + else if (VisitorUtil.IsNullConstant(notEqualExpression.Left) && notEqualExpression.Right.NodeType == ExpressionType.MemberAccess) + { + notNullMembers.Add((MemberExpression) notEqualExpression.Right); + } + } + protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression) { if (expression.NodeType == ExpressionType.Equal) @@ -303,6 +407,8 @@ protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression) return TranslateInequalityComparison(expression); } + SearchForNotNullMembersCheck(expression); + var lhs = VisitExpression(expression.Left).AsExpression(); var rhs = VisitExpression(expression.Right).AsExpression(); @@ -384,8 +490,8 @@ private HqlTreeNode TranslateInequalityComparison(BinaryExpression expression) return _hqlTreeBuilder.IsNotNull(lhs); } - var lhsNullable = IsNullable(lhs); - var rhsNullable = IsNullable(rhs); + var lhsNullable = IsNullable(expression.Left, expression); + var rhsNullable = IsNullable(expression.Right, expression); var inequality = _hqlTreeBuilder.Inequality(lhs, rhs); @@ -447,8 +553,8 @@ private HqlTreeNode TranslateEqualityComparison(BinaryExpression expression) return _hqlTreeBuilder.IsNull((lhs)); } - var lhsNullable = IsNullable(lhs); - var rhsNullable = IsNullable(rhs); + var lhsNullable = IsNullable(expression.Left, expression); + var rhsNullable = IsNullable(expression.Right, expression); var equality = _hqlTreeBuilder.Equality(lhs, rhs); @@ -467,10 +573,188 @@ private HqlTreeNode TranslateEqualityComparison(BinaryExpression expression) _hqlTreeBuilder.IsNull(rhs2))); } - static bool IsNullable(HqlExpression original) + private bool IsNullable(Expression expression, BinaryExpression equalityExpression) + { + var currentExpression = expression; + while (true) + { + switch (currentExpression.NodeType) + { + case ExpressionType.Convert: + case ExpressionType.ConvertChecked: + case ExpressionType.TypeAs: + var unaryExpression = (UnaryExpression) currentExpression; + return IsNullable(unaryExpression.Operand, equalityExpression); // a cast will not return null if the operand is not null + case ExpressionType.Not: + case ExpressionType.And: + case ExpressionType.Or: + case ExpressionType.ExclusiveOr: + case ExpressionType.LeftShift: + case ExpressionType.RightShift: + case ExpressionType.AndAlso: + case ExpressionType.OrElse: + case ExpressionType.Equal: + case ExpressionType.NotEqual: + case ExpressionType.GreaterThanOrEqual: + case ExpressionType.GreaterThan: + case ExpressionType.LessThan: + case ExpressionType.LessThanOrEqual: + return false; + case ExpressionType.Add: + case ExpressionType.AddChecked: + case ExpressionType.Divide: + case ExpressionType.Modulo: + case ExpressionType.Multiply: + case ExpressionType.MultiplyChecked: + case ExpressionType.Power: + case ExpressionType.Subtract: + case ExpressionType.SubtractChecked: + var binaryExpression = (BinaryExpression) currentExpression; + return IsNullable(binaryExpression.Left, equalityExpression) || IsNullable(binaryExpression.Right, equalityExpression); + case ExpressionType.ArrayIndex: + return true; // for indexed lists we cannot determine whether the item will be null or not + case ExpressionType.Coalesce: + return IsNullable(((BinaryExpression) currentExpression).Right, equalityExpression); + case ExpressionType.Conditional: + var conditionalExpression = (ConditionalExpression) currentExpression; + return IsNullable(conditionalExpression.IfTrue, equalityExpression) || + IsNullable(conditionalExpression.IfFalse, equalityExpression); + case ExpressionType.Call: + var methodInfo = ((MethodCallExpression) currentExpression).Method; + return !_functionRegistry.TryGetGenerator(methodInfo, out var method) || method.AllowsNullableReturnType(methodInfo); + case ExpressionType.MemberAccess: + var memberExpression = (MemberExpression) currentExpression; + + if (_functionRegistry.TryGetGenerator(memberExpression.Member, out _)) + { + // We have to skip the property as it will be converted to a function that can return null + // if the argument is null + currentExpression = memberExpression.Expression; + continue; + } + + var memberType = ReflectHelper.GetPropertyOrFieldType(memberExpression.Member); + if (memberType?.IsValueType == true && !memberType.IsNullable()) + { + currentExpression = memberExpression.Expression; + continue; + } + + // Check if there was a not null check prior the equality expression + if (( + equalityExpression.NodeType == ExpressionType.NotEqual || + equalityExpression.NodeType == ExpressionType.Equal + ) && + _equalityNotNullMembers.TryGetValue(equalityExpression, out var notNullMembers) && + notNullMembers.Any(o => AreEqual(o, memberExpression))) + { + return false; + } + + // We have to check the member mapping to determine if is nullable + var entityName = TryGetEntityName(memberExpression); + if (entityName == null) + { + return true; // not mapped + } + + var persister = _parameters.SessionFactory.GetEntityPersister(entityName); + var index = persister.EntityMetamodel.GetPropertyIndexOrNull(memberExpression.Member.Name); + if (!index.HasValue || persister.EntityMetamodel.PropertyNullability[index.Value]) + { + return true; // not mapped or nullable + } + + currentExpression = memberExpression.Expression; + continue; + case ExpressionType.Extension: + switch (currentExpression) + { + case QuerySourceReferenceExpression querySourceReferenceExpression: + switch (querySourceReferenceExpression.ReferencedQuerySource) + { + case MainFromClause _: + return false; // we reached to the root expression, there were no nullable expressions + case NhJoinClause joinClause: + return IsNullable(joinClause.FromExpression, equalityExpression); + default: + return true; // unknown query source + } + case SubQueryExpression subQuery: + if (subQuery.QueryModel.SelectClause.Selector is NhAggregatedExpression subQueryAggregatedExpression) + { + return subQueryAggregatedExpression.AllowsNullableReturnType; + } + else if (subQuery.QueryModel.ResultOperators.Any(o => NotNullOperators.Contains(o.GetType()))) + { + return false; + } + + return true; + case NhAggregatedExpression aggregatedExpression: + return aggregatedExpression.AllowsNullableReturnType; + default: + return true; // a query can return null and we cannot calculate it as it is not yet executed + } + case ExpressionType.TypeIs: // an equal or in operator will be generated and those cannot return null + case ExpressionType.NewArrayInit: + return false; + case ExpressionType.Constant: + return VisitorUtil.IsNullConstant(currentExpression); + case ExpressionType.Parameter: + return !currentExpression.Type.IsValueType; + default: + return true; + } + } + } + + private bool AreEqual(MemberExpression memberExpression, MemberExpression otherMemberExpression) + { + if (memberExpression.Member != otherMemberExpression.Member || + memberExpression.Expression.NodeType != otherMemberExpression.Expression.NodeType) + { + return false; + } + + switch (memberExpression.Expression) + { + case QuerySourceReferenceExpression querySourceReferenceExpression: + if (otherMemberExpression.Expression is QuerySourceReferenceExpression otherQuerySourceReferenceExpression) + { + return querySourceReferenceExpression.ReferencedQuerySource == + otherQuerySourceReferenceExpression.ReferencedQuerySource; + } + + return false; + // Components have a nested member expression + case MemberExpression nestedMemberExpression: + if (otherMemberExpression.Expression is MemberExpression otherNestedMemberExpression) + { + return AreEqual(nestedMemberExpression, otherNestedMemberExpression); + } + + return false; + default: + return memberExpression.Expression == otherMemberExpression.Expression; + } + } + + private string TryGetEntityName(MemberExpression memberExpression) { - var hqlDot = original as HqlDot; - return hqlDot != null && hqlDot.Children.Last() is HqlIdent; + System.Type entityType; + // Try to get the actual entity type from the query source if possbile as member can be declared + // in a base type + if (memberExpression.Expression is QuerySourceReferenceExpression querySourceReferenceExpression) + { + entityType = querySourceReferenceExpression.Type; + } + else + { + entityType = memberExpression.Member.ReflectedType; + } + + return _parameters.SessionFactory.TryGetGuessEntityName(entityType); } protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression) From c273c20eb73bd9083ce4b4d350adf543915f44f5 Mon Sep 17 00:00:00 2001 From: maca88 Date: Mon, 28 Jan 2019 22:09:10 +0100 Subject: [PATCH 02/11] Split subselect tests --- .../Async/Linq/NullComparisonTests.cs | 64 +++++++++++-------- .../Linq/NullComparisonTests.cs | 64 +++++++++++-------- 2 files changed, 78 insertions(+), 50 deletions(-) diff --git a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs index af1510ad68d..51b79afae37 100644 --- a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs @@ -33,9 +33,7 @@ public class NullComparisonTestsAsync : LinqTestCase [Test] public async Task NullInequalityWithNotNullAsync() { - IQueryable q; - - q = session.Query().Where(o => o.Input != null); + var q = session.Query().Where(o => o.Input != null); await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothSame, BothDifferent)); q = session.Query().Where(o => null != o.Input); @@ -134,18 +132,6 @@ public async Task NullInequalityWithNotNullAsync() q = session.Query().Where(o => (o.Input + o.Output) != o.Output); await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothSame, BothDifferent)); - q = session.Query().Where(o => o.RelatedItems.Count != 1); - await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); - - q = session.Query().Where(o => o.RelatedItems.Max(r => r.Id) != 0); - await (ExpectAllAsync(q, Does.Contain("is null").IgnoreCase)); - - q = session.Query().Where(o => o.RelatedItems.Take(10).All(r => r.Output != null) != o.NullableBool); - await (ExpectAllAsync(q, Does.Not.Contain("or case").IgnoreCase)); - - q = session.Query().Where(o => o.RelatedItems.Where(r => r.Id == 0).Sum(r => r.Input.Length) != 5); - await (ExpectAllAsync(q, Does.Contain("or (").IgnoreCase)); - q = session.Query().Where(o => o.Address.Street != o.Output); await (ExpectAsync(q, Does.Contain("Input is null").IgnoreCase, BothDifferent, OutputSet, BothNull)); @@ -157,6 +143,27 @@ public async Task NullInequalityWithNotNullAsync() q = session.Query().Where(o => o.Address.Street != null && o.Address.Street != o.NullableOutput); await (ExpectAsync(q, Does.Contain("Output is null").IgnoreCase, InputSet, BothDifferent)); + } + + [Test] + public async Task NullInequalityWithNotNullSubSelectAsync() + { + if (!Dialect.SupportsScalarSubSelects) + { + Assert.Ignore("Dialect does not support scalar subselects"); + } + + var q = session.Query().Where(o => o.RelatedItems.Count != 1); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.RelatedItems.Max(r => r.Id) != 0); + await (ExpectAllAsync(q, Does.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output != null) != o.NullableBool); + await (ExpectAllAsync(q, Does.Not.Contain("or case").IgnoreCase)); + + q = session.Query().Where(o => o.RelatedItems.Where(r => r.Id == 0).Sum(r => r.Input.Length) != 5); + await (ExpectAllAsync(q, Does.Contain("or (").IgnoreCase)); q = session.Query().Where(o => o.RelatedItems.All(r => r.Output != null) != (o.NullableOutput.Length > 0)); await (ExpectAsync(q, Does.Not.Contain("or case").IgnoreCase)); @@ -165,9 +172,7 @@ public async Task NullInequalityWithNotNullAsync() [Test] public async Task NullEqualityWithNotNullAsync() { - IQueryable q; - - q = session.Query().Where(o => o.Input == null); + var q = session.Query().Where(o => o.Input == null); await (ExpectAsync(q, Does.Not.Contain("or is null").IgnoreCase, OutputSet, BothNull)); q = session.Query().Where(o => null == o.Input); @@ -260,13 +265,7 @@ public async Task NullEqualityWithNotNullAsync() q = session.Query().Where(o => (o.NullableOutput + o.Output) == o.Output); await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); - q = session.Query().Where(o => (o.Input + o.Output) == o.Output); - await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); - - q = session.Query().Where(o => o.RelatedItems.Count == 1); - await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); - - q = session.Query().Where(o => o.RelatedItems.Max(r => r.Id) == 0); + q = session.Query().Where(o => (o.Output + o.Output) == o.Output); await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); q = session.Query().Where(o => !o.Input.Equals(o.Output)); @@ -294,6 +293,21 @@ public async Task NullEqualityWithNotNullAsync() await (ExpectAsync(q, Does.Not.Contain("Output is null").IgnoreCase, BothSame)); } + [Test] + public async Task NullEqualityWithNotNullSubSelectAsync() + { + if (!Dialect.SupportsScalarSubSelects) + { + Assert.Ignore("Dialect does not support scalar subselects"); + } + + var q = session.Query().Where(o => o.RelatedItems.Count == 1); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.RelatedItems.Max(r => r.Id) == 0); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + } + [Test] public async Task NullEqualityAsync() diff --git a/src/NHibernate.Test/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Linq/NullComparisonTests.cs index 0d9363bb5ed..f6c8ba96496 100644 --- a/src/NHibernate.Test/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Linq/NullComparisonTests.cs @@ -21,9 +21,7 @@ public class NullComparisonTests : LinqTestCase [Test] public void NullInequalityWithNotNull() { - IQueryable q; - - q = session.Query().Where(o => o.Input != null); + var q = session.Query().Where(o => o.Input != null); Expect(q, Does.Not.Contain("is null").IgnoreCase, InputSet, BothSame, BothDifferent); q = session.Query().Where(o => null != o.Input); @@ -122,18 +120,6 @@ public void NullInequalityWithNotNull() q = session.Query().Where(o => (o.Input + o.Output) != o.Output); Expect(q, Does.Not.Contain("is null").IgnoreCase, BothSame, BothDifferent); - q = session.Query().Where(o => o.RelatedItems.Count != 1); - Expect(q, Does.Not.Contain("is null").IgnoreCase); - - q = session.Query().Where(o => o.RelatedItems.Max(r => r.Id) != 0); - ExpectAll(q, Does.Contain("is null").IgnoreCase); - - q = session.Query().Where(o => o.RelatedItems.Take(10).All(r => r.Output != null) != o.NullableBool); - ExpectAll(q, Does.Not.Contain("or case").IgnoreCase); - - q = session.Query().Where(o => o.RelatedItems.Where(r => r.Id == 0).Sum(r => r.Input.Length) != 5); - ExpectAll(q, Does.Contain("or (").IgnoreCase); - q = session.Query().Where(o => o.Address.Street != o.Output); Expect(q, Does.Contain("Input is null").IgnoreCase, BothDifferent, OutputSet, BothNull); @@ -145,6 +131,27 @@ public void NullInequalityWithNotNull() q = session.Query().Where(o => o.Address.Street != null && o.Address.Street != o.NullableOutput); Expect(q, Does.Contain("Output is null").IgnoreCase, InputSet, BothDifferent); + } + + [Test] + public void NullInequalityWithNotNullSubSelect() + { + if (!Dialect.SupportsScalarSubSelects) + { + Assert.Ignore("Dialect does not support scalar subselects"); + } + + var q = session.Query().Where(o => o.RelatedItems.Count != 1); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.RelatedItems.Max(r => r.Id) != 0); + ExpectAll(q, Does.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.RelatedItems.All(r => r.Output != null) != o.NullableBool); + ExpectAll(q, Does.Not.Contain("or case").IgnoreCase); + + q = session.Query().Where(o => o.RelatedItems.Where(r => r.Id == 0).Sum(r => r.Input.Length) != 5); + ExpectAll(q, Does.Contain("or (").IgnoreCase); q = session.Query().Where(o => o.RelatedItems.All(r => r.Output != null) != (o.NullableOutput.Length > 0)); Expect(q, Does.Not.Contain("or case").IgnoreCase); @@ -153,9 +160,7 @@ public void NullInequalityWithNotNull() [Test] public void NullEqualityWithNotNull() { - IQueryable q; - - q = session.Query().Where(o => o.Input == null); + var q = session.Query().Where(o => o.Input == null); Expect(q, Does.Not.Contain("or is null").IgnoreCase, OutputSet, BothNull); q = session.Query().Where(o => null == o.Input); @@ -248,13 +253,7 @@ public void NullEqualityWithNotNull() q = session.Query().Where(o => (o.NullableOutput + o.Output) == o.Output); Expect(q, Does.Not.Contain("is null").IgnoreCase); - q = session.Query().Where(o => (o.Input + o.Output) == o.Output); - Expect(q, Does.Not.Contain("is null").IgnoreCase); - - q = session.Query().Where(o => o.RelatedItems.Count == 1); - ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); - - q = session.Query().Where(o => o.RelatedItems.Max(r => r.Id) == 0); + q = session.Query().Where(o => (o.Output + o.Output) == o.Output); Expect(q, Does.Not.Contain("is null").IgnoreCase); q = session.Query().Where(o => !o.Input.Equals(o.Output)); @@ -282,6 +281,21 @@ public void NullEqualityWithNotNull() Expect(q, Does.Not.Contain("Output is null").IgnoreCase, BothSame); } + [Test] + public void NullEqualityWithNotNullSubSelect() + { + if (!Dialect.SupportsScalarSubSelects) + { + Assert.Ignore("Dialect does not support scalar subselects"); + } + + var q = session.Query().Where(o => o.RelatedItems.Count == 1); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.RelatedItems.Max(r => r.Id) == 0); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + } + [Test] public void NullEquality() From 40b830d735d59a4ebf8f9480e7e5def8c2b0d272 Mon Sep 17 00:00:00 2001 From: maca88 Date: Thu, 31 Jan 2019 21:35:50 +0100 Subject: [PATCH 03/11] Moved nullable check code into a separated class --- .../Visitors/HqlGeneratorExpressionVisitor.cs | 300 +---------------- .../Visitors/NullableExpressionDetector.cs | 309 ++++++++++++++++++ 2 files changed, 316 insertions(+), 293 deletions(-) create mode 100644 src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index 8cc5c8089e8..224af655a24 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Generic; using System.Data; using System.Dynamic; using System.Linq; @@ -8,16 +7,12 @@ using NHibernate.Engine.Query; using NHibernate.Hql.Ast; using NHibernate.Hql.Ast.ANTLR; -using NHibernate.Linq.Clauses; using NHibernate.Linq.Expressions; using NHibernate.Linq.Functions; -using NHibernate.Mapping.ByCode; using NHibernate.Param; using NHibernate.Type; using NHibernate.Util; -using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; -using Remotion.Linq.Clauses.ResultOperators; namespace NHibernate.Linq.Visitors { @@ -26,17 +21,7 @@ public class HqlGeneratorExpressionVisitor : IHqlExpressionVisitor private readonly HqlTreeBuilder _hqlTreeBuilder = new HqlTreeBuilder(); private readonly VisitorParameters _parameters; private readonly ILinqToHqlGeneratorsRegistry _functionRegistry; - private readonly Dictionary> _equalityNotNullMembers = - new Dictionary>(); - - private static readonly HashSet NotNullOperators = new HashSet() - { - typeof(AllResultOperator), - typeof(AnyResultOperator), - typeof(ContainsResultOperator), - typeof(CountResultOperator), - typeof(LongCountResultOperator) - }; + private readonly NullableExpressionDetector _nullableExpressionDetector; public static HqlTreeNode Visit(Expression expression, VisitorParameters parameters) { @@ -47,6 +32,7 @@ public HqlGeneratorExpressionVisitor(VisitorParameters parameters) { _functionRegistry = parameters.SessionFactory.Settings.LinqToHqlGeneratorsRegistry; _parameters = parameters; + _nullableExpressionDetector = new NullableExpressionDetector(_parameters.SessionFactory, _functionRegistry); } public ISessionFactory SessionFactory { get { return _parameters.SessionFactory; } } @@ -308,94 +294,6 @@ private HqlTreeNode VisitVBStringComparisonExpression(VBStringComparisonExpressi return VisitExpression(expression.Comparison); } - private void SearchForNotNullMembersCheck(BinaryExpression expression) - { - // Check for a member not null check that has a not equals expression - // Example: o.Status != null && o.Status != "New" - // Example: (o.Status != null && o.OldStatus != null) && (o.Status != o.OldStatus) - // Example: (o.Status != null && o.OldStatus != null) && (o.Status == o.OldStatus) - if (expression.NodeType != ExpressionType.AndAlso || - expression.Right.NodeType != ExpressionType.NotEqual && - expression.Right.NodeType != ExpressionType.Equal || - expression.Left.NodeType != ExpressionType.AndAlso && - expression.Left.NodeType != ExpressionType.NotEqual) - { - return; - } - - // Skip if there are no member access expressions on the right side - var notEqualExpression = (BinaryExpression) expression.Right; - if (!IsMemberAccess(notEqualExpression.Left) && !IsMemberAccess(notEqualExpression.Right)) - { - return; - } - - var notNullMembers = new List(); - // We may have multiple conditions - // Example: o.Status != null && o.OldStatus != null - if (expression.Left.NodeType == ExpressionType.AndAlso) - { - FindAllNotNullMembers((BinaryExpression) expression.Left, notNullMembers); - } - else - { - FindNotNullMember((BinaryExpression) expression.Left, notNullMembers); - } - - if (notNullMembers.Count > 0) - { - _equalityNotNullMembers[notEqualExpression] = notNullMembers; - } - } - - private static bool IsMemberAccess(Expression expression) - { - if (expression.NodeType == ExpressionType.MemberAccess) - { - return true; - } - - // Nullable members can be wrapped in a convert expression - return expression is UnaryExpression unaryExpression && unaryExpression.Operand.NodeType == ExpressionType.MemberAccess; - } - - private static void FindAllNotNullMembers(BinaryExpression andAlsoExpression, List notNullMembers) - { - if (andAlsoExpression.Right.NodeType == ExpressionType.NotEqual) - { - FindNotNullMember((BinaryExpression) andAlsoExpression.Right, notNullMembers); - } - else if (andAlsoExpression.Right.NodeType == ExpressionType.AndAlso) - { - FindAllNotNullMembers((BinaryExpression) andAlsoExpression.Right, notNullMembers); - } - else - { - return; - } - - if (andAlsoExpression.Left.NodeType == ExpressionType.NotEqual) - { - FindNotNullMember((BinaryExpression) andAlsoExpression.Left, notNullMembers); - } - else if (andAlsoExpression.Left.NodeType == ExpressionType.AndAlso) - { - FindAllNotNullMembers((BinaryExpression) andAlsoExpression.Left, notNullMembers); - } - } - - private static void FindNotNullMember(BinaryExpression notEqualExpression, List notNullMembers) - { - if (notEqualExpression.Left.NodeType == ExpressionType.MemberAccess && VisitorUtil.IsNullConstant(notEqualExpression.Right)) - { - notNullMembers.Add((MemberExpression) notEqualExpression.Left); - } - else if (VisitorUtil.IsNullConstant(notEqualExpression.Left) && notEqualExpression.Right.NodeType == ExpressionType.MemberAccess) - { - notNullMembers.Add((MemberExpression) notEqualExpression.Right); - } - } - protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression) { if (expression.NodeType == ExpressionType.Equal) @@ -407,7 +305,7 @@ protected HqlTreeNode VisitBinaryExpression(BinaryExpression expression) return TranslateInequalityComparison(expression); } - SearchForNotNullMembersCheck(expression); + _nullableExpressionDetector.SearchForNotNullMemberChecks(expression); var lhs = VisitExpression(expression.Left).AsExpression(); var rhs = VisitExpression(expression.Right).AsExpression(); @@ -490,8 +388,8 @@ private HqlTreeNode TranslateInequalityComparison(BinaryExpression expression) return _hqlTreeBuilder.IsNotNull(lhs); } - var lhsNullable = IsNullable(expression.Left, expression); - var rhsNullable = IsNullable(expression.Right, expression); + var lhsNullable = _nullableExpressionDetector.IsNullable(expression.Left, expression); + var rhsNullable = _nullableExpressionDetector.IsNullable(expression.Right, expression); var inequality = _hqlTreeBuilder.Inequality(lhs, rhs); @@ -553,8 +451,8 @@ private HqlTreeNode TranslateEqualityComparison(BinaryExpression expression) return _hqlTreeBuilder.IsNull((lhs)); } - var lhsNullable = IsNullable(expression.Left, expression); - var rhsNullable = IsNullable(expression.Right, expression); + var lhsNullable = _nullableExpressionDetector.IsNullable(expression.Left, expression); + var rhsNullable = _nullableExpressionDetector.IsNullable(expression.Right, expression); var equality = _hqlTreeBuilder.Equality(lhs, rhs); @@ -573,190 +471,6 @@ private HqlTreeNode TranslateEqualityComparison(BinaryExpression expression) _hqlTreeBuilder.IsNull(rhs2))); } - private bool IsNullable(Expression expression, BinaryExpression equalityExpression) - { - var currentExpression = expression; - while (true) - { - switch (currentExpression.NodeType) - { - case ExpressionType.Convert: - case ExpressionType.ConvertChecked: - case ExpressionType.TypeAs: - var unaryExpression = (UnaryExpression) currentExpression; - return IsNullable(unaryExpression.Operand, equalityExpression); // a cast will not return null if the operand is not null - case ExpressionType.Not: - case ExpressionType.And: - case ExpressionType.Or: - case ExpressionType.ExclusiveOr: - case ExpressionType.LeftShift: - case ExpressionType.RightShift: - case ExpressionType.AndAlso: - case ExpressionType.OrElse: - case ExpressionType.Equal: - case ExpressionType.NotEqual: - case ExpressionType.GreaterThanOrEqual: - case ExpressionType.GreaterThan: - case ExpressionType.LessThan: - case ExpressionType.LessThanOrEqual: - return false; - case ExpressionType.Add: - case ExpressionType.AddChecked: - case ExpressionType.Divide: - case ExpressionType.Modulo: - case ExpressionType.Multiply: - case ExpressionType.MultiplyChecked: - case ExpressionType.Power: - case ExpressionType.Subtract: - case ExpressionType.SubtractChecked: - var binaryExpression = (BinaryExpression) currentExpression; - return IsNullable(binaryExpression.Left, equalityExpression) || IsNullable(binaryExpression.Right, equalityExpression); - case ExpressionType.ArrayIndex: - return true; // for indexed lists we cannot determine whether the item will be null or not - case ExpressionType.Coalesce: - return IsNullable(((BinaryExpression) currentExpression).Right, equalityExpression); - case ExpressionType.Conditional: - var conditionalExpression = (ConditionalExpression) currentExpression; - return IsNullable(conditionalExpression.IfTrue, equalityExpression) || - IsNullable(conditionalExpression.IfFalse, equalityExpression); - case ExpressionType.Call: - var methodInfo = ((MethodCallExpression) currentExpression).Method; - return !_functionRegistry.TryGetGenerator(methodInfo, out var method) || method.AllowsNullableReturnType(methodInfo); - case ExpressionType.MemberAccess: - var memberExpression = (MemberExpression) currentExpression; - - if (_functionRegistry.TryGetGenerator(memberExpression.Member, out _)) - { - // We have to skip the property as it will be converted to a function that can return null - // if the argument is null - currentExpression = memberExpression.Expression; - continue; - } - - var memberType = ReflectHelper.GetPropertyOrFieldType(memberExpression.Member); - if (memberType?.IsValueType == true && !memberType.IsNullable()) - { - currentExpression = memberExpression.Expression; - continue; - } - - // Check if there was a not null check prior the equality expression - if (( - equalityExpression.NodeType == ExpressionType.NotEqual || - equalityExpression.NodeType == ExpressionType.Equal - ) && - _equalityNotNullMembers.TryGetValue(equalityExpression, out var notNullMembers) && - notNullMembers.Any(o => AreEqual(o, memberExpression))) - { - return false; - } - - // We have to check the member mapping to determine if is nullable - var entityName = TryGetEntityName(memberExpression); - if (entityName == null) - { - return true; // not mapped - } - - var persister = _parameters.SessionFactory.GetEntityPersister(entityName); - var index = persister.EntityMetamodel.GetPropertyIndexOrNull(memberExpression.Member.Name); - if (!index.HasValue || persister.EntityMetamodel.PropertyNullability[index.Value]) - { - return true; // not mapped or nullable - } - - currentExpression = memberExpression.Expression; - continue; - case ExpressionType.Extension: - switch (currentExpression) - { - case QuerySourceReferenceExpression querySourceReferenceExpression: - switch (querySourceReferenceExpression.ReferencedQuerySource) - { - case MainFromClause _: - return false; // we reached to the root expression, there were no nullable expressions - case NhJoinClause joinClause: - return IsNullable(joinClause.FromExpression, equalityExpression); - default: - return true; // unknown query source - } - case SubQueryExpression subQuery: - if (subQuery.QueryModel.SelectClause.Selector is NhAggregatedExpression subQueryAggregatedExpression) - { - return subQueryAggregatedExpression.AllowsNullableReturnType; - } - else if (subQuery.QueryModel.ResultOperators.Any(o => NotNullOperators.Contains(o.GetType()))) - { - return false; - } - - return true; - case NhAggregatedExpression aggregatedExpression: - return aggregatedExpression.AllowsNullableReturnType; - default: - return true; // a query can return null and we cannot calculate it as it is not yet executed - } - case ExpressionType.TypeIs: // an equal or in operator will be generated and those cannot return null - case ExpressionType.NewArrayInit: - return false; - case ExpressionType.Constant: - return VisitorUtil.IsNullConstant(currentExpression); - case ExpressionType.Parameter: - return !currentExpression.Type.IsValueType; - default: - return true; - } - } - } - - private bool AreEqual(MemberExpression memberExpression, MemberExpression otherMemberExpression) - { - if (memberExpression.Member != otherMemberExpression.Member || - memberExpression.Expression.NodeType != otherMemberExpression.Expression.NodeType) - { - return false; - } - - switch (memberExpression.Expression) - { - case QuerySourceReferenceExpression querySourceReferenceExpression: - if (otherMemberExpression.Expression is QuerySourceReferenceExpression otherQuerySourceReferenceExpression) - { - return querySourceReferenceExpression.ReferencedQuerySource == - otherQuerySourceReferenceExpression.ReferencedQuerySource; - } - - return false; - // Components have a nested member expression - case MemberExpression nestedMemberExpression: - if (otherMemberExpression.Expression is MemberExpression otherNestedMemberExpression) - { - return AreEqual(nestedMemberExpression, otherNestedMemberExpression); - } - - return false; - default: - return memberExpression.Expression == otherMemberExpression.Expression; - } - } - - private string TryGetEntityName(MemberExpression memberExpression) - { - System.Type entityType; - // Try to get the actual entity type from the query source if possbile as member can be declared - // in a base type - if (memberExpression.Expression is QuerySourceReferenceExpression querySourceReferenceExpression) - { - entityType = querySourceReferenceExpression.Type; - } - else - { - entityType = memberExpression.Member.ReflectedType; - } - - return _parameters.SessionFactory.TryGetGuessEntityName(entityType); - } - protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression) { switch (expression.NodeType) diff --git a/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs b/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs new file mode 100644 index 00000000000..bc4f1fa57b0 --- /dev/null +++ b/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs @@ -0,0 +1,309 @@ +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using NHibernate.Engine; +using NHibernate.Linq.Clauses; +using NHibernate.Linq.Expressions; +using NHibernate.Linq.Functions; +using NHibernate.Util; +using Remotion.Linq.Clauses; +using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Clauses.ResultOperators; + +namespace NHibernate.Linq.Visitors +{ + internal class NullableExpressionDetector + { + private static readonly HashSet NotNullOperators = new HashSet + { + typeof(AllResultOperator), + typeof(AnyResultOperator), + typeof(ContainsResultOperator), + typeof(CountResultOperator), + typeof(LongCountResultOperator) + }; + + private readonly Dictionary> _equalityNotNullMembers = + new Dictionary>(); + private readonly ISessionFactoryImplementor _sessionFactory; + private readonly ILinqToHqlGeneratorsRegistry _functionRegistry; + + public NullableExpressionDetector(ISessionFactoryImplementor sessionFactory, ILinqToHqlGeneratorsRegistry functionRegistry) + { + _sessionFactory = sessionFactory; + _functionRegistry = functionRegistry; + } + + public void SearchForNotNullMemberChecks(BinaryExpression expression) + { + // Check for a member not null check that has a not equals expression + // Example: o.Status != null && o.Status != "New" + // Example: (o.Status != null && o.OldStatus != null) && (o.Status != o.OldStatus) + // Example: (o.Status != null && o.OldStatus != null) && (o.Status == o.OldStatus) + if (expression.NodeType != ExpressionType.AndAlso || + expression.Right.NodeType != ExpressionType.NotEqual && + expression.Right.NodeType != ExpressionType.Equal || + expression.Left.NodeType != ExpressionType.AndAlso && + expression.Left.NodeType != ExpressionType.NotEqual) + { + return; + } + + // Skip if there are no member access expressions on the right side + var notEqualExpression = (BinaryExpression) expression.Right; + if (!IsMemberAccess(notEqualExpression.Left) && !IsMemberAccess(notEqualExpression.Right)) + { + return; + } + + var notNullMembers = new List(); + // We may have multiple conditions + // Example: o.Status != null && o.OldStatus != null + if (expression.Left.NodeType == ExpressionType.AndAlso) + { + FindAllNotNullMembers((BinaryExpression) expression.Left, notNullMembers); + } + else + { + FindNotNullMember((BinaryExpression) expression.Left, notNullMembers); + } + + if (notNullMembers.Count > 0) + { + _equalityNotNullMembers[notEqualExpression] = notNullMembers; + } + } + + public bool IsNullable(Expression expression, BinaryExpression equalityExpression) + { + switch (expression.NodeType) + { + case ExpressionType.Convert: + case ExpressionType.ConvertChecked: + case ExpressionType.TypeAs: + return IsNullable(((UnaryExpression) expression).Operand, equalityExpression); // a cast will not return null if the operand is not null + case ExpressionType.Not: + case ExpressionType.And: + case ExpressionType.Or: + case ExpressionType.ExclusiveOr: + case ExpressionType.LeftShift: + case ExpressionType.RightShift: + case ExpressionType.AndAlso: + case ExpressionType.OrElse: + case ExpressionType.Equal: + case ExpressionType.NotEqual: + case ExpressionType.GreaterThanOrEqual: + case ExpressionType.GreaterThan: + case ExpressionType.LessThan: + case ExpressionType.LessThanOrEqual: + return false; + case ExpressionType.Add: + case ExpressionType.AddChecked: + case ExpressionType.Divide: + case ExpressionType.Modulo: + case ExpressionType.Multiply: + case ExpressionType.MultiplyChecked: + case ExpressionType.Power: + case ExpressionType.Subtract: + case ExpressionType.SubtractChecked: + var binaryExpression = (BinaryExpression) expression; + return IsNullable(binaryExpression.Left, equalityExpression) || IsNullable(binaryExpression.Right, equalityExpression); + case ExpressionType.ArrayIndex: + return true; // for indexed lists we cannot determine whether the item will be null or not + case ExpressionType.Coalesce: + return IsNullable(((BinaryExpression) expression).Right, equalityExpression); + case ExpressionType.Conditional: + var conditionalExpression = (ConditionalExpression) expression; + return IsNullable(conditionalExpression.IfTrue, equalityExpression) || + IsNullable(conditionalExpression.IfFalse, equalityExpression); + case ExpressionType.Call: + var methodInfo = ((MethodCallExpression) expression).Method; + return !_functionRegistry.TryGetGenerator(methodInfo, out var method) || method.AllowsNullableReturnType(methodInfo); + case ExpressionType.MemberAccess: + return IsNullable((MemberExpression) expression, equalityExpression); + case ExpressionType.Extension: + return IsNullableExtension(expression, equalityExpression); + case ExpressionType.TypeIs: // an equal or in operator will be generated and those cannot return null + case ExpressionType.NewArrayInit: + return false; + case ExpressionType.Constant: + return VisitorUtil.IsNullConstant(expression); + case ExpressionType.Parameter: + return !expression.Type.IsValueType; + default: + return true; + } + } + + private bool IsNullable(MemberExpression memberExpression, BinaryExpression equalityExpression) + { + if (_functionRegistry.TryGetGenerator(memberExpression.Member, out _)) + { + // We have to skip the property as it will be converted to a function that can return null + // if the argument is null + return IsNullable(memberExpression.Expression, equalityExpression); + } + + var memberType = memberExpression.Member.GetPropertyOrFieldType(); + if (memberType?.IsValueType == true && !memberType.IsNullable()) + { + return IsNullable(memberExpression.Expression, equalityExpression); + } + + // Check if there was a not null check prior the equality expression + if (( + equalityExpression.NodeType == ExpressionType.NotEqual || + equalityExpression.NodeType == ExpressionType.Equal + ) && + _equalityNotNullMembers.TryGetValue(equalityExpression, out var notNullMembers) && + notNullMembers.Any(o => AreEqual(o, memberExpression))) + { + return false; + } + + // We have to check the member mapping to determine if is nullable + var entityName = TryGetEntityName(memberExpression); + if (entityName == null) + { + return true; // not mapped + } + + var persister = _sessionFactory.GetEntityPersister(entityName); + var index = persister.EntityMetamodel.GetPropertyIndexOrNull(memberExpression.Member.Name); + if (!index.HasValue || persister.EntityMetamodel.PropertyNullability[index.Value]) + { + return true; // not mapped or nullable + } + + return IsNullable(memberExpression.Expression, equalityExpression); + } + + private bool IsNullableExtension(Expression extensionExpression, BinaryExpression equalityExpression) + { + switch (extensionExpression) + { + case QuerySourceReferenceExpression querySourceReferenceExpression: + switch (querySourceReferenceExpression.ReferencedQuerySource) + { + case MainFromClause _: + return false; // we reached to the root expression, there were no nullable expressions + case NhJoinClause joinClause: + return IsNullable(joinClause.FromExpression, equalityExpression); + default: + return true; // unknown query source + } + case SubQueryExpression subQueryExpression: + if (subQueryExpression.QueryModel.SelectClause.Selector is NhAggregatedExpression subQueryAggregatedExpression) + { + return subQueryAggregatedExpression.AllowsNullableReturnType; + } + else if (subQueryExpression.QueryModel.ResultOperators.Any(o => NotNullOperators.Contains(o.GetType()))) + { + return false; + } + + return true; + case NhAggregatedExpression aggregatedExpression: + return aggregatedExpression.AllowsNullableReturnType; + default: + return true; // a query can return null and we cannot calculate it as it is not yet executed + } + } + + private string TryGetEntityName(MemberExpression memberExpression) + { + System.Type entityType; + // Try to get the actual entity type from the query source if possbile as member can be declared + // in a base type + if (memberExpression.Expression is QuerySourceReferenceExpression querySourceReferenceExpression) + { + entityType = querySourceReferenceExpression.Type; + } + else + { + entityType = memberExpression.Member.ReflectedType; + } + + return _sessionFactory.TryGetGuessEntityName(entityType); + } + + private static bool IsMemberAccess(Expression expression) + { + if (expression.NodeType == ExpressionType.MemberAccess) + { + return true; + } + + // Nullable members can be wrapped in a convert expression + return expression is UnaryExpression unaryExpression && unaryExpression.Operand.NodeType == ExpressionType.MemberAccess; + } + + private static void FindAllNotNullMembers(BinaryExpression andAlsoExpression, List notNullMembers) + { + if (andAlsoExpression.Right.NodeType == ExpressionType.NotEqual) + { + FindNotNullMember((BinaryExpression) andAlsoExpression.Right, notNullMembers); + } + else if (andAlsoExpression.Right.NodeType == ExpressionType.AndAlso) + { + FindAllNotNullMembers((BinaryExpression) andAlsoExpression.Right, notNullMembers); + } + else + { + return; + } + + if (andAlsoExpression.Left.NodeType == ExpressionType.NotEqual) + { + FindNotNullMember((BinaryExpression) andAlsoExpression.Left, notNullMembers); + } + else if (andAlsoExpression.Left.NodeType == ExpressionType.AndAlso) + { + FindAllNotNullMembers((BinaryExpression) andAlsoExpression.Left, notNullMembers); + } + } + + private static void FindNotNullMember(BinaryExpression notEqualExpression, List notNullMembers) + { + if (notEqualExpression.Left.NodeType == ExpressionType.MemberAccess && VisitorUtil.IsNullConstant(notEqualExpression.Right)) + { + notNullMembers.Add((MemberExpression) notEqualExpression.Left); + } + else if (VisitorUtil.IsNullConstant(notEqualExpression.Left) && notEqualExpression.Right.NodeType == ExpressionType.MemberAccess) + { + notNullMembers.Add((MemberExpression) notEqualExpression.Right); + } + } + + private static bool AreEqual(MemberExpression memberExpression, MemberExpression otherMemberExpression) + { + if (memberExpression.Member != otherMemberExpression.Member || + memberExpression.Expression.NodeType != otherMemberExpression.Expression.NodeType) + { + return false; + } + + switch (memberExpression.Expression) + { + case QuerySourceReferenceExpression querySourceReferenceExpression: + if (otherMemberExpression.Expression is QuerySourceReferenceExpression otherQuerySourceReferenceExpression) + { + return querySourceReferenceExpression.ReferencedQuerySource == + otherQuerySourceReferenceExpression.ReferencedQuerySource; + } + + return false; + // Components have a nested member expression + case MemberExpression nestedMemberExpression: + if (otherMemberExpression.Expression is MemberExpression otherNestedMemberExpression) + { + return AreEqual(nestedMemberExpression, otherNestedMemberExpression); + } + + return false; + default: + return memberExpression.Expression == otherMemberExpression.Expression; + } + } + } +} From f5a5acb2ea815f43396ac472be1e1ae5ac9d9c69 Mon Sep 17 00:00:00 2001 From: maca88 Date: Mon, 8 Apr 2019 21:49:10 +0200 Subject: [PATCH 04/11] Fix edge case scenarios --- .../Async/Linq/NullComparisonTests.cs | 70 ++++++++++++++++++- .../Linq/NullComparisonTests.cs | 70 ++++++++++++++++++- .../Persister/Entity/IEntityPersister.cs | 1 + .../Functions/BaseHqlGeneratorForMethod.cs | 5 -- .../Visitors/NullableExpressionDetector.cs | 31 +++----- 5 files changed, 147 insertions(+), 30 deletions(-) diff --git a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs index 51b79afae37..010035edc4f 100644 --- a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs @@ -143,6 +143,18 @@ public async Task NullInequalityWithNotNullAsync() q = session.Query().Where(o => o.Address.Street != null && o.Address.Street != o.NullableOutput); await (ExpectAsync(q, Does.Contain("Output is null").IgnoreCase, InputSet, BothDifferent)); + + await (ExpectAsync(session.Query().Where(o => o.CustomerId != null), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => null != o.CustomerId), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CustomerId != "test"), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" != o.CustomerId), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.Order.Customer.CustomerId != "test"), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" != o.Order.Customer.CustomerId), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.Order.Customer.CompanyName != "test"), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" != o.Order.Customer.CompanyName), Does.Not.Contain("is null").IgnoreCase)); } [Test] @@ -173,10 +185,10 @@ public async Task NullInequalityWithNotNullSubSelectAsync() public async Task NullEqualityWithNotNullAsync() { var q = session.Query().Where(o => o.Input == null); - await (ExpectAsync(q, Does.Not.Contain("or is null").IgnoreCase, OutputSet, BothNull)); + await (ExpectAsync(q, Does.Contain("is null").IgnoreCase, OutputSet, BothNull)); q = session.Query().Where(o => null == o.Input); - await (ExpectAsync(q, Does.Not.Contain("or is null").IgnoreCase, OutputSet, BothNull)); + await (ExpectAsync(q, Does.Contain("is null").IgnoreCase, OutputSet, BothNull)); q = session.Query().Where(o => o.InputNullability == AnotherEntityNullability.True); await (ExpectAsync(q, Does.Not.Contain("end is null").IgnoreCase, BothNull, OutputSet)); @@ -291,6 +303,15 @@ public async Task NullEqualityWithNotNullAsync() q = session.Query().Where(o => o.Address.Street != null && o.Address.Street == o.NullableOutput); await (ExpectAsync(q, Does.Not.Contain("Output is null").IgnoreCase, BothSame)); + + await (ExpectAsync(session.Query().Where(o => o.CustomerId == null), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => null == o.CustomerId), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => o.CustomerId == "test"), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" == o.CustomerId), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => o.Order.Customer.CustomerId == "test"), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" == o.Order.Customer.CustomerId), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => o.Order.Customer.CompanyName == "test"), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" == o.Order.Customer.CompanyName), Does.Not.Contain("is null").IgnoreCase)); } [Test] @@ -377,6 +398,24 @@ public async Task NullEqualityAsync() // Columns against columns q = from x in session.Query() where x.Input == x.Output select x; await (ExpectAsync(q, BothSame, BothNull)); + + await (ExpectAsync(session.Query().Where(o => o.Order.Customer.ContactName == null), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => null == o.Order.Customer.ContactName), Does.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.Order.Customer.ContactName == "test"), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" == o.Order.Customer.ContactName), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => null == o.Component.Property1), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => o.Component.Property1 == null), Does.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => "test" == o.Component.Property1), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => o.Component.Property1 == "test"), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => null == o.Component.OtherComponent.OtherProperty1), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => o.Component.OtherComponent.OtherProperty1 == null), Does.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => "test" == o.Component.OtherComponent.OtherProperty1), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => o.Component.OtherComponent.OtherProperty1 == "test"), Does.Not.Contain("is null").IgnoreCase)); } [Test] @@ -435,6 +474,24 @@ public async Task NullInequalityAsync() // Columns against columns q = from x in session.Query() where x.Input != x.Output select x; await (ExpectAsync(q, BothDifferent, InputSet, OutputSet)); + + await (ExpectAsync(session.Query().Where(o => o.Order.Customer.ContactName != null), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => null != o.Order.Customer.ContactName), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.Order.Customer.ContactName != "test"), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" != o.Order.Customer.ContactName), Does.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => null != o.Component.Property1), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => o.Component.Property1 != null), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => "test" != o.Component.Property1), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => o.Component.Property1 != "test"), Does.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => null != o.Component.OtherComponent.OtherProperty1), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => o.Component.OtherComponent.OtherProperty1 != null), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => "test" != o.Component.OtherComponent.OtherProperty1), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => o.Component.OtherComponent.OtherProperty1 != "test"), Does.Contain("is null").IgnoreCase)); } [Test] @@ -633,6 +690,15 @@ private async Task ExpectAsync(IQueryable q, IResolveCons return (await (q.ToListAsync(cancellationToken))).OrderBy(Key).ToList(); } + private static async Task ExpectAsync(IQueryable query, IResolveConstraint sqlConstraint, CancellationToken cancellationToken = default(CancellationToken)) + { + using (var sqlLog = new SqlLogSpy()) + { + var list = await (query.ToListAsync(cancellationToken)); + Assert.That(sqlLog.GetWholeLog(), sqlConstraint); + } + } + private static string Key(AnotherEntityRequired e) { return "Input=" + (e.Input ?? "NULL") + ", Output=" + (e.Output ?? "NULL"); diff --git a/src/NHibernate.Test/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Linq/NullComparisonTests.cs index f6c8ba96496..386c5361d8c 100644 --- a/src/NHibernate.Test/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Linq/NullComparisonTests.cs @@ -131,6 +131,18 @@ public void NullInequalityWithNotNull() q = session.Query().Where(o => o.Address.Street != null && o.Address.Street != o.NullableOutput); Expect(q, Does.Contain("Output is null").IgnoreCase, InputSet, BothDifferent); + + Expect(session.Query().Where(o => o.CustomerId != null), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => null != o.CustomerId), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CustomerId != "test"), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" != o.CustomerId), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.Order.Customer.CustomerId != "test"), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" != o.Order.Customer.CustomerId), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.Order.Customer.CompanyName != "test"), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" != o.Order.Customer.CompanyName), Does.Not.Contain("is null").IgnoreCase); } [Test] @@ -161,10 +173,10 @@ public void NullInequalityWithNotNullSubSelect() public void NullEqualityWithNotNull() { var q = session.Query().Where(o => o.Input == null); - Expect(q, Does.Not.Contain("or is null").IgnoreCase, OutputSet, BothNull); + Expect(q, Does.Contain("is null").IgnoreCase, OutputSet, BothNull); q = session.Query().Where(o => null == o.Input); - Expect(q, Does.Not.Contain("or is null").IgnoreCase, OutputSet, BothNull); + Expect(q, Does.Contain("is null").IgnoreCase, OutputSet, BothNull); q = session.Query().Where(o => o.InputNullability == AnotherEntityNullability.True); Expect(q, Does.Not.Contain("end is null").IgnoreCase, BothNull, OutputSet); @@ -279,6 +291,15 @@ public void NullEqualityWithNotNull() q = session.Query().Where(o => o.Address.Street != null && o.Address.Street == o.NullableOutput); Expect(q, Does.Not.Contain("Output is null").IgnoreCase, BothSame); + + Expect(session.Query().Where(o => o.CustomerId == null), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => null == o.CustomerId), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => o.CustomerId == "test"), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" == o.CustomerId), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => o.Order.Customer.CustomerId == "test"), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" == o.Order.Customer.CustomerId), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => o.Order.Customer.CompanyName == "test"), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" == o.Order.Customer.CompanyName), Does.Not.Contain("is null").IgnoreCase); } [Test] @@ -365,6 +386,24 @@ public void NullEquality() // Columns against columns q = from x in session.Query() where x.Input == x.Output select x; Expect(q, BothSame, BothNull); + + Expect(session.Query().Where(o => o.Order.Customer.ContactName == null), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => null == o.Order.Customer.ContactName), Does.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.Order.Customer.ContactName == "test"), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" == o.Order.Customer.ContactName), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => null == o.Component.Property1), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => o.Component.Property1 == null), Does.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => "test" == o.Component.Property1), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => o.Component.Property1 == "test"), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => null == o.Component.OtherComponent.OtherProperty1), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => o.Component.OtherComponent.OtherProperty1 == null), Does.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => "test" == o.Component.OtherComponent.OtherProperty1), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => o.Component.OtherComponent.OtherProperty1 == "test"), Does.Not.Contain("is null").IgnoreCase); } [Test] @@ -423,6 +462,24 @@ public void NullInequality() // Columns against columns q = from x in session.Query() where x.Input != x.Output select x; Expect(q, BothDifferent, InputSet, OutputSet); + + Expect(session.Query().Where(o => o.Order.Customer.ContactName != null), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => null != o.Order.Customer.ContactName), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.Order.Customer.ContactName != "test"), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" != o.Order.Customer.ContactName), Does.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => null != o.Component.Property1), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => o.Component.Property1 != null), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => "test" != o.Component.Property1), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => o.Component.Property1 != "test"), Does.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => null != o.Component.OtherComponent.OtherProperty1), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => o.Component.OtherComponent.OtherProperty1 != null), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => "test" != o.Component.OtherComponent.OtherProperty1), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => o.Component.OtherComponent.OtherProperty1 != "test"), Does.Contain("is null").IgnoreCase); } [Test] @@ -621,6 +678,15 @@ private IList GetResults(IQueryable(IQueryable query, IResolveConstraint sqlConstraint) + { + using (var sqlLog = new SqlLogSpy()) + { + var list = query.ToList(); + Assert.That(sqlLog.GetWholeLog(), sqlConstraint); + } + } + private static string Key(AnotherEntityRequired e) { return "Input=" + (e.Input ?? "NULL") + ", Output=" + (e.Output ?? "NULL"); diff --git a/src/NHibernate/Async/Persister/Entity/IEntityPersister.cs b/src/NHibernate/Async/Persister/Entity/IEntityPersister.cs index db1e4a16c1c..b24810f3998 100644 --- a/src/NHibernate/Async/Persister/Entity/IEntityPersister.cs +++ b/src/NHibernate/Async/Persister/Entity/IEntityPersister.cs @@ -20,6 +20,7 @@ using System.Collections.Generic; using NHibernate.Intercept; using NHibernate.Util; +using System.Linq; namespace NHibernate.Persister.Entity { diff --git a/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs b/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs index d65acba8b87..53955d136a6 100644 --- a/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs +++ b/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs @@ -14,10 +14,5 @@ public abstract class BaseHqlGeneratorForMethod : IHqlGeneratorForMethod, IHqlGe public abstract HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor); public virtual bool AllowsNullableReturnType(MethodInfo method) => true; - - bool IHqlGeneratorForMethodExtended.AllowsNullableReturnType(MethodInfo method) - { - return AllowsNullableReturnType(method); - } } } diff --git a/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs b/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs index bc4f1fa57b0..298c5efb737 100644 --- a/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs +++ b/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs @@ -5,6 +5,7 @@ using NHibernate.Linq.Clauses; using NHibernate.Linq.Expressions; using NHibernate.Linq.Functions; +using NHibernate.Persister.Entity; using NHibernate.Util; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; @@ -162,17 +163,22 @@ private bool IsNullable(MemberExpression memberExpression, BinaryExpression equa } // We have to check the member mapping to determine if is nullable - var entityName = TryGetEntityName(memberExpression); + var entityName = ExpressionsHelper.TryGetEntityName(_sessionFactory, memberExpression, out var memberPath); if (entityName == null) { - return true; // not mapped + return true; // Not mapped } var persister = _sessionFactory.GetEntityPersister(entityName); - var index = persister.EntityMetamodel.GetPropertyIndexOrNull(memberExpression.Member.Name); + if (persister.IsIdentifierMember(memberPath)) + { + return false; // Identifier is always not null + } + + var index = persister.EntityMetamodel.GetPropertyIndexOrNull(memberPath); if (!index.HasValue || persister.EntityMetamodel.PropertyNullability[index.Value]) { - return true; // not mapped or nullable + return true; // Not mapped or nullable } return IsNullable(memberExpression.Expression, equalityExpression); @@ -210,23 +216,6 @@ private bool IsNullableExtension(Expression extensionExpression, BinaryExpressio } } - private string TryGetEntityName(MemberExpression memberExpression) - { - System.Type entityType; - // Try to get the actual entity type from the query source if possbile as member can be declared - // in a base type - if (memberExpression.Expression is QuerySourceReferenceExpression querySourceReferenceExpression) - { - entityType = querySourceReferenceExpression.Type; - } - else - { - entityType = memberExpression.Member.ReflectedType; - } - - return _sessionFactory.TryGetGuessEntityName(entityType); - } - private static bool IsMemberAccess(Expression expression) { if (expression.NodeType == ExpressionType.MemberAccess) From c01ac7b63899610e312dcf94e5223ba32874abfc Mon Sep 17 00:00:00 2001 From: maca88 Date: Tue, 9 Apr 2019 18:40:39 +0200 Subject: [PATCH 05/11] Fix TryGetEntityName for interface mapped members --- .../Northwind/Entities/User.cs | 6 +++ .../Northwind/Mappings/User.hbm.xml | 10 ++++- .../Async/Linq/NullComparisonTests.cs | 39 ++++++++++++++++++ ...Sql2008DialectLinqReadonlyCreateScript.sql | Bin 1868594 -> 1868954 bytes ...Sql2012DialectLinqReadonlyCreateScript.sql | 8 ++-- ...reSQL83DialectLinqReadonlyCreateScript.sql | Bin 1437774 -> 1438124 bytes .../Linq/NorthwindDbCreator.cs | 7 +++- .../Linq/NullComparisonTests.cs | 39 ++++++++++++++++++ .../Visitors/NullableExpressionDetector.cs | 2 +- 9 files changed, 105 insertions(+), 6 deletions(-) diff --git a/src/NHibernate.DomainModel/Northwind/Entities/User.cs b/src/NHibernate.DomainModel/Northwind/Entities/User.cs index c3f220ffda5..c23e667be9b 100644 --- a/src/NHibernate.DomainModel/Northwind/Entities/User.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/User.cs @@ -26,6 +26,8 @@ public interface IUser Role Role { get; set; } EnumStoredAsString Enum1 { get; set; } EnumStoredAsInt32 Enum2 { get; set; } + IUser CreatedBy { get; set; } + IUser ModifiedBy { get; set; } } public class User : IUser, IEntity @@ -50,6 +52,10 @@ public class User : IUser, IEntity public virtual EnumStoredAsInt32 Enum2 { get; set; } + public virtual IUser CreatedBy { get; set; } + + public virtual IUser ModifiedBy { get; set; } + public virtual int NotMapped { get; set; } public virtual Role NotMappedRole { get; set; } diff --git a/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml b/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml index f59dc2956c6..2764cb70898 100644 --- a/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml +++ b/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml @@ -7,11 +7,19 @@ - + + + + + + + + + diff --git a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs index 010035edc4f..70c49d4c0be 100644 --- a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs @@ -155,6 +155,12 @@ public async Task NullInequalityWithNotNullAsync() await (ExpectAsync(session.Query().Where(o => o.Order.Customer.CompanyName != "test"), Does.Not.Contain("is null").IgnoreCase)); await (ExpectAsync(session.Query().Where(o => "test" != o.Order.Customer.CompanyName), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CreatedBy.CreatedBy.CreatedBy.Name != "test"), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" != o.CreatedBy.CreatedBy.CreatedBy.Name), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CreatedBy.CreatedBy.Id != 5), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => 5 != o.CreatedBy.CreatedBy.Id), Does.Not.Contain("is null").IgnoreCase)); } [Test] @@ -306,12 +312,21 @@ public async Task NullEqualityWithNotNullAsync() await (ExpectAsync(session.Query().Where(o => o.CustomerId == null), Does.Contain("is null").IgnoreCase)); await (ExpectAsync(session.Query().Where(o => null == o.CustomerId), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => o.CustomerId == "test"), Does.Not.Contain("is null").IgnoreCase)); await (ExpectAsync(session.Query().Where(o => "test" == o.CustomerId), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => o.Order.Customer.CustomerId == "test"), Does.Not.Contain("is null").IgnoreCase)); await (ExpectAsync(session.Query().Where(o => "test" == o.Order.Customer.CustomerId), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => o.Order.Customer.CompanyName == "test"), Does.Not.Contain("is null").IgnoreCase)); await (ExpectAsync(session.Query().Where(o => "test" == o.Order.Customer.CompanyName), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CreatedBy.CreatedBy.CreatedBy.Name == "test"), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" == o.CreatedBy.CreatedBy.CreatedBy.Name), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CreatedBy.CreatedBy.Id == 5), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => 5 == o.CreatedBy.CreatedBy.Id), Does.Not.Contain("is null").IgnoreCase)); } [Test] @@ -416,6 +431,18 @@ public async Task NullEqualityAsync() await (ExpectAsync(session.Query().Where(o => "test" == o.Component.OtherComponent.OtherProperty1), Does.Not.Contain("is null").IgnoreCase)); await (ExpectAsync(session.Query().Where(o => o.Component.OtherComponent.OtherProperty1 == "test"), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CreatedBy.ModifiedBy.CreatedBy.Name == "test"), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" == o.CreatedBy.ModifiedBy.CreatedBy.Name), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CreatedBy.CreatedBy.Component.OtherComponent.OtherProperty1 == "test"), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" == o.CreatedBy.CreatedBy.Component.OtherComponent.OtherProperty1), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.ModifiedBy.CreatedBy.Id == 5), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => 5 == o.ModifiedBy.CreatedBy.Id), Does.Not.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id == 5), Does.Not.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => 5 == o.CreatedBy.ModifiedBy.Id), Does.Not.Contain("is null").IgnoreCase)); } [Test] @@ -492,6 +519,18 @@ public async Task NullInequalityAsync() await (ExpectAsync(session.Query().Where(o => "test" != o.Component.OtherComponent.OtherProperty1), Does.Contain("is null").IgnoreCase)); await (ExpectAsync(session.Query().Where(o => o.Component.OtherComponent.OtherProperty1 != "test"), Does.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CreatedBy.ModifiedBy.CreatedBy.Name != "test"), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" != o.CreatedBy.ModifiedBy.CreatedBy.Name), Does.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CreatedBy.CreatedBy.Component.OtherComponent.OtherProperty1 != "test"), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => "test" != o.CreatedBy.CreatedBy.Component.OtherComponent.OtherProperty1), Does.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.ModifiedBy.CreatedBy.Id != 5), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => 5 != o.ModifiedBy.CreatedBy.Id), Does.Contain("is null").IgnoreCase)); + + await (ExpectAsync(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id != 5), Does.Contain("is null").IgnoreCase)); + await (ExpectAsync(session.Query().Where(o => 5 != o.CreatedBy.ModifiedBy.Id), Does.Contain("is null").IgnoreCase)); } [Test] diff --git a/src/NHibernate.Test/DbScripts/MsSql2008DialectLinqReadonlyCreateScript.sql b/src/NHibernate.Test/DbScripts/MsSql2008DialectLinqReadonlyCreateScript.sql index b9dbec1e05dff6bb9bc9ec62ff68dd5c95b924ed..12235095e0ddb6ef8e6d4eea097d578a206e2565 100644 GIT binary patch delta 377 zcmdngRy3=#sG)_ig{g(Pg{6hHg{_6Xg``#CMvd8Ahe3fMn!%X?%^U_#pypTxh%7dv7S>5mu1n#XzQ9E0 z3M)vN;q-~-GQX$Gn8>7&Yv}YuB|g>Z&&*_YOs{*z%e8%)xl9QoxrR*!x^Q}dF(21< Q4NDm|Mu>@ORx&)E02+f_UH||9 delta 172 zcmbQ$S+uFGsG)_ig{g(Pg{6hHg{_6Xg`9cHQ1U5%81u#yZperM@J-}FI731U<0lDd6rZQJ1 zA5l`;9%m+F#5ny)KA+h13HCB9lO0l7roSm)uP*2K`IO< zGj7%0KCentf{`4<{^yD+PQO(vI%oQaP!X=lY+E(RHfu8f4$0{s)-rNTKUFV!VtSm5 S2-oy)5~6n7Pc(?O().Where(o => o.Order.Customer.CompanyName != "test"), Does.Not.Contain("is null").IgnoreCase); Expect(session.Query().Where(o => "test" != o.Order.Customer.CompanyName), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CreatedBy.CreatedBy.CreatedBy.Name != "test"), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" != o.CreatedBy.CreatedBy.CreatedBy.Name), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CreatedBy.CreatedBy.Id != 5), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => 5 != o.CreatedBy.CreatedBy.Id), Does.Not.Contain("is null").IgnoreCase); } [Test] @@ -294,12 +300,21 @@ public void NullEqualityWithNotNull() Expect(session.Query().Where(o => o.CustomerId == null), Does.Contain("is null").IgnoreCase); Expect(session.Query().Where(o => null == o.CustomerId), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => o.CustomerId == "test"), Does.Not.Contain("is null").IgnoreCase); Expect(session.Query().Where(o => "test" == o.CustomerId), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => o.Order.Customer.CustomerId == "test"), Does.Not.Contain("is null").IgnoreCase); Expect(session.Query().Where(o => "test" == o.Order.Customer.CustomerId), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => o.Order.Customer.CompanyName == "test"), Does.Not.Contain("is null").IgnoreCase); Expect(session.Query().Where(o => "test" == o.Order.Customer.CompanyName), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CreatedBy.CreatedBy.CreatedBy.Name == "test"), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" == o.CreatedBy.CreatedBy.CreatedBy.Name), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CreatedBy.CreatedBy.Id == 5), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => 5 == o.CreatedBy.CreatedBy.Id), Does.Not.Contain("is null").IgnoreCase); } [Test] @@ -404,6 +419,18 @@ public void NullEquality() Expect(session.Query().Where(o => "test" == o.Component.OtherComponent.OtherProperty1), Does.Not.Contain("is null").IgnoreCase); Expect(session.Query().Where(o => o.Component.OtherComponent.OtherProperty1 == "test"), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CreatedBy.ModifiedBy.CreatedBy.Name == "test"), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" == o.CreatedBy.ModifiedBy.CreatedBy.Name), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CreatedBy.CreatedBy.Component.OtherComponent.OtherProperty1 == "test"), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" == o.CreatedBy.CreatedBy.Component.OtherComponent.OtherProperty1), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.ModifiedBy.CreatedBy.Id == 5), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => 5 == o.ModifiedBy.CreatedBy.Id), Does.Not.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id == 5), Does.Not.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => 5 == o.CreatedBy.ModifiedBy.Id), Does.Not.Contain("is null").IgnoreCase); } [Test] @@ -480,6 +507,18 @@ public void NullInequality() Expect(session.Query().Where(o => "test" != o.Component.OtherComponent.OtherProperty1), Does.Contain("is null").IgnoreCase); Expect(session.Query().Where(o => o.Component.OtherComponent.OtherProperty1 != "test"), Does.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CreatedBy.ModifiedBy.CreatedBy.Name != "test"), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" != o.CreatedBy.ModifiedBy.CreatedBy.Name), Does.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CreatedBy.CreatedBy.Component.OtherComponent.OtherProperty1 != "test"), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => "test" != o.CreatedBy.CreatedBy.Component.OtherComponent.OtherProperty1), Does.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.ModifiedBy.CreatedBy.Id != 5), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => 5 != o.ModifiedBy.CreatedBy.Id), Does.Contain("is null").IgnoreCase); + + Expect(session.Query().Where(o => o.CreatedBy.ModifiedBy.Id != 5), Does.Contain("is null").IgnoreCase); + Expect(session.Query().Where(o => 5 != o.CreatedBy.ModifiedBy.Id), Does.Contain("is null").IgnoreCase); } [Test] diff --git a/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs b/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs index 298c5efb737..140aa01002e 100644 --- a/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs +++ b/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs @@ -170,7 +170,7 @@ private bool IsNullable(MemberExpression memberExpression, BinaryExpression equa } var persister = _sessionFactory.GetEntityPersister(entityName); - if (persister.IsIdentifierMember(memberPath)) + if (persister.EntityMetamodel.GetIdentifierPropertyType(memberPath) != null) { return false; // Identifier is always not null } From f089c449aa75f686b42de23f19c7a3cf4656a588 Mon Sep 17 00:00:00 2001 From: maca88 Date: Fri, 4 Oct 2019 23:30:45 +0200 Subject: [PATCH 06/11] Fix TryGetEntityName for custom entity names --- src/NHibernate/Async/Persister/Entity/IEntityPersister.cs | 1 - src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/NHibernate/Async/Persister/Entity/IEntityPersister.cs b/src/NHibernate/Async/Persister/Entity/IEntityPersister.cs index b24810f3998..db1e4a16c1c 100644 --- a/src/NHibernate/Async/Persister/Entity/IEntityPersister.cs +++ b/src/NHibernate/Async/Persister/Entity/IEntityPersister.cs @@ -20,7 +20,6 @@ using System.Collections.Generic; using NHibernate.Intercept; using NHibernate.Util; -using System.Linq; namespace NHibernate.Persister.Entity { diff --git a/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs b/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs index 140aa01002e..d75d16fd1d9 100644 --- a/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs +++ b/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs @@ -163,8 +163,8 @@ private bool IsNullable(MemberExpression memberExpression, BinaryExpression equa } // We have to check the member mapping to determine if is nullable - var entityName = ExpressionsHelper.TryGetEntityName(_sessionFactory, memberExpression, out var memberPath); - if (entityName == null) + var entityName = ExpressionsHelper.TryGetEntityName(_sessionFactory, memberExpression, out var memberPath, out _); + if (entityName == null || memberPath == null) { return true; // Not mapped } From a0863a167dac3907305f2ac8b7aba6b5507d8cc7 Mon Sep 17 00:00:00 2001 From: maca88 Date: Sat, 5 Oct 2019 02:45:25 +0200 Subject: [PATCH 07/11] Fix some CodeFactor issues --- src/NHibernate.Test/Linq/NullComparisonTests.cs | 1 - .../Linq/Visitors/NullableExpressionDetector.cs | 15 ++++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/NHibernate.Test/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Linq/NullComparisonTests.cs index 910bd7cef2b..df2e8d3b06f 100644 --- a/src/NHibernate.Test/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Linq/NullComparisonTests.cs @@ -332,7 +332,6 @@ public void NullEqualityWithNotNullSubSelect() Expect(q, Does.Not.Contain("is null").IgnoreCase); } - [Test] public void NullEquality() { diff --git a/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs b/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs index d75d16fd1d9..c59bf131d44 100644 --- a/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs +++ b/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs @@ -41,11 +41,16 @@ public void SearchForNotNullMemberChecks(BinaryExpression expression) // Example: o.Status != null && o.Status != "New" // Example: (o.Status != null && o.OldStatus != null) && (o.Status != o.OldStatus) // Example: (o.Status != null && o.OldStatus != null) && (o.Status == o.OldStatus) - if (expression.NodeType != ExpressionType.AndAlso || - expression.Right.NodeType != ExpressionType.NotEqual && - expression.Right.NodeType != ExpressionType.Equal || - expression.Left.NodeType != ExpressionType.AndAlso && - expression.Left.NodeType != ExpressionType.NotEqual) + if ( + expression.NodeType != ExpressionType.AndAlso || + ( + expression.Right.NodeType != ExpressionType.NotEqual && + expression.Right.NodeType != ExpressionType.Equal + ) || + ( + expression.Left.NodeType != ExpressionType.AndAlso && + expression.Left.NodeType != ExpressionType.NotEqual + )) { return; } From 96e29c2b4369e7db1f5087f4ebaf24fa1147f334 Mon Sep 17 00:00:00 2001 From: maca88 Date: Sat, 5 Oct 2019 02:49:02 +0200 Subject: [PATCH 08/11] Regenerate async --- src/NHibernate.Test/Async/Linq/NullComparisonTests.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs index 70c49d4c0be..b3f07465c70 100644 --- a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs @@ -344,7 +344,6 @@ public async Task NullEqualityWithNotNullSubSelectAsync() await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); } - [Test] public async Task NullEqualityAsync() { From 29d205242b550a231fb7ca9ba3740ece0c79026d Mon Sep 17 00:00:00 2001 From: maca88 Date: Fri, 21 Feb 2020 10:12:44 +0100 Subject: [PATCH 09/11] Update code to use TryGetMappedNullability method --- .../Visitors/NullableExpressionDetector.cs | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs b/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs index c59bf131d44..674cc78daec 100644 --- a/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs +++ b/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs @@ -5,7 +5,6 @@ using NHibernate.Linq.Clauses; using NHibernate.Linq.Expressions; using NHibernate.Linq.Functions; -using NHibernate.Persister.Entity; using NHibernate.Util; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; @@ -167,23 +166,9 @@ private bool IsNullable(MemberExpression memberExpression, BinaryExpression equa return false; } - // We have to check the member mapping to determine if is nullable - var entityName = ExpressionsHelper.TryGetEntityName(_sessionFactory, memberExpression, out var memberPath, out _); - if (entityName == null || memberPath == null) + if (!ExpressionsHelper.TryGetMappedNullability(_sessionFactory, memberExpression, out var nullable) || nullable) { - return true; // Not mapped - } - - var persister = _sessionFactory.GetEntityPersister(entityName); - if (persister.EntityMetamodel.GetIdentifierPropertyType(memberPath) != null) - { - return false; // Identifier is always not null - } - - var index = persister.EntityMetamodel.GetPropertyIndexOrNull(memberPath); - if (!index.HasValue || persister.EntityMetamodel.PropertyNullability[index.Value]) - { - return true; // Not mapped or nullable + return true; // The expression contains one or many unsupported nodes or the member is nullable } return IsNullable(memberExpression.Expression, equalityExpression); From 9d0817033689a3354641909f52bfb391973c8031 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Delaporte?= <12201973+fredericdelaporte@users.noreply.github.com> Date: Sun, 23 Feb 2020 19:45:49 +0100 Subject: [PATCH 10/11] Adjust some details --- src/NHibernate.Test/Linq/NorthwindDbCreator.cs | 2 +- src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/NHibernate.Test/Linq/NorthwindDbCreator.cs b/src/NHibernate.Test/Linq/NorthwindDbCreator.cs index bfbb1505302..fa484be3a92 100644 --- a/src/NHibernate.Test/Linq/NorthwindDbCreator.cs +++ b/src/NHibernate.Test/Linq/NorthwindDbCreator.cs @@ -72,7 +72,7 @@ public static void CreateMiscTestData(ISession session) foreach (var user in users) { - user.CreatedBy = user; + user.CreatedBy = users[0]; } var timesheets = new[] diff --git a/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs b/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs index 674cc78daec..1e4d609449f 100644 --- a/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs +++ b/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs @@ -86,7 +86,9 @@ public bool IsNullable(Expression expression, BinaryExpression equalityExpressio case ExpressionType.Convert: case ExpressionType.ConvertChecked: case ExpressionType.TypeAs: - return IsNullable(((UnaryExpression) expression).Operand, equalityExpression); // a cast will not return null if the operand is not null + // a cast will not return null if the operand is not null (as long as TypeAs is not translated to + // try_convert in SQL). + return IsNullable(((UnaryExpression) expression).Operand, equalityExpression); case ExpressionType.Not: case ExpressionType.And: case ExpressionType.Or: From d25097161a36a31be58a779eab6c0048c96c4ab7 Mon Sep 17 00:00:00 2001 From: maca88 Date: Mon, 24 Feb 2020 20:07:29 +0100 Subject: [PATCH 11/11] Code review changes --- .../Async/Linq/NullComparisonTests.cs | 30 +++++ .../Linq/NullComparisonTests.cs | 30 +++++ .../Visitors/NullableExpressionDetector.cs | 116 +++++++++--------- 3 files changed, 121 insertions(+), 55 deletions(-) diff --git a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs index b3f07465c70..6a5c75091c7 100644 --- a/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Async/Linq/NullComparisonTests.cs @@ -102,18 +102,39 @@ public async Task NullInequalityWithNotNullAsync() q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value != o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value); await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value != o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value && o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.HasValue); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value != 0); await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value != 0 && o.NullableAnotherEntityRequiredId.HasValue); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue || o.NullableAnotherEntityRequiredId.Value != 0); await (ExpectAllAsync(q, Does.Contain("is null").IgnoreCase)); + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value != 0 || o.NullableAnotherEntityRequiredId.HasValue); + await (ExpectAllAsync(q, Does.Contain("is null").IgnoreCase)); + q = session.Query().Where(o => o.NullableOutput != null && o.NullableOutput != "test"); await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent, BothSame, OutputSet)); + q = session.Query().Where(o => o.NullableOutput != "test" && o.NullableOutput != null); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent, BothSame, OutputSet)); + q = session.Query().Where(o => o.NullableOutput != null || o.NullableOutput != "test"); await (ExpectAllAsync(q, Does.Contain("is null").IgnoreCase)); + q = session.Query().Where(o => o.NullableOutput != "test" || o.NullableOutput != null); + await (ExpectAllAsync(q, Does.Contain("is null").IgnoreCase)); + + q = session.Query().Where(o => o.NullableOutput != "test" && (o.NullableAnotherEntityRequiredId > 0 && o.NullableOutput != null)); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent, BothSame, OutputSet)); + + q = session.Query().Where(o => o.NullableOutput != null && (o.NullableAnotherEntityRequiredId > 0 && o.NullableOutput != "test")); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent, BothSame, OutputSet)); + q = session.Query().Where(o => o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value != o.NullableAnotherEntityRequiredId.Value); await (ExpectAsync(q, Does.Contain("or case").IgnoreCase)); @@ -253,12 +274,21 @@ public async Task NullEqualityWithNotNullAsync() q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value == o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value); await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value == o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value && o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.HasValue); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value == 0); await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value == 0 && o.NullableAnotherEntityRequiredId.HasValue); + await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue || o.NullableAnotherEntityRequiredId.Value == 0); await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value == 0 || o.NullableAnotherEntityRequiredId.HasValue); + await (ExpectAllAsync(q, Does.Not.Contain("is null").IgnoreCase)); + q = session.Query().Where(o => o.NullableOutput == "test"); await (ExpectAsync(q, Does.Not.Contain("is null").IgnoreCase)); diff --git a/src/NHibernate.Test/Linq/NullComparisonTests.cs b/src/NHibernate.Test/Linq/NullComparisonTests.cs index df2e8d3b06f..0ed569813f3 100644 --- a/src/NHibernate.Test/Linq/NullComparisonTests.cs +++ b/src/NHibernate.Test/Linq/NullComparisonTests.cs @@ -90,18 +90,39 @@ public void NullInequalityWithNotNull() q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value != o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value); Expect(q, Does.Not.Contain("is null").IgnoreCase); + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value != o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value && o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.HasValue); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value != 0); ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value != 0 && o.NullableAnotherEntityRequiredId.HasValue); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue || o.NullableAnotherEntityRequiredId.Value != 0); ExpectAll(q, Does.Contain("is null").IgnoreCase); + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value != 0 || o.NullableAnotherEntityRequiredId.HasValue); + ExpectAll(q, Does.Contain("is null").IgnoreCase); + q = session.Query().Where(o => o.NullableOutput != null && o.NullableOutput != "test"); Expect(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent, BothSame, OutputSet); + q = session.Query().Where(o => o.NullableOutput != "test" && o.NullableOutput != null); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent, BothSame, OutputSet); + q = session.Query().Where(o => o.NullableOutput != null || o.NullableOutput != "test"); ExpectAll(q, Does.Contain("is null").IgnoreCase); + q = session.Query().Where(o => o.NullableOutput != "test" || o.NullableOutput != null); + ExpectAll(q, Does.Contain("is null").IgnoreCase); + + q = session.Query().Where(o => o.NullableOutput != "test" && (o.NullableAnotherEntityRequiredId > 0 && o.NullableOutput != null)); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent, BothSame, OutputSet); + + q = session.Query().Where(o => o.NullableOutput != null && (o.NullableAnotherEntityRequiredId > 0 && o.NullableOutput != "test")); + Expect(q, Does.Not.Contain("is null").IgnoreCase, BothDifferent, BothSame, OutputSet); + q = session.Query().Where(o => o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value != o.NullableAnotherEntityRequiredId.Value); Expect(q, Does.Contain("or case").IgnoreCase); @@ -241,12 +262,21 @@ public void NullEqualityWithNotNull() q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value == o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value); ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value == o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.Value && o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequired.NullableAnotherEntityRequiredId.HasValue); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue && o.NullableAnotherEntityRequiredId.Value == 0); Expect(q, Does.Not.Contain("is null").IgnoreCase); + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value == 0 && o.NullableAnotherEntityRequiredId.HasValue); + Expect(q, Does.Not.Contain("is null").IgnoreCase); + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.HasValue || o.NullableAnotherEntityRequiredId.Value == 0); ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + q = session.Query().Where(o => o.NullableAnotherEntityRequiredId.Value == 0 || o.NullableAnotherEntityRequiredId.HasValue); + ExpectAll(q, Does.Not.Contain("is null").IgnoreCase); + q = session.Query().Where(o => o.NullableOutput == "test"); Expect(q, Does.Not.Contain("is null").IgnoreCase); diff --git a/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs b/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs index 1e4d609449f..d9e2d6c06f5 100644 --- a/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs +++ b/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs @@ -40,43 +40,25 @@ public void SearchForNotNullMemberChecks(BinaryExpression expression) // Example: o.Status != null && o.Status != "New" // Example: (o.Status != null && o.OldStatus != null) && (o.Status != o.OldStatus) // Example: (o.Status != null && o.OldStatus != null) && (o.Status == o.OldStatus) + // Example: o.Status != null && (o.OldStatus != null && o.Status == o.OldStatus) if ( - expression.NodeType != ExpressionType.AndAlso || + _equalityNotNullMembers.ContainsKey(expression) || + !IsAndOrAndAlso(expression) || ( - expression.Right.NodeType != ExpressionType.NotEqual && - expression.Right.NodeType != ExpressionType.Equal + !IsAndOrAndAlso(expression.Right) && + !IsEqualOrNotEqual(expression.Right) ) || ( - expression.Left.NodeType != ExpressionType.AndAlso && - expression.Left.NodeType != ExpressionType.NotEqual + !IsAndOrAndAlso(expression.Left) && + !IsEqualOrNotEqual(expression.Left) )) { return; } - // Skip if there are no member access expressions on the right side - var notEqualExpression = (BinaryExpression) expression.Right; - if (!IsMemberAccess(notEqualExpression.Left) && !IsMemberAccess(notEqualExpression.Right)) - { - return; - } - - var notNullMembers = new List(); - // We may have multiple conditions - // Example: o.Status != null && o.OldStatus != null - if (expression.Left.NodeType == ExpressionType.AndAlso) - { - FindAllNotNullMembers((BinaryExpression) expression.Left, notNullMembers); - } - else - { - FindNotNullMember((BinaryExpression) expression.Left, notNullMembers); - } - - if (notNullMembers.Count > 0) - { - _equalityNotNullMembers[notEqualExpression] = notNullMembers; - } + // Find all not null members and cache them for each binary expression that is found, + // in order to verify whether the member in a binary expression is nullable or not + FindAllNotNullMembers(expression, new List()); } public bool IsNullable(Expression expression, BinaryExpression equalityExpression) @@ -157,11 +139,8 @@ private bool IsNullable(MemberExpression memberExpression, BinaryExpression equa return IsNullable(memberExpression.Expression, equalityExpression); } - // Check if there was a not null check prior the equality expression - if (( - equalityExpression.NodeType == ExpressionType.NotEqual || - equalityExpression.NodeType == ExpressionType.Equal - ) && + // Check if there was a not null check prior or after the equality expression + if (IsEqualOrNotEqual(equalityExpression) && _equalityNotNullMembers.TryGetValue(equalityExpression, out var notNullMembers) && notNullMembers.Any(o => AreEqual(o, memberExpression))) { @@ -208,51 +187,66 @@ private bool IsNullableExtension(Expression extensionExpression, BinaryExpressio } } - private static bool IsMemberAccess(Expression expression) + private static bool TryGetMemberAccess(Expression expression, out MemberExpression memberExpression) { - if (expression.NodeType == ExpressionType.MemberAccess) + memberExpression = expression as MemberExpression; + if (memberExpression != null) { return true; } // Nullable members can be wrapped in a convert expression - return expression is UnaryExpression unaryExpression && unaryExpression.Operand.NodeType == ExpressionType.MemberAccess; + if (expression is UnaryExpression unaryExpression) + { + memberExpression = unaryExpression.Operand as MemberExpression; + } + + return memberExpression != null; } - private static void FindAllNotNullMembers(BinaryExpression andAlsoExpression, List notNullMembers) + private void FindAllNotNullMembers(Expression expression, List notNullMembers) { - if (andAlsoExpression.Right.NodeType == ExpressionType.NotEqual) - { - FindNotNullMember((BinaryExpression) andAlsoExpression.Right, notNullMembers); - } - else if (andAlsoExpression.Right.NodeType == ExpressionType.AndAlso) - { - FindAllNotNullMembers((BinaryExpression) andAlsoExpression.Right, notNullMembers); - } - else + if (!(expression is BinaryExpression binaryExpression)) { return; } - if (andAlsoExpression.Left.NodeType == ExpressionType.NotEqual) + // We may have multiple conditions + // Example: o.Status != null && o.OldStatus != null + // Example: o.Status != null && (o.OldStatus != null && o.Test > 0) + if (IsAndOrAndAlso(expression)) { - FindNotNullMember((BinaryExpression) andAlsoExpression.Left, notNullMembers); + FindAllNotNullMembers(binaryExpression, notNullMembers); } - else if (andAlsoExpression.Left.NodeType == ExpressionType.AndAlso) + else if (IsEqualOrNotEqual(expression)) { - FindAllNotNullMembers((BinaryExpression) andAlsoExpression.Left, notNullMembers); + FindNotNullMember(binaryExpression, notNullMembers); } } - private static void FindNotNullMember(BinaryExpression notEqualExpression, List notNullMembers) + private void FindAllNotNullMembers(BinaryExpression binaryExpression, List notNullMembers) { - if (notEqualExpression.Left.NodeType == ExpressionType.MemberAccess && VisitorUtil.IsNullConstant(notEqualExpression.Right)) + _equalityNotNullMembers.Add(binaryExpression, notNullMembers); + FindAllNotNullMembers(binaryExpression.Left, notNullMembers); + FindAllNotNullMembers(binaryExpression.Right, notNullMembers); + } + + private void FindNotNullMember(BinaryExpression binaryExpression, List notNullMembers) + { + _equalityNotNullMembers[binaryExpression] = notNullMembers; + if (binaryExpression.NodeType != ExpressionType.NotEqual) { - notNullMembers.Add((MemberExpression) notEqualExpression.Left); + return; } - else if (VisitorUtil.IsNullConstant(notEqualExpression.Left) && notEqualExpression.Right.NodeType == ExpressionType.MemberAccess) + + MemberExpression memberExpression; + if (VisitorUtil.IsNullConstant(binaryExpression.Right) && TryGetMemberAccess(binaryExpression.Left, out memberExpression)) { - notNullMembers.Add((MemberExpression) notEqualExpression.Right); + notNullMembers.Add(memberExpression); + } + else if (VisitorUtil.IsNullConstant(binaryExpression.Left) && TryGetMemberAccess(binaryExpression.Right, out memberExpression)) + { + notNullMembers.Add(memberExpression); } } @@ -286,5 +280,17 @@ private static bool AreEqual(MemberExpression memberExpression, MemberExpression return memberExpression.Expression == otherMemberExpression.Expression; } } + + private static bool IsAndOrAndAlso(Expression expression) + { + return expression.NodeType == ExpressionType.And || + expression.NodeType == ExpressionType.AndAlso; + } + + private static bool IsEqualOrNotEqual(Expression expression) + { + return expression.NodeType == ExpressionType.Equal || + expression.NodeType == ExpressionType.NotEqual; + } } }