From b15be9cfb54db9b4ba575dc9dec082fe865e7708 Mon Sep 17 00:00:00 2001 From: maca88 Date: Sun, 26 Apr 2020 22:44:53 +0200 Subject: [PATCH 01/11] Add Linq parameter type detection --- .../Northwind/Entities/User.cs | 6 + .../Northwind/Mappings/User.hbm.xml | 6 + src/NHibernate.Test/Async/Linq/EnumTests.cs | 37 ++ src/NHibernate.Test/Linq/ConstantTest.cs | 8 +- .../Linq/ConstantTypeLocatorTests.cs | 398 ++++++++++++++++++ src/NHibernate.Test/Linq/EnumTests.cs | 37 ++ src/NHibernate.Test/Linq/TryGetMappedTests.cs | 3 +- .../NHSpecificTest/GH1526/Fixture.cs | 2 +- .../Async/Linq/DefaultQueryProvider.cs | 3 +- src/NHibernate/Driver/OdbcDriver.cs | 12 +- src/NHibernate/Driver/SqlClientDriver.cs | 20 +- src/NHibernate/Driver/SqlServerCeDriver.cs | 6 + .../Ast/ANTLR/ASTQueryTranslatorFactory.cs | 16 +- src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs | 43 +- .../Hql/Ast/ANTLR/QueryTranslatorImpl.cs | 44 +- .../Hql/Ast/ANTLR/Tree/BetweenOperatorNode.cs | 32 +- .../Tree/BinaryArithmeticOperatorNode.cs | 52 +-- .../Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs | 11 +- .../Hql/Ast/ANTLR/Tree/InLogicOperatorNode.cs | 5 +- src/NHibernate/Impl/AbstractQueryImpl.cs | 82 +--- src/NHibernate/Linq/DefaultQueryProvider.cs | 36 +- src/NHibernate/Linq/NhLinqExpression.cs | 30 +- .../Linq/ReWriters/AddJoinsReWriter.cs | 4 +- .../Linq/Visitors/ConstantTypeLocator.cs | 293 +++++++++++++ .../Linq/Visitors/ExpressionKeyVisitor.cs | 121 ++++-- .../Visitors/ExpressionParameterVisitor.cs | 66 +-- .../Visitors/MemberExpressionJoinDetector.cs | 6 +- src/NHibernate/Linq/Visitors/VisitorUtil.cs | 6 + .../Linq/Visitors/WhereJoinDetector.cs | 6 +- src/NHibernate/Param/NamedListParameter.cs | 13 + src/NHibernate/Param/NamedParameter.cs | 4 +- src/NHibernate/Util/ParameterHelper.cs | 166 ++++++++ src/NHibernate/Util/ReflectionCache.cs | 10 + 33 files changed, 1340 insertions(+), 244 deletions(-) create mode 100644 src/NHibernate.Test/Linq/ConstantTypeLocatorTests.cs create mode 100644 src/NHibernate/Linq/Visitors/ConstantTypeLocator.cs create mode 100644 src/NHibernate/Param/NamedListParameter.cs create mode 100644 src/NHibernate/Util/ParameterHelper.cs diff --git a/src/NHibernate.DomainModel/Northwind/Entities/User.cs b/src/NHibernate.DomainModel/Northwind/Entities/User.cs index c23e667be9b..14096dac912 100644 --- a/src/NHibernate.DomainModel/Northwind/Entities/User.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/User.cs @@ -48,10 +48,16 @@ public class User : IUser, IEntity public virtual FeatureSet Features { get; set; } + public virtual User NotMappedUser => this; + public virtual EnumStoredAsString Enum1 { get; set; } + public virtual EnumStoredAsString? NullableEnum1 { get; set; } + public virtual EnumStoredAsInt32 Enum2 { get; set; } + public virtual EnumStoredAsInt32? NullableEnum2 { get; set; } + public virtual IUser CreatedBy { get; set; } public virtual IUser ModifiedBy { get; set; } diff --git a/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml b/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml index 2764cb70898..f249de9574e 100644 --- a/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml +++ b/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml @@ -24,8 +24,14 @@ + + + + + diff --git a/src/NHibernate.Test/Async/Linq/EnumTests.cs b/src/NHibernate.Test/Async/Linq/EnumTests.cs index 622a806ed30..6e9355d294c 100644 --- a/src/NHibernate.Test/Async/Linq/EnumTests.cs +++ b/src/NHibernate.Test/Async/Linq/EnumTests.cs @@ -61,5 +61,42 @@ public async Task CanQueryOnEnumStoredAsString_Small_1Async() Assert.AreEqual(expectedCount, query.Count); } + + [Test] + public async Task ConditionalNavigationPropertyAsync() + { + EnumStoredAsString? type = null; + await (db.Users.Where(o => o.Enum1 == EnumStoredAsString.Large).ToListAsync()); + await (db.Users.Where(o => EnumStoredAsString.Large != o.Enum1).ToListAsync()); + await (db.Users.Where(o => (o.NullableEnum1 ?? EnumStoredAsString.Large) == EnumStoredAsString.Medium).ToListAsync()); + await (db.Users.Where(o => ((o.NullableEnum1 ?? type) ?? o.Enum1) == EnumStoredAsString.Medium).ToListAsync()); + + await (db.Users.Where(o => (o.NullableEnum1.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) == EnumStoredAsString.Medium).ToListAsync()); + await (db.Users.Where(o => (o.Enum1 != EnumStoredAsString.Large + ? (o.NullableEnum1.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) + : EnumStoredAsString.Small) == EnumStoredAsString.Medium).ToListAsync()); + + await (db.Users.Where(o => (o.Enum1 == EnumStoredAsString.Large ? o.Role : o.Role).Name == "test").ToListAsync()); + } + + [Test] + public async Task CanQueryComplexExpressionOnEnumStoredAsStringAsync() + { + var type = EnumStoredAsString.Unspecified; + var query = await ((from user in db.Users + where (user.NullableEnum1 == EnumStoredAsString.Large + ? EnumStoredAsString.Medium + : user.NullableEnum1 ?? user.Enum1 + ) == type + select new + { + user, + simple = user.Enum1, + condition = user.Enum1 == EnumStoredAsString.Large ? EnumStoredAsString.Medium : user.Enum1, + coalesce = user.NullableEnum1 ?? EnumStoredAsString.Medium + }).ToListAsync()); + + Assert.That(query.Count, Is.EqualTo(0)); + } } } diff --git a/src/NHibernate.Test/Linq/ConstantTest.cs b/src/NHibernate.Test/Linq/ConstantTest.cs index 6b693ddbc4a..f75bf3a9070 100644 --- a/src/NHibernate.Test/Linq/ConstantTest.cs +++ b/src/NHibernate.Test/Linq/ConstantTest.cs @@ -217,12 +217,12 @@ public void ConstantInWhereDoesNotCauseManyKeys() select c); var preTransformParameters = new PreTransformationParameters(QueryMode.Select, Sfi); var preTransformResult = NhRelinqQueryParser.PreTransform(q1.Expression, preTransformParameters); - var expression = ExpressionParameterVisitor.Visit(preTransformResult, out var parameters1); - var k1 = ExpressionKeyVisitor.Visit(expression, parameters1); + var parameters1 = ExpressionParameterVisitor.Visit(preTransformResult); + var k1 = ExpressionKeyVisitor.Visit(preTransformResult.Expression, parameters1, Sfi); var preTransformResult2 = NhRelinqQueryParser.PreTransform(q2.Expression, preTransformParameters); - var expression2 = ExpressionParameterVisitor.Visit(preTransformResult2, out var parameters2); - var k2 = ExpressionKeyVisitor.Visit(expression2, parameters2); + var parameters2 = ExpressionParameterVisitor.Visit(preTransformResult2); + var k2 = ExpressionKeyVisitor.Visit(preTransformResult2.Expression, parameters2, Sfi); Assert.That(parameters1, Has.Count.GreaterThan(0), "parameters1"); Assert.That(parameters2, Has.Count.GreaterThan(0), "parameters2"); diff --git a/src/NHibernate.Test/Linq/ConstantTypeLocatorTests.cs b/src/NHibernate.Test/Linq/ConstantTypeLocatorTests.cs new file mode 100644 index 00000000000..4e21a078ebe --- /dev/null +++ b/src/NHibernate.Test/Linq/ConstantTypeLocatorTests.cs @@ -0,0 +1,398 @@ +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Linq.Expressions; +using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Engine.Query; +using NHibernate.Linq; +using NHibernate.Linq.Visitors; +using NHibernate.Type; +using NUnit.Framework; +using Remotion.Linq.Clauses; + +namespace NHibernate.Test.Linq +{ + public class ConstantTypeLocatorTests : LinqTestCase + { + [Test] + public void AddIntegerTest() + { + AssertResults( + new Dictionary> + { + {"2.1", o => o is DoubleType}, + {"5", o => o is Int32Type}, + }, + db.Users.Where(o => o.Id + 5 > 2.1), + db.Users.Where(o => 2.1 < 5 + o.Id) + ); + } + + [Test] + public void AddDecimalTest() + { + AssertResults( + new Dictionary> + { + {"2.1", o => o is DecimalType}, + {"5.2", o => o is DecimalType}, + }, + db.Users.Where(o => o.Id + 5.2m > 2.1m), + db.Users.Where(o => 2.1m < 5.2m + o.Id) + ); + } + + [Test] + public void SubtractFloatTest() + { + AssertResults( + new Dictionary> + { + {"2.1", o => o is DoubleType}, + {"5.2", o => o is SingleType}, + }, + db.Users.Where(o => o.Id - 5.2f > 2.1), + db.Users.Where(o => 2.1 < 5.2f - o.Id) + ); + } + + [Test] + public void GreaterThanTest() + { + AssertResults( + new Dictionary> + { + {"2.1", o => o is Int32Type} + }, + db.Users.Where(o => o.Id > 2.1), + db.Users.Where(o => 2.1 > o.Id) + ); + } + + [Test] + public void EqualStringEnumTest() + { + AssertResults( + new Dictionary> + { + {"3", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => o.Enum1 == EnumStoredAsString.Large), + db.Users.Where(o => EnumStoredAsString.Large == o.Enum1) + ); + } + + [Test] + public void EqualStringTest() + { + AssertResults( + new Dictionary> + { + {"\"London\"", o => o is StringType stringType && stringType.SqlType.Length == 15} + }, + db.Orders.Where(o => o.ShippingAddress.City == "London"), + db.Orders.Where(o => "London" == o.ShippingAddress.City) + ); + } + + [Test] + public void DoubleEqualTest() + { + AssertResults( + new Dictionary> + { + {"3", o => o is EnumStoredAsStringType}, + {"1", o => o is PersistentEnumType} + }, + db.Users.Where(o => o.Enum1 == EnumStoredAsString.Large && o.Enum2 == EnumStoredAsInt32.High), + db.Users.Where(o => EnumStoredAsInt32.High == o.Enum2 && EnumStoredAsString.Large == o.Enum1) + ); + } + + [Test] + public void NotEqualTest() + { + AssertResults( + new Dictionary> + { + {"3", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => o.Enum1 != EnumStoredAsString.Large), + db.Users.Where(o => EnumStoredAsString.Large != o.Enum1) + ); + } + + [Test] + public void DoubleNotEqualTest() + { + AssertResults( + new Dictionary> + { + {"3", o => o is EnumStoredAsStringType}, + {"1", o => o is PersistentEnumType} + }, + db.Users.Where(o => o.Enum1 != EnumStoredAsString.Large || o.NullableEnum2 != EnumStoredAsInt32.High), + db.Users.Where(o => EnumStoredAsInt32.High != o.NullableEnum2 || o.Enum1 != EnumStoredAsString.Large) + ); + } + + [Test] + public void CoalesceTest() + { + AssertResults( + new Dictionary> + { + {"2", o => o is EnumStoredAsStringType}, + {"Large", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => (o.NullableEnum1 ?? EnumStoredAsString.Large) == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == (o.NullableEnum1 ?? EnumStoredAsString.Large)) + ); + } + + [Test] + public void DoubleCoalesceTest() + { + AssertResults( + new Dictionary> + { + {"2", o => o is EnumStoredAsStringType}, + {"Large", o => o is EnumStoredAsStringType}, + }, + db.Users.Where(o => ((o.NullableEnum1 ?? (EnumStoredAsString?) EnumStoredAsString.Large) ?? o.Enum1) == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == ((o.NullableEnum1 ?? (EnumStoredAsString?) EnumStoredAsString.Large) ?? o.Enum1)) + ); + } + + [Test] + public void ConditionalTest() + { + AssertResults( + new Dictionary> + { + {"2", o => o is EnumStoredAsStringType}, + {"Unspecified", o => o is EnumStoredAsStringType}, + {"null", o => o is PersistentEnumType}, // HasValue + }, + db.Users.Where(o => (o.NullableEnum2.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == (o.NullableEnum2.HasValue ? EnumStoredAsString.Unspecified : o.Enum1)) + ); + } + + [Test] + public void DoubleConditionalTest() + { + AssertResults( + new Dictionary> + { + {"0", o => o is PersistentEnumType}, + {"2", o => o is EnumStoredAsStringType}, + {"Small", o => o is EnumStoredAsStringType}, + {"Unspecified", o => o is EnumStoredAsStringType}, + {"null", o => o is PersistentEnumType}, // HasValue + }, + db.Users.Where(o => (o.Enum2 != EnumStoredAsInt32.Unspecified + ? (o.NullableEnum2.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) + : EnumStoredAsString.Small) == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == (o.Enum2 != EnumStoredAsInt32.Unspecified + ? EnumStoredAsString.Small + : (o.NullableEnum2.HasValue ? EnumStoredAsString.Unspecified : o.Enum1))) + ); + } + + [Test] + public void CoalesceMemberTest() + { + AssertResults( + new Dictionary> + { + {"2", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => (o.NotMappedUser ?? o).Enum1 == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == (o ?? o.NotMappedUser).Enum1) + ); + } + + [Test] + public void ConditionalMemberTest() + { + AssertResults( + new Dictionary> + { + {"2", o => o is EnumStoredAsStringType}, + {"\"test\"", o => o is AnsiStringType}, + }, + db.Users.Where(o => (o.Name == "test" ? o.NotMappedUser : o).Enum1 == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == (o.Name == "test" ? o : o.NotMappedUser).Enum1) + ); + } + + + [Test] + public void AssignMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"val\"", o => o is AnsiStringType}, + {"Large", o => o is EnumStoredAsStringType}, + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new User {Name = "val", Enum1 = EnumStoredAsString.Large} + ); + } + + [Test] + public void AssignComponentMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"prop1\"", o => o is AnsiStringType} + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new User {Component = new UserComponent {Property1 = "prop1"}} + ); + } + + [Test] + public void AssignNestedComponentMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"other\"", o => o is AnsiStringType} + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new User + { + Component = new UserComponent {OtherComponent = new UserComponent2 {OtherProperty1 = "other"}} + } + ); + } + + [Test] + public void AnonymousAssignMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"val\"", o => o is AnsiStringType}, + {"Large", o => o is EnumStoredAsStringType}, + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new {Name = "val", Enum1 = EnumStoredAsString.Large} + ); + } + + [Test] + public void AnonymousAssignComponentMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"prop1\"", o => o is AnsiStringType} + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new {Component = new {Property1 = "prop1"}} + ); + } + + [Test] + public void AnonymousAssignNestedComponentMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"other\"", o => o is AnsiStringType} + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new {Component = new {OtherComponent = new {OtherProperty1 = "other"}}} + ); + } + + private void AssertResults( + Dictionary> expectedResults, + params IQueryable[] queries) + { + foreach (var query in queries) + { + AssertResult(expectedResults, query); + } + } + + private void AssertResult( + Dictionary> expectedResults, + IQueryable query) + { + AssertResult(expectedResults, QueryMode.Select, query.Expression, query.Expression.Type); + } + + private void AssertResult( + Dictionary> expectedResults, + QueryMode queryMode, + IQueryable query, + Expression> expression) + { + var dmlExpression = expression != null + ? DmlExpressionRewriter.PrepareExpression(query.Expression, expression) + : query.Expression; + + AssertResult(expectedResults, queryMode, dmlExpression, typeof(T)); + } + + private void AssertResult( + Dictionary> expectedResults, + QueryMode queryMode, + IQueryable query, + Expression> expression) + { + var dmlExpression = expression != null + ? DmlExpressionRewriter.PrepareExpressionFromAnonymous(query.Expression, expression) + : query.Expression; + + AssertResult(expectedResults, queryMode, dmlExpression, typeof(T)); + } + + private void AssertResult( + Dictionary> expectedResults, + QueryMode queryMode, + Expression expression, + System.Type targetType) + { + var result = NhRelinqQueryParser.PreTransform(expression, new PreTransformationParameters(queryMode, Sfi)); + expression = result.Expression; + var queryModel = NhRelinqQueryParser.Parse(expression); + var types = ConstantTypeLocator.GetTypes(queryModel, targetType, Sfi); + Assert.That(types.Count, Is.EqualTo(expectedResults.Count), "Incorrect number of constants"); + foreach (var pair in types) + { + var origCulture = CultureInfo.CurrentCulture; + try + { + CultureInfo.CurrentCulture = CultureInfo.InvariantCulture; + var expressionText = pair.Key.ToString(); + Assert.That(expectedResults.ContainsKey(expressionText), Is.True, $"{expressionText} constant is not expected"); + Assert.That(expectedResults[expressionText](pair.Value), Is.True, $"Invalid type, actual type: {pair.Value?.Name ?? "null"}"); + } + finally + { + CultureInfo.CurrentCulture = origCulture; + } + } + } + } +} diff --git a/src/NHibernate.Test/Linq/EnumTests.cs b/src/NHibernate.Test/Linq/EnumTests.cs index 4050c7ddb97..aeea060b51e 100644 --- a/src/NHibernate.Test/Linq/EnumTests.cs +++ b/src/NHibernate.Test/Linq/EnumTests.cs @@ -48,5 +48,42 @@ public void CanQueryOnEnumStoredAsString(EnumStoredAsString type, int expectedCo Assert.AreEqual(expectedCount, query.Count); } + + [Test] + public void ConditionalNavigationProperty() + { + EnumStoredAsString? type = null; + db.Users.Where(o => o.Enum1 == EnumStoredAsString.Large).ToList(); + db.Users.Where(o => EnumStoredAsString.Large != o.Enum1).ToList(); + db.Users.Where(o => (o.NullableEnum1 ?? EnumStoredAsString.Large) == EnumStoredAsString.Medium).ToList(); + db.Users.Where(o => ((o.NullableEnum1 ?? type) ?? o.Enum1) == EnumStoredAsString.Medium).ToList(); + + db.Users.Where(o => (o.NullableEnum1.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) == EnumStoredAsString.Medium).ToList(); + db.Users.Where(o => (o.Enum1 != EnumStoredAsString.Large + ? (o.NullableEnum1.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) + : EnumStoredAsString.Small) == EnumStoredAsString.Medium).ToList(); + + db.Users.Where(o => (o.Enum1 == EnumStoredAsString.Large ? o.Role : o.Role).Name == "test").ToList(); + } + + [Test] + public void CanQueryComplexExpressionOnEnumStoredAsString() + { + var type = EnumStoredAsString.Unspecified; + var query = (from user in db.Users + where (user.NullableEnum1 == EnumStoredAsString.Large + ? EnumStoredAsString.Medium + : user.NullableEnum1 ?? user.Enum1 + ) == type + select new + { + user, + simple = user.Enum1, + condition = user.Enum1 == EnumStoredAsString.Large ? EnumStoredAsString.Medium : user.Enum1, + coalesce = user.NullableEnum1 ?? EnumStoredAsString.Medium + }).ToList(); + + Assert.That(query.Count, Is.EqualTo(0)); + } } } diff --git a/src/NHibernate.Test/Linq/TryGetMappedTests.cs b/src/NHibernate.Test/Linq/TryGetMappedTests.cs index 20610d32bad..880ce2e3e69 100644 --- a/src/NHibernate.Test/Linq/TryGetMappedTests.cs +++ b/src/NHibernate.Test/Linq/TryGetMappedTests.cs @@ -774,7 +774,8 @@ private void AssertResult( var expression = query.Expression; var preTransformResult = NhRelinqQueryParser.PreTransform(expression, new PreTransformationParameters(QueryMode.Select, Sfi)); - expression = ExpressionParameterVisitor.Visit(preTransformResult, out var constantToParameterMap); + expression = preTransformResult.Expression; + var constantToParameterMap = ExpressionParameterVisitor.Visit(preTransformResult); var queryModel = NhRelinqQueryParser.Parse(expression); var requiredHqlParameters = new List(); var visitorParameters = new VisitorParameters( diff --git a/src/NHibernate.Test/NHSpecificTest/GH1526/Fixture.cs b/src/NHibernate.Test/NHSpecificTest/GH1526/Fixture.cs index 5318d771ec2..ee2b82c02f9 100644 --- a/src/NHibernate.Test/NHSpecificTest/GH1526/Fixture.cs +++ b/src/NHibernate.Test/NHSpecificTest/GH1526/Fixture.cs @@ -71,7 +71,7 @@ public void ShouldCreateDifferentKeys_TypeBinaryExpression() private static string GetCacheKey(Expression exp) { - return ExpressionKeyVisitor.Visit(exp, new Dictionary()); + return ExpressionKeyVisitor.Visit(exp, new Dictionary(), null); } } } diff --git a/src/NHibernate/Async/Linq/DefaultQueryProvider.cs b/src/NHibernate/Async/Linq/DefaultQueryProvider.cs index 4d0344a27eb..2cc4b8863b4 100644 --- a/src/NHibernate/Async/Linq/DefaultQueryProvider.cs +++ b/src/NHibernate/Async/Linq/DefaultQueryProvider.cs @@ -21,6 +21,7 @@ using NHibernate.Util; using System.Threading.Tasks; using NHibernate.Multi; +using NHibernate.Param; namespace NHibernate.Linq { @@ -103,7 +104,7 @@ public Task ExecuteDmlAsync(QueryMode queryMode, Expression expression, var query = Session.CreateQuery(nhLinqExpression); - SetParameters(query, nhLinqExpression.ParameterValuesByName); + SetParameters(query, nhLinqExpression.NamedParameters); _options?.Apply(query); return query.ExecuteUpdateAsync(cancellationToken); } diff --git a/src/NHibernate/Driver/OdbcDriver.cs b/src/NHibernate/Driver/OdbcDriver.cs index cf8df041cea..5ac80336849 100644 --- a/src/NHibernate/Driver/OdbcDriver.cs +++ b/src/NHibernate/Driver/OdbcDriver.cs @@ -78,10 +78,18 @@ private void SetVariableLengthParameterSize(DbParameter dbParam, SqlType sqlType { switch (dbParam.DbType) { - case DbType.AnsiString: + case DbType.StringFixedLength: case DbType.AnsiStringFixedLength: + // For types that are using one character (CharType, AnsiCharType, TrueFalseType, YesNoType and EnumCharType), + // we have to specify the length otherwise sql function like charindex won't work as expected. + if (sqlType.Length == 1) + { + dbParam.Size = sqlType.Length; + } + + break; case DbType.String: - case DbType.StringFixedLength: + case DbType.AnsiString: // NH-4083: do not limit to column length if above 2000. Setting size may trigger conversion from // nvarchar to ntext when size is superior or equal to 2000, causing some queries to fail: // https://stackoverflow.com/q/8569844/1178314 diff --git a/src/NHibernate/Driver/SqlClientDriver.cs b/src/NHibernate/Driver/SqlClientDriver.cs index 682d755472a..ec876494370 100644 --- a/src/NHibernate/Driver/SqlClientDriver.cs +++ b/src/NHibernate/Driver/SqlClientDriver.cs @@ -161,7 +161,9 @@ protected override void InitializeParameter(DbParameter dbParam, string name, Sq { case DbType.AnsiString: case DbType.AnsiStringFixedLength: - dbParam.Size = IsAnsiText(dbParam, sqlType) ? MsSql2000Dialect.MaxSizeForAnsiClob : MsSql2000Dialect.MaxSizeForLengthLimitedAnsiString; + dbParam.Size = IsAnsiText(dbParam, sqlType) + ? MsSql2000Dialect.MaxSizeForAnsiClob + : IsChar(dbParam, sqlType) ? sqlType.Length : MsSql2000Dialect.MaxSizeForLengthLimitedAnsiString; break; case DbType.Binary: dbParam.Size = IsBlob(dbParam, sqlType) ? MsSql2000Dialect.MaxSizeForBlob : MsSql2000Dialect.MaxSizeForLengthLimitedBinary; @@ -174,7 +176,9 @@ protected override void InitializeParameter(DbParameter dbParam, string name, Sq break; case DbType.String: case DbType.StringFixedLength: - dbParam.Size = IsText(dbParam, sqlType) ? MsSql2000Dialect.MaxSizeForClob : MsSql2000Dialect.MaxSizeForLengthLimitedString; + dbParam.Size = IsText(dbParam, sqlType) + ? MsSql2000Dialect.MaxSizeForClob + : IsChar(dbParam, sqlType) ? sqlType.Length : MsSql2000Dialect.MaxSizeForLengthLimitedString; break; case DbType.DateTime2: dbParam.Size = MsSql2000Dialect.MaxDateTime2; @@ -283,6 +287,18 @@ protected static bool IsBlob(DbParameter dbParam, SqlType sqlType) return (sqlType is BinaryBlobSqlType) || ((DbType.Binary == dbParam.DbType) && sqlType.LengthDefined && (sqlType.Length > MsSql2000Dialect.MaxSizeForLengthLimitedBinary)); } + /// + /// Interprets if a parameter is a character (for the purposes of setting its default size) + /// + /// The parameter + /// The of the parameter + /// True, if the parameter should be interpreted as a character, otherwise False + protected static bool IsChar(DbParameter dbParam, SqlType sqlType) + { + return (DbType.StringFixedLength == dbParam.DbType || DbType.StringFixedLength == dbParam.DbType) && + sqlType.LengthDefined && sqlType.Length == 1; + } + public override IResultSetsCommand GetResultSetsCommand(ISessionImplementor session) { return new BasicResultSetsCommand(session); diff --git a/src/NHibernate/Driver/SqlServerCeDriver.cs b/src/NHibernate/Driver/SqlServerCeDriver.cs index eb4f03316ea..0b6a4ad93bc 100644 --- a/src/NHibernate/Driver/SqlServerCeDriver.cs +++ b/src/NHibernate/Driver/SqlServerCeDriver.cs @@ -75,6 +75,12 @@ public override IResultSetsCommand GetResultSetsCommand(Engine.ISessionImplement protected override void InitializeParameter(DbParameter dbParam, string name, SqlType sqlType) { base.InitializeParameter(dbParam, name, AdjustSqlType(sqlType)); + // For types that are using one character (CharType, AnsiCharType, TrueFalseType, YesNoType and EnumCharType), + // we have to specify the length otherwise sql function like charindex won't work as expected. + if (sqlType.LengthDefined && sqlType.Length == 1) + { + dbParam.Size = sqlType.Length; + } AdjustDbParamTypeForLargeObjects(dbParam, sqlType); } diff --git a/src/NHibernate/Hql/Ast/ANTLR/ASTQueryTranslatorFactory.cs b/src/NHibernate/Hql/Ast/ANTLR/ASTQueryTranslatorFactory.cs index 7b9e937bd18..e7b95eed2cb 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/ASTQueryTranslatorFactory.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/ASTQueryTranslatorFactory.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using NHibernate.Engine; using NHibernate.Hql.Ast.ANTLR.Tree; +using NHibernate.Linq; using NHibernate.Util; namespace NHibernate.Hql.Ast.ANTLR @@ -16,15 +17,24 @@ public class ASTQueryTranslatorFactory : IQueryTranslatorFactory { public IQueryTranslator[] CreateQueryTranslators(IQueryExpression queryExpression, string collectionRole, bool shallow, IDictionary filters, ISessionFactoryImplementor factory) { - return CreateQueryTranslators(queryExpression.Translate(factory, collectionRole != null), queryExpression.Key, collectionRole, shallow, filters, factory); + return CreateQueryTranslators(queryExpression, queryExpression.Translate(factory, collectionRole != null), queryExpression.Key, collectionRole, shallow, filters, factory); } - static IQueryTranslator[] CreateQueryTranslators(IASTNode ast, string queryIdentifier, string collectionRole, bool shallow, IDictionary filters, ISessionFactoryImplementor factory) + static IQueryTranslator[] CreateQueryTranslators( + IQueryExpression queryExpression, + IASTNode ast, + string queryIdentifier, + string collectionRole, + bool shallow, + IDictionary filters, + ISessionFactoryImplementor factory) { var polymorphicParsers = AstPolymorphicProcessor.Process(ast, factory); var translators = polymorphicParsers - .ToArray(hql => new QueryTranslatorImpl(queryIdentifier, hql, filters, factory)); + .ToArray(hql => queryExpression is NhLinqExpression linqExpression + ? new QueryTranslatorImpl(queryIdentifier, hql, filters, factory, linqExpression.NamedParameters) + : new QueryTranslatorImpl(queryIdentifier, hql, filters, factory)); foreach (var translator in translators) { diff --git a/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs b/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs index 37e5eaffc6e..d6f3a7f0861 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs @@ -36,7 +36,7 @@ public partial class HqlSqlWalker private string _statementTypeName; private int _positionalParameterCount; private int _parameterCount; - private readonly NullableDictionary _namedParameters = new NullableDictionary(); + private readonly NullableDictionary _namedParameterLocations = new NullableDictionary(); private readonly List _parameters = new List(); private FromClause _currentFromClause; private SelectClause _selectClause; @@ -54,6 +54,7 @@ public partial class HqlSqlWalker private readonly LiteralProcessor _literalProcessor; private readonly IDictionary _tokenReplacements; + private readonly IDictionary _namedParameters; private JoinType _impliedJoinType; @@ -64,17 +65,30 @@ public partial class HqlSqlWalker private int numberOfParametersInSetClause; private Stack clauseStack=new Stack(); - public HqlSqlWalker(QueryTranslatorImpl qti, - ISessionFactoryImplementor sfi, - ITreeNodeStream input, - IDictionary tokenReplacements, - string collectionRole) + public HqlSqlWalker( + QueryTranslatorImpl qti, + ISessionFactoryImplementor sfi, + ITreeNodeStream input, + IDictionary tokenReplacements, + string collectionRole) + : this(qti, sfi, input, tokenReplacements, null, collectionRole) + { + } + + internal HqlSqlWalker( + QueryTranslatorImpl qti, + ISessionFactoryImplementor sfi, + ITreeNodeStream input, + IDictionary tokenReplacements, + IDictionary namedParameters, + string collectionRole) : this(input) { _sessionFactoryHelper = new SessionFactoryHelperExtensions(sfi); _qti = qti; _literalProcessor = new LiteralProcessor(this); _tokenReplacements = tokenReplacements; + _namedParameters = namedParameters; _collectionFilterRole = collectionRole; } @@ -122,7 +136,7 @@ public ISet QuerySpaces public IDictionary NamedParameters { - get { return _namedParameters; } + get { return _namedParameterLocations; } } internal SessionFactoryHelperExtensions SessionFactoryHelper @@ -1033,13 +1047,20 @@ IASTNode GenerateNamedParameter(IASTNode delimiterNode, IASTNode nameNode) ); parameter.HqlParameterSpecification = paramSpec; + if (_namedParameters != null && _namedParameters.TryGetValue(name, out var namedParameter)) + { + // Add the parameter type information so that we are able to calculate functions return types + // when the parameter is used as an argument. + parameter.ExpectedType = namedParameter.Type; + } + _parameters.Add(paramSpec); return parameter; } IASTNode GeneratePositionalParameter(IASTNode inputNode) { - if (_namedParameters.Count > 0) + if (_namedParameterLocations.Count > 0) { // NH TODO: remove this limitation throw new SemanticException("cannot define positional parameter after any named parameters have been defined"); @@ -1171,15 +1192,15 @@ public void AddQuerySpaces(string[] spaces) private void TrackNamedParameterPositions(string name) { int loc = _parameterCount++; - object o = _namedParameters[name]; + object o = _namedParameterLocations[name]; if ( o == null ) { - _namedParameters.Add(name, loc); + _namedParameterLocations.Add(name, loc); } else if (o is int) { List list = new List(4) {(int) o, loc}; - _namedParameters[name] = list; + _namedParameterLocations[name] = list; } else { diff --git a/src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs b/src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs index 2e27559d1dd..bcf3dc14e11 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs @@ -29,7 +29,8 @@ public partial class QueryTranslatorImpl : IFilterTranslator private readonly string _queryIdentifier; private readonly IASTNode _stageOneAst; private readonly ISessionFactoryImplementor _factory; - + private readonly IDictionary _namedParameters; + private bool _shallowQuery; private bool _compiled; private IDictionary _enabledFilters; @@ -47,10 +48,28 @@ public partial class QueryTranslatorImpl : IFilterTranslator /// Currently enabled filters /// The session factory constructing this translator instance. public QueryTranslatorImpl( - string queryIdentifier, - IASTNode parsedQuery, - IDictionary enabledFilters, - ISessionFactoryImplementor factory) + string queryIdentifier, + IASTNode parsedQuery, + IDictionary enabledFilters, + ISessionFactoryImplementor factory) + : this(queryIdentifier, parsedQuery, enabledFilters, factory, null) + { + } + + /// + /// Creates a new AST-based query translator. + /// + /// The query-identifier (used in stats collection) + /// The hql query to translate + /// Currently enabled filters + /// The session factory constructing this translator instance. + /// The named parameters information. + internal QueryTranslatorImpl( + string queryIdentifier, + IASTNode parsedQuery, + IDictionary enabledFilters, + ISessionFactoryImplementor factory, + IDictionary namedParameters) { _queryIdentifier = queryIdentifier; _stageOneAst = parsedQuery; @@ -58,6 +77,7 @@ public QueryTranslatorImpl( _shallowQuery = false; _enabledFilters = enabledFilters; _factory = factory; + _namedParameters = namedParameters; } /// @@ -434,7 +454,7 @@ private static IStatementExecutor BuildAppropriateStatementExecutor(IStatement s private HqlSqlTranslator Analyze(string collectionRole) { - var translator = new HqlSqlTranslator(_stageOneAst, this, _factory, _tokenReplacements, collectionRole); + var translator = new HqlSqlTranslator(_stageOneAst, this, _factory, _tokenReplacements, _namedParameters, collectionRole); translator.Translate(); @@ -548,15 +568,23 @@ internal class HqlSqlTranslator private readonly QueryTranslatorImpl _qti; private readonly ISessionFactoryImplementor _sfi; private readonly IDictionary _tokenReplacements; + private readonly IDictionary _namedParameters; private readonly string _collectionRole; private IStatement _resultAst; - public HqlSqlTranslator(IASTNode ast, QueryTranslatorImpl qti, ISessionFactoryImplementor sfi, IDictionary tokenReplacements, string collectionRole) + public HqlSqlTranslator( + IASTNode ast, + QueryTranslatorImpl qti, + ISessionFactoryImplementor sfi, + IDictionary tokenReplacements, + IDictionary namedParameters, + string collectionRole) { _inputAst = ast; _qti = qti; _sfi = sfi; _tokenReplacements = tokenReplacements; + _namedParameters = namedParameters; _collectionRole = collectionRole; } @@ -576,7 +604,7 @@ public IStatement Translate() var nodes = new BufferedTreeNodeStream(_inputAst); - var hqlSqlWalker = new HqlSqlWalker(_qti, _sfi, nodes, _tokenReplacements, _collectionRole); + var hqlSqlWalker = new HqlSqlWalker(_qti, _sfi, nodes, _tokenReplacements, _namedParameters, _collectionRole); hqlSqlWalker.TreeAdaptor = new HqlSqlWalkerTreeAdaptor(hqlSqlWalker); try diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/BetweenOperatorNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/BetweenOperatorNode.cs index fd91e09fd3a..8b909121987 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/BetweenOperatorNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/BetweenOperatorNode.cs @@ -63,25 +63,27 @@ private IASTNode GetHighOperand() private static void Check(IASTNode check, IASTNode first, IASTNode second) { - var expectedTypeAwareNode = check as IExpectedTypeAwareNode; - if (expectedTypeAwareNode != null) + if (!(check is IExpectedTypeAwareNode expectedTypeAwareNode) || + expectedTypeAwareNode.ExpectedType != null) { - IType expectedType = null; - var firstNode = first as SqlNode; - if (firstNode != null) - { - expectedType = firstNode.DataType; - } - if (expectedType == null) + return; + } + + IType expectedType = null; + if (first is SqlNode firstNode) + { + expectedType = firstNode.DataType; + } + + if (expectedType == null) + { + if (second is SqlNode secondNode) { - var secondNode = second as SqlNode; - if (secondNode != null) - { - expectedType = secondNode.DataType; - } + expectedType = secondNode.DataType; } - expectedTypeAwareNode.ExpectedType = expectedType; } + + expectedTypeAwareNode.ExpectedType = expectedType; } } } diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryArithmeticOperatorNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryArithmeticOperatorNode.cs index 9b706facb49..60ca24ca379 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryArithmeticOperatorNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryArithmeticOperatorNode.cs @@ -32,32 +32,34 @@ public void Initialize() IType lhType = (lhs is SqlNode) ? ((SqlNode)lhs).DataType : null; IType rhType = (rhs is SqlNode) ? ((SqlNode)rhs).DataType : null; - if (lhs is IExpectedTypeAwareNode && rhType != null) + TrySetExpectedType(lhs, rhType, true); + TrySetExpectedType(rhs, lhType, false); + } + + private void TrySetExpectedType(IASTNode operand, IType otherOperandType, bool leftHandOperand) + { + if (!(operand is IExpectedTypeAwareNode typeAwareNode) || + otherOperandType == null || + typeAwareNode.ExpectedType != null) { - IType expectedType; + return; + } + + IType expectedType = null; - // we have something like : "? [op] rhs" - if (IsDateTimeType(rhType)) + // we have something like : "lhs [op] ?" or "? [op] rhs" + if (IsDateTimeType(otherOperandType)) + { + if (leftHandOperand) { // more specifically : "? [op] datetime" // 1) if the operator is MINUS, the param needs to be of // some datetime type // 2) if the operator is PLUS, the param needs to be of // some numeric type - expectedType = Type == HqlSqlWalker.PLUS ? NHibernateUtil.Double : rhType; + expectedType = Type == HqlSqlWalker.PLUS ? NHibernateUtil.Double : otherOperandType; } - else - { - expectedType = rhType; - } - ((IExpectedTypeAwareNode)lhs).ExpectedType = expectedType; - } - else if (rhs is ParameterNode && lhType != null) - { - IType expectedType = null; - - // we have something like : "lhs [op] ?" - if (IsDateTimeType(lhType)) + else if (Type == HqlSqlWalker.PLUS) { // more specifically : "datetime [op] ?" // 1) if the operator is MINUS, we really cannot determine @@ -65,17 +67,15 @@ public void Initialize() // numeric would be valid // 2) if the operator is PLUS, the param needs to be of // some numeric type - if (Type == HqlSqlWalker.PLUS) - { - expectedType = NHibernateUtil.Double; - } - } - else - { - expectedType = lhType; + expectedType = NHibernateUtil.Double; } - ((IExpectedTypeAwareNode)rhs).ExpectedType = expectedType; } + else + { + expectedType = otherOperandType; + } + + typeAwareNode.ExpectedType = expectedType; } public override IType DataType diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs index bf0560dfc76..cae4b920ec8 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs @@ -65,15 +65,14 @@ public virtual void Initialize() rhsType = lhsType; } - var lshExpectedTypeAwareNode = lhs as IExpectedTypeAwareNode; - if (lshExpectedTypeAwareNode != null) + if (lhs is IExpectedTypeAwareNode lshTypeAwareNode && lshTypeAwareNode.ExpectedType == null) { - lshExpectedTypeAwareNode.ExpectedType = rhsType; + lshTypeAwareNode.ExpectedType = rhsType; } - var rshExpectedTypeAwareNode = rhs as IExpectedTypeAwareNode; - if (rshExpectedTypeAwareNode != null) + + if (rhs is IExpectedTypeAwareNode rshTypeAwareNode && rshTypeAwareNode.ExpectedType == null) { - rshExpectedTypeAwareNode.ExpectedType = lhsType; + rshTypeAwareNode.ExpectedType = lhsType; } MutateRowValueConstructorSyntaxesIfNecessary( lhsType, rhsType ); diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/InLogicOperatorNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/InLogicOperatorNode.cs index 0ad4e404bda..f0b9856f76a 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/InLogicOperatorNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/InLogicOperatorNode.cs @@ -47,11 +47,12 @@ public override void Initialize() IASTNode inListChild = inList.GetChild(0); while (inListChild != null) { - var expectedTypeAwareNode = inListChild as IExpectedTypeAwareNode; - if (expectedTypeAwareNode != null) + if (inListChild is IExpectedTypeAwareNode expectedTypeAwareNode && + expectedTypeAwareNode.ExpectedType == null) { expectedTypeAwareNode.ExpectedType = lhsType; } + inListChild = inListChild.NextSibling; } } diff --git a/src/NHibernate/Impl/AbstractQueryImpl.cs b/src/NHibernate/Impl/AbstractQueryImpl.cs index 9ff4c712b0d..ba46b665466 100644 --- a/src/NHibernate/Impl/AbstractQueryImpl.cs +++ b/src/NHibernate/Impl/AbstractQueryImpl.cs @@ -142,7 +142,8 @@ protected internal virtual IType DetermineType(int paramPosition, object paramVa protected internal virtual IType DetermineType(int paramPosition, object paramValue) { - IType type = parameterMetadata.GetOrdinalParameterExpectedType(paramPosition + 1) ?? GuessType(paramValue); + IType type = parameterMetadata.GetOrdinalParameterExpectedType(paramPosition + 1) ?? + ParameterHelper.GuessType(paramValue, session.Factory); return type; } @@ -154,67 +155,15 @@ protected internal virtual IType DetermineType(string paramName, object paramVal protected internal virtual IType DetermineType(string paramName, object paramValue) { - IType type = parameterMetadata.GetNamedParameterExpectedType(paramName) ?? GuessType(paramValue); + IType type = parameterMetadata.GetNamedParameterExpectedType(paramName) ?? + ParameterHelper.GuessType(paramValue, session.Factory); return type; } protected internal virtual IType DetermineType(string paramName, System.Type clazz) { - IType type = parameterMetadata.GetNamedParameterExpectedType(paramName) ?? GuessType(clazz); - return type; - } - - /// - /// Guesses the from the param's value. - /// - /// The object to guess the of. - /// An for the object. - /// - /// Thrown when the param is null because the - /// can't be guess from a null value. - /// - private IType GuessType(object param) - { - if (param == null) - { - throw new ArgumentNullException("param", "The IType can not be guessed for a null value."); - } - - System.Type clazz = NHibernateProxyHelper.GetClassWithoutInitializingProxy(param); - return GuessType(clazz); - } - - /// - /// Guesses the from the . - /// - /// The to guess the of. - /// An for the . - /// - /// Thrown when the clazz is null because the - /// can't be guess from a null type. - /// - private IType GuessType(System.Type clazz) - { - if (clazz == null) - { - throw new ArgumentNullException("clazz", "The IType can not be guessed for a null value."); - } - - var type = TypeFactory.HeuristicType(clazz); - if (type == null || type is SerializableType) - { - if (session.Factory.TryGetEntityPersister(clazz.FullName) != null) - { - return NHibernateUtil.Entity(clazz); - } - - if (type == null) - { - throw new HibernateException( - "Could not determine a type for class: " + clazz.AssemblyQualifiedName); - } - } - + IType type = parameterMetadata.GetNamedParameterExpectedType(paramName) ?? + ParameterHelper.GuessType(clazz, session.Factory); return type; } @@ -310,7 +259,11 @@ public IQuery SetParameter(int position, T val) { CheckPositionalParameter(position); - return SetParameter(position, val, parameterMetadata.GetOrdinalParameterExpectedType(position + 1) ?? GuessType(typeof(T))); + return SetParameter( + position, + val, + parameterMetadata.GetOrdinalParameterExpectedType(position + 1) ?? + ParameterHelper.GuessType(typeof(T), session.Factory)); } private void CheckPositionalParameter(int position) @@ -327,7 +280,11 @@ private void CheckPositionalParameter(int position) public IQuery SetParameter(string name, T val) { - return SetParameter(name, val, parameterMetadata.GetNamedParameterExpectedType(name) ?? GuessType(typeof (T))); + return SetParameter( + name, + val, + parameterMetadata.GetNamedParameterExpectedType(name) ?? + ParameterHelper.GuessType(typeof(T), session.Factory)); } public IQuery SetParameter(string name, object val) @@ -792,7 +749,12 @@ public IQuery SetParameterList(string name, IEnumerable vals) } object firstValue = vals.Cast().FirstOrDefault(); - SetParameterList(name, vals, firstValue == null ? GuessType(vals.GetCollectionElementType()) : DetermineType(name, firstValue)); + SetParameterList( + name, + vals, + firstValue == null + ? ParameterHelper.GuessType(vals.GetCollectionElementType(), session.Factory) + : DetermineType(name, firstValue)); return this; } diff --git a/src/NHibernate/Linq/DefaultQueryProvider.cs b/src/NHibernate/Linq/DefaultQueryProvider.cs index c8de5a37a5e..912b640a951 100644 --- a/src/NHibernate/Linq/DefaultQueryProvider.cs +++ b/src/NHibernate/Linq/DefaultQueryProvider.cs @@ -11,6 +11,7 @@ using NHibernate.Util; using System.Threading.Tasks; using NHibernate.Multi; +using NHibernate.Param; namespace NHibernate.Linq { @@ -211,7 +212,7 @@ protected virtual NhLinqExpression PrepareQuery(Expression expression, out IQuer query = Session.CreateFilter(Collection, nhLinqExpression); } - SetParameters(query, nhLinqExpression.ParameterValuesByName); + SetParameters(query, nhLinqExpression.NamedParameters); _options?.Apply(query); SetResultTransformerAndAdditionalCriteria(query, nhLinqExpression, nhLinqExpression.ParameterValuesByName); @@ -252,38 +253,19 @@ protected virtual object ExecuteQuery(NhLinqExpression nhLinqExpression, IQuery #pragma warning restore 618 } - private static void SetParameters(IQuery query, IDictionary> parameters) + private static void SetParameters(IQuery query, IDictionary parameters) { foreach (var parameterName in query.NamedParameters) { - var param = parameters[parameterName]; - - if (param.Item1 == null) + // The parameter type will be taken from the parameter metadata + var parameter = parameters[parameterName]; + if (parameter.IsCollection) { - if (typeof(IEnumerable).IsAssignableFrom(param.Item2.ReturnedClass) && - param.Item2.ReturnedClass != typeof(string)) - { - query.SetParameterList(parameterName, null, param.Item2); - } - else - { - query.SetParameter(parameterName, null, param.Item2); - } + query.SetParameterList(parameter.Name, (IEnumerable) parameter.Value); } else { - if (param.Item1 is IEnumerable && !(param.Item1 is string)) - { - query.SetParameterList(parameterName, (IEnumerable)param.Item1); - } - else if (param.Item2 != null) - { - query.SetParameter(parameterName, param.Item1, param.Item2); - } - else - { - query.SetParameter(parameterName, param.Item1); - } + query.SetParameter(parameter.Name, parameter.Value); } } } @@ -310,7 +292,7 @@ public int ExecuteDml(QueryMode queryMode, Expression expression) var query = Session.CreateQuery(nhLinqExpression); - SetParameters(query, nhLinqExpression.ParameterValuesByName); + SetParameters(query, nhLinqExpression.NamedParameters); _options?.Apply(query); return query.ExecuteUpdate(); } diff --git a/src/NHibernate/Linq/NhLinqExpression.cs b/src/NHibernate/Linq/NhLinqExpression.cs index ad39397fd71..50b41b4f927 100644 --- a/src/NHibernate/Linq/NhLinqExpression.cs +++ b/src/NHibernate/Linq/NhLinqExpression.cs @@ -8,6 +8,7 @@ using NHibernate.Linq.Visitors; using NHibernate.Param; using NHibernate.Type; +using Remotion.Linq; namespace NHibernate.Linq { @@ -34,6 +35,8 @@ public class NhLinqExpression : IQueryExpression, ICacheableQueryExpression protected virtual QueryMode QueryMode { get; } + internal IDictionary NamedParameters { get; } + private readonly Expression _expression; private readonly IDictionary _constantToParameterMap; @@ -56,12 +59,12 @@ internal NhLinqExpression(QueryMode queryMode, Expression expression, ISessionFa // referenced from the main query. LinqLogging.LogExpression("Expression (partially evaluated)", _expression); - _expression = ExpressionParameterVisitor.Visit(preTransformResult, out _constantToParameterMap); + _constantToParameterMap = ExpressionParameterVisitor.Visit(preTransformResult); ParameterValuesByName = _constantToParameterMap.Values.Distinct().ToDictionary(p => p.Name, - p => System.Tuple.Create(p.Value, p.Type)); - - Key = ExpressionKeyVisitor.Visit(_expression, _constantToParameterMap); + p => System.Tuple.Create(p.Value, p.Type)); + NamedParameters = _constantToParameterMap.Values.Distinct().ToDictionary(p => p.Name); + Key = ExpressionKeyVisitor.Visit(_expression, _constantToParameterMap, sessionFactory); Type = _expression.Type; @@ -88,6 +91,7 @@ public IASTNode Translate(ISessionFactoryImplementor sessionFactory, bool filter var requiredHqlParameters = new List(); var queryModel = NhRelinqQueryParser.Parse(_expression); queryModel.TransformExpressions(TransparentIdentifierRemovingExpressionVisitor.ReplaceTransparentIdentifiers); + SetParameterTypes(sessionFactory, queryModel); var visitorParameters = new VisitorParameters(sessionFactory, _constantToParameterMap, requiredHqlParameters, new QuerySourceNamer(), TargetType, QueryMode); @@ -118,6 +122,24 @@ internal void CopyExpressionTranslation(NhLinqExpression other) Type = other.Type; } + private void SetParameterTypes( + ISessionFactoryImplementor sessionFactory, + QueryModel queryModel) + { + if (_constantToParameterMap.Count == 0) + { + return; + } + + foreach (var pair in ConstantTypeLocator.GetTypes(queryModel, TargetType, sessionFactory, true)) + { + if (_constantToParameterMap.TryGetValue(pair.Key, out var parameter)) + { + parameter.Type = pair.Value; + } + } + } + private static IASTNode DuplicateTree(IASTNode ast) { var thisNode = ast.DupNode(); diff --git a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs index 99a8e009571..9826ffdbf06 100644 --- a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs +++ b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs @@ -29,8 +29,8 @@ private AddJoinsReWriter(ISessionFactoryImplementor sessionFactory, QueryModel q { _sessionFactory = sessionFactory; var joiner = new Joiner(queryModel, AddJoin); - _memberExpressionJoinDetector = new MemberExpressionJoinDetector(this, joiner); - _whereJoinDetector = new WhereJoinDetector(this, joiner); + _memberExpressionJoinDetector = new MemberExpressionJoinDetector(this, joiner, _sessionFactory); + _whereJoinDetector = new WhereJoinDetector(this, joiner, _sessionFactory); } public static void ReWrite(QueryModel queryModel, VisitorParameters parameters) diff --git a/src/NHibernate/Linq/Visitors/ConstantTypeLocator.cs b/src/NHibernate/Linq/Visitors/ConstantTypeLocator.cs new file mode 100644 index 00000000000..534f6623e78 --- /dev/null +++ b/src/NHibernate/Linq/Visitors/ConstantTypeLocator.cs @@ -0,0 +1,293 @@ +using System.Collections.Generic; +using System.Linq.Expressions; +using NHibernate.Engine; +using NHibernate.Type; +using NHibernate.Util; +using Remotion.Linq; +using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Parsing; + +namespace NHibernate.Linq.Visitors +{ + /// + /// Locates actual type based on its usage. + /// + public static class ConstantTypeLocator + { + /// + /// List of for which the should be related to the other side + /// of a (e.g. o.MyEnum == MyEnum.Option -> MyEnum.Option should have o.MyEnum as a related + /// ). + /// + private static readonly HashSet ValidBinaryExpressionTypes = new HashSet + { + ExpressionType.Equal, + ExpressionType.NotEqual, + ExpressionType.GreaterThanOrEqual, + ExpressionType.GreaterThan, + ExpressionType.LessThan, + ExpressionType.LessThanOrEqual, + ExpressionType.Coalesce, + ExpressionType.Assign + }; + + /// + /// List of for which the should be copied across + /// as related (e.g. (o.MyEnum ?? MyEnum.Option) == MyEnum.Option2 -> MyEnum.Option2 should have o.MyEnum as a related + /// ). + /// + private static readonly HashSet NonVoidOperators = new HashSet + { + ExpressionType.Coalesce, + ExpressionType.Conditional + }; + + /// + /// Calculate constant expressions types inside the query model. + /// + /// The query model. + /// The target entity type. + /// The session factory. + /// + public static Dictionary GetTypes( + QueryModel queryModel, + System.Type targetType, + ISessionFactoryImplementor sessionFactory) + { + return GetTypes(queryModel, targetType, sessionFactory, false); + } + + internal static Dictionary GetTypes( + QueryModel queryModel, + System.Type targetType, + ISessionFactoryImplementor sessionFactory, + bool removeMappedAsCalls) + { + var types = new Dictionary(); + var visitor = new ConstantTypeLocatorVisitor(removeMappedAsCalls, targetType, sessionFactory); + queryModel.TransformExpressions(visitor.Visit); + + foreach (var pair in visitor.ConstantExpressions) + { + var type = pair.Value; + var constantExpression = pair.Key; + if (type != null) + { + // MappedAs was used + types.Add(constantExpression, type); + continue; + } + + // In order to get the actual type we have to check first the related member expressions, as + // an enum is translated in a numeric type when used in a BinaryExpression and also it can be mapped as string. + // By getting the type from a related member expression we also get the correct length in case of StringType + // or precision when having a DecimalType. + if (visitor.RelatedExpressions.TryGetValue(constantExpression, out var memberExpressions)) + { + foreach (var memberExpression in memberExpressions) + { + if (ExpressionsHelper.TryGetMappedType( + sessionFactory, + memberExpression, + out type, + out _, + out _, + out _)) + { + break; + } + } + } + + // No related MemberExpressions was found, guess the type by value or its type when null. + if (type == null) + { + type = constantExpression.Value != null + ? ParameterHelper.TryGuessType(constantExpression.Value, sessionFactory, out _) + : ParameterHelper.TryGuessType(constantExpression.Type, sessionFactory, out _); + } + + types.Add(constantExpression, type); + } + + return types; + } + + private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor + { + private readonly bool _removeMappedAsCalls; + private readonly System.Type _targetType; + private readonly ISessionFactoryImplementor _sessionFactory; + public readonly Dictionary ConstantExpressions = + new Dictionary(); + public readonly Dictionary> RelatedExpressions = + new Dictionary>(); + + public ConstantTypeLocatorVisitor( + bool removeMappedAsCalls, + System.Type targetType, + ISessionFactoryImplementor sessionFactory) + { + _removeMappedAsCalls = removeMappedAsCalls; + _targetType = targetType; + _sessionFactory = sessionFactory; + } + + protected override Expression VisitBinary(BinaryExpression node) + { + node = (BinaryExpression) base.VisitBinary(node); + if (!ValidBinaryExpressionTypes.Contains(node.NodeType)) + { + return node; + } + + var left = Unwrap(node.Left); + var right = Unwrap(node.Right); + if (node.NodeType == ExpressionType.Assign) + { + VisitAssign(left, right); + } + else + { + AddRelatedMemberExpression(node, left, right); + AddRelatedMemberExpression(node, right, left); + } + + return node; + } + + protected override Expression VisitConditional(ConditionalExpression node) + { + node = (ConditionalExpression) base.VisitConditional(node); + var ifTrue = Unwrap(node.IfTrue); + var ifFalse = Unwrap(node.IfFalse); + AddRelatedMemberExpression(node, ifTrue, ifFalse); + AddRelatedMemberExpression(node, ifFalse, ifTrue); + + return node; + } + + protected override MemberAssignment VisitMemberAssignment(MemberAssignment node) + { + node = base.VisitMemberAssignment(node); + return node; + } + + protected override Expression VisitMethodCall(MethodCallExpression node) + { + if (VisitorUtil.IsMappedAs(node.Method)) + { + var rawParameter = Visit(node.Arguments[0]); + var parameter = rawParameter as ConstantExpression; + var type = node.Arguments[1] as ConstantExpression; + if (parameter == null) + throw new HibernateException( + $"{nameof(LinqExtensionMethods.MappedAs)} must be called on an expression which can be evaluated as " + + $"{nameof(ConstantExpression)}. It was call on {rawParameter?.GetType().Name ?? "null"} instead."); + if (type == null) + throw new HibernateException( + $"{nameof(LinqExtensionMethods.MappedAs)} type must be supplied as {nameof(ConstantExpression)}. " + + $"It was {node.Arguments[1]?.GetType().Name ?? "null"} instead."); + + ConstantExpressions[parameter] = (IType) type.Value; + + return _removeMappedAsCalls + ? rawParameter + : node; + } + + return base.VisitMethodCall(node); + } + + protected override Expression VisitConstant(ConstantExpression node) + { + if (node.Value is IEntityNameProvider || RelatedExpressions.ContainsKey(node)) + { + return node; + } + + RelatedExpressions.Add(node, new HashSet()); + ConstantExpressions.Add(node, null); + return node; + } + + public override Expression Visit(Expression node) + { + if (node is SubQueryExpression subQueryExpression) + { + subQueryExpression.QueryModel.TransformExpressions(Visit); + } + + return base.Visit(node); + } + + private void VisitAssign(Expression leftNode, Expression rightNode) + { + // Insert and Update statements have assign expressions, where the left side is a parameter and its name + // represents the property path to be assigned + if (!(leftNode is ParameterExpression parameterExpression) || + !(rightNode is ConstantExpression constantExpression)) + { + return; + } + + var entityName = _sessionFactory.TryGetGuessEntityName(_targetType); + if (entityName == null) + { + return; + } + + var persister = _sessionFactory.GetEntityPersister(entityName); + ConstantExpressions[constantExpression] = persister.EntityMetamodel.GetPropertyType(parameterExpression.Name); + } + + private void AddRelatedMemberExpression(Expression node, Expression left, Expression right) + { + HashSet set; + if (left is MemberExpression leftMemberExpression) + { + AddMemberExpression(right, leftMemberExpression); + if (NonVoidOperators.Contains(node.NodeType)) + { + AddMemberExpression(node, leftMemberExpression); + } + } + + // Copy all found MemberExpressions to the other side + // (e.g. (o.Prop ?? constant1) == constant2 -> copy o.Prop to constant2) + if (RelatedExpressions.TryGetValue(left, out set)) + { + foreach (var nestedMemberExpression in set) + { + AddMemberExpression(right, nestedMemberExpression); + if (NonVoidOperators.Contains(node.NodeType)) + { + AddMemberExpression(node, nestedMemberExpression); + } + } + } + } + + private void AddMemberExpression(Expression expression, MemberExpression memberExpression) + { + if (!RelatedExpressions.TryGetValue(expression, out var set)) + { + set = new HashSet(); + RelatedExpressions.Add(expression, set); + } + + set.Add(memberExpression); + } + + private static Expression Unwrap(Expression expression) + { + if (expression is UnaryExpression unaryExpression) + { + return unaryExpression.Operand; + } + + return expression; + } + } + } +} diff --git a/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs index ef4981d2aec..28261d2088c 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs @@ -7,7 +7,9 @@ using System.Reflection; using System.Runtime.CompilerServices; using System.Text; +using NHibernate.Engine; using NHibernate.Param; +using NHibernate.Type; using Remotion.Linq.Parsing; namespace NHibernate.Linq.Visitors @@ -22,22 +24,46 @@ namespace NHibernate.Linq.Visitors public class ExpressionKeyVisitor : RelinqExpressionVisitor { private readonly IDictionary _constantToParameterMap; + private readonly ISessionFactoryImplementor _sessionFactory; readonly StringBuilder _string = new StringBuilder(); - private ExpressionKeyVisitor(IDictionary constantToParameterMap) + private ExpressionKeyVisitor( + IDictionary constantToParameterMap, + ISessionFactoryImplementor sessionFactory) { _constantToParameterMap = constantToParameterMap; + _sessionFactory = sessionFactory; } + // Since v5.3 + [Obsolete("Use the overload with ISessionFactoryImplementor parameter")] public static string Visit(Expression expression, IDictionary parameters) { - var visitor = new ExpressionKeyVisitor(parameters); + var visitor = new ExpressionKeyVisitor(parameters, null); visitor.Visit(expression); return visitor.ToString(); } + /// + /// Generates the key for the expression. + /// + /// The expression. + /// The session factory. + /// Parameters found in . + /// The key for the expression. + public static string Visit( + Expression rootExpression, + IDictionary parameters, + ISessionFactoryImplementor sessionFactory) + { + var visitor = new ExpressionKeyVisitor(parameters, sessionFactory); + visitor.Visit(rootExpression); + + return visitor.ToString(); + } + public override string ToString() { return _string.ToString(); @@ -86,49 +112,70 @@ protected override Expression VisitConstant(ConstantExpression expression) throw new InvalidOperationException("Cannot visit a constant without a constant to parameter map."); if (_constantToParameterMap.TryGetValue(expression, out param)) { - // Nulls generate different query plans. X = variable generates a different query depending on if variable is null or not. - if (param.Value == null) - { - _string.Append("NULL"); - } - else - { - var value = param.Value as IEnumerable; - if (value != null && !(value is string) && !value.Cast().Any()) - { - _string.Append("EmptyList"); - } - else - { - _string.Append(param.Name); - } - } + VisitParameter(param); } else { - if (expression.Value == null) - { - _string.Append("NULL"); - } - else - { - var value = expression.Value as IEnumerable; - if (value != null && !(value is string) && !(value is IQueryable)) - { - _string.Append("{"); - _string.Append(String.Join(",", value.Cast())); - _string.Append("}"); - } - else - { - _string.Append(expression.Value); - } - } + VisitConstantValue(expression.Value); } return base.VisitConstant(expression); } + private void VisitConstantValue(object value) + { + if (value == null) + { + _string.Append("NULL"); + return; + } + + if (value is IEnumerable enumerable && !(value is IQueryable)) + { + _string.Append("{"); + _string.Append(string.Join(",", enumerable.Cast())); + _string.Append("}"); + return; + } + + // When MappedAs is used we have to put all sql types information in the key in order to + // distinct when different precisions/sizes are used. + if (_sessionFactory != null && value is IType type) + { + _string.Append(type.Name); + _string.Append('['); + _string.Append(string.Join(",", type.SqlTypes(_sessionFactory).Select(o => o.ToString()))); + _string.Append(']'); + return; + } + + _string.Append(value); + } + + private void VisitParameter(NamedParameter param) + { + // Nulls generate different query plans. X = variable generates a different query depending on if variable is null or not. + if (param.Value == null) + { + _string.Append("NULL"); + return; + } + + if (param.IsCollection && !((IEnumerable) param.Value).Cast().Any()) + { + _string.Append("EmptyList"); + } + else + { + _string.Append(param.Name); + } + + // Add the type in order to avoid invalid parameter conversions (string -> char) + _string.Append("<"); + _string.Append(param.Value.GetType()); + _string.Append(">"); + } + private T AppendCommas(T expression) where T : Expression { Visit(expression); diff --git a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs index 45134248a51..51607c40c1a 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs @@ -1,4 +1,5 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; @@ -21,20 +22,13 @@ public class ExpressionParameterVisitor : RelinqExpressionVisitor private readonly IDictionary _queryVariables; private readonly ISessionFactoryImplementor _sessionFactory; - private static readonly MethodInfo QueryableSkipDefinition = - ReflectHelper.FastGetMethodDefinition(Queryable.Skip, default(IQueryable), 0); - private static readonly MethodInfo QueryableTakeDefinition = - ReflectHelper.FastGetMethodDefinition(Queryable.Take, default(IQueryable), 0); - private static readonly MethodInfo EnumerableSkipDefinition = - ReflectHelper.FastGetMethodDefinition(Enumerable.Skip, default(IEnumerable), 0); - private static readonly MethodInfo EnumerableTakeDefinition = - ReflectHelper.FastGetMethodDefinition(Enumerable.Take, default(IEnumerable), 0); - - private readonly ICollection _pagingMethods = new HashSet - { - QueryableSkipDefinition, QueryableTakeDefinition, - EnumerableSkipDefinition, EnumerableTakeDefinition - }; + private static readonly ISet PagingMethods = new HashSet + { + ReflectionCache.EnumerableMethods.SkipDefinition, + ReflectionCache.EnumerableMethods.TakeDefinition, + ReflectionCache.QueryableMethods.SkipDefinition, + ReflectionCache.QueryableMethods.TakeDefinition + }; // Since v5.3 [Obsolete("Please use overload with preTransformationResult parameter instead.")] @@ -59,22 +53,19 @@ public static IDictionary Visit(Expression e return visitor._parameters; } - public static Expression Visit( - PreTransformationResult preTransformationResult, - out IDictionary parameters) + public static IDictionary Visit(PreTransformationResult preTransformationResult) { var visitor = new ExpressionParameterVisitor(preTransformationResult); - var expression = visitor.Visit(preTransformationResult.Expression); - parameters = visitor._parameters; - - return expression; + visitor.Visit(preTransformationResult.Expression); + return visitor._parameters; } protected override Expression VisitMethodCall(MethodCallExpression expression) { - if (expression.Method.Name == nameof(LinqExtensionMethods.MappedAs) && expression.Method.DeclaringType == typeof(LinqExtensionMethods)) + if (VisitorUtil.IsMappedAs(expression.Method)) { var rawParameter = Visit(expression.Arguments[0]); + // TODO 6.0: Remove below code and return expression as this logic is now inside ConstantTypeLocator var parameter = rawParameter as ConstantExpression; var type = expression.Arguments[1] as ConstantExpression; if (parameter == null) @@ -95,10 +86,10 @@ protected override Expression VisitMethodCall(MethodCallExpression expression) ? expression.Method.GetGenericMethodDefinition() : expression.Method; - if (_pagingMethods.Contains(method) && !_sessionFactory.Dialect.SupportsVariableLimit) + if (PagingMethods.Contains(method) && !_sessionFactory.Dialect.SupportsVariableLimit) { - //TODO: find a way to make this code cleaner var query = Visit(expression.Arguments[0]); + //TODO 6.0: Remove the below code and return expression var arg = expression.Arguments[1]; if (query == expression.Arguments[0]) @@ -125,11 +116,14 @@ protected override Expression VisitConstant(ConstantExpression expression) // We have a bit more information about the null parameter value. // Figure out a type so that HQL doesn't break on the null. (Related to NH-2430) + // In v5.3 types are calculated by ConstantTypeLocator, this logic is only for back compatibility. + // TODO 6.0: Remove if (expression.Value == null) type = NHibernateUtil.GuessType(expression.Type); // Constant characters should be sent as strings - if (expression.Type == typeof(char)) + // TODO 6.0: Remove + if (_queryVariables == null && expression.Type == typeof(char)) { value = value.ToString(); } @@ -144,13 +138,13 @@ protected override Expression VisitConstant(ConstantExpression expression) _queryVariables.TryGetValue(expression, out var variable) && !_variableParameters.TryGetValue(variable, out parameter)) { - parameter = new NamedParameter("p" + (_parameters.Count + 1), value, type); + parameter = CreateParameter(expression, value, type); _variableParameters.Add(variable, parameter); } if (parameter == null) { - parameter = new NamedParameter("p" + (_parameters.Count + 1), value, type); + parameter = CreateParameter(expression, value, type); } _parameters.Add(expression, parameter); @@ -161,9 +155,27 @@ protected override Expression VisitConstant(ConstantExpression expression) return base.VisitConstant(expression); } + private NamedParameter CreateParameter(ConstantExpression expression, object value, IType type) + { + var parameterName = "p" + (_parameters.Count + 1); + return IsCollectionType(expression) + ? new NamedListParameter(parameterName, value, type) + : new NamedParameter(parameterName, value, type); + } + private static bool IsNullObject(ConstantExpression expression) { return expression.Type == typeof(Object) && expression.Value == null; } + + private static bool IsCollectionType(ConstantExpression expression) + { + if (expression.Value != null) + { + return expression.Value is IEnumerable && !(expression.Value is string); + } + + return expression.Type.IsCollectionType(); + } } } diff --git a/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs b/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs index 580ba3cf00c..019769fccb1 100644 --- a/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs +++ b/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs @@ -19,16 +19,18 @@ internal class MemberExpressionJoinDetector : RelinqExpressionVisitor { private readonly IIsEntityDecider _isEntityDecider; private readonly IJoiner _joiner; + private readonly ISessionFactoryImplementor _sessionFactory; private bool _requiresJoinForNonIdentifier; private bool _preventJoinsInConditionalTest; private bool _hasIdentifier; private int _memberExpressionDepth; - public MemberExpressionJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner) + public MemberExpressionJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner, ISessionFactoryImplementor sessionFactory) { _isEntityDecider = isEntityDecider; _joiner = joiner; + _sessionFactory = sessionFactory; } protected override Expression VisitMember(MemberExpression expression) @@ -55,7 +57,7 @@ protected override Expression VisitMember(MemberExpression expression) ((_requiresJoinForNonIdentifier && !_hasIdentifier) || _memberExpressionDepth > 0) && _joiner.CanAddJoin(expression)) { - var key = ExpressionKeyVisitor.Visit(expression, null); + var key = ExpressionKeyVisitor.Visit(expression, null, _sessionFactory); return _joiner.AddJoin(result, key); } diff --git a/src/NHibernate/Linq/Visitors/VisitorUtil.cs b/src/NHibernate/Linq/Visitors/VisitorUtil.cs index 22ac89dd0aa..885d2c66a4e 100644 --- a/src/NHibernate/Linq/Visitors/VisitorUtil.cs +++ b/src/NHibernate/Linq/Visitors/VisitorUtil.cs @@ -131,5 +131,11 @@ public static string GetMemberPath(this MemberExpression memberExpression) } return path; } + + internal static bool IsMappedAs(MethodInfo methodInfo) + { + return methodInfo.Name == nameof(LinqExtensionMethods.MappedAs) && + methodInfo.DeclaringType == typeof(LinqExtensionMethods); + } } } diff --git a/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs b/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs index 886d4e0e2b1..689457a7403 100644 --- a/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs +++ b/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs @@ -62,6 +62,7 @@ internal class WhereJoinDetector : RelinqExpressionVisitor // TODO: There are a number of types of expressions that we didn't handle here due to time constraints. For example, the ?: operator could be checked easily. private readonly IIsEntityDecider _isEntityDecider; private readonly IJoiner _joiner; + private readonly ISessionFactoryImplementor _sessionFactory; private readonly Stack _handled = new Stack(); @@ -71,10 +72,11 @@ internal class WhereJoinDetector : RelinqExpressionVisitor // The following is used for member expressions traversal. private int _memberExpressionDepth; - internal WhereJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner) + internal WhereJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner, ISessionFactoryImplementor sessionFactory) { _isEntityDecider = isEntityDecider; _joiner = joiner; + _sessionFactory = sessionFactory; } public Expression Transform(Expression expression) @@ -329,7 +331,7 @@ protected override Expression VisitMember(MemberExpression expression) { // Don't add joins for things like a.B == a.C where B and C are entities. // We only need to join B when there's something like a.B.D. - var key = ExpressionKeyVisitor.Visit(expression, null); + var key = ExpressionKeyVisitor.Visit(expression, null, _sessionFactory); if (_memberExpressionDepth > 0 && _joiner.CanAddJoin(expression)) { diff --git a/src/NHibernate/Param/NamedListParameter.cs b/src/NHibernate/Param/NamedListParameter.cs new file mode 100644 index 00000000000..9e02fd442b6 --- /dev/null +++ b/src/NHibernate/Param/NamedListParameter.cs @@ -0,0 +1,13 @@ +using NHibernate.Type; + +namespace NHibernate.Param +{ + internal class NamedListParameter : NamedParameter + { + public NamedListParameter(string name, object value, IType elementType) : base(name, value, elementType) + { + } + + public override bool IsCollection => true; + } +} diff --git a/src/NHibernate/Param/NamedParameter.cs b/src/NHibernate/Param/NamedParameter.cs index b42f69925f0..95f6604ed33 100644 --- a/src/NHibernate/Param/NamedParameter.cs +++ b/src/NHibernate/Param/NamedParameter.cs @@ -15,6 +15,8 @@ public NamedParameter(string name, object value, IType type) public object Value { get; internal set; } public IType Type { get; internal set; } + public virtual bool IsCollection => false; + public bool Equals(NamedParameter other) { if (ReferenceEquals(null, other)) @@ -38,4 +40,4 @@ public override int GetHashCode() return (Name != null ? Name.GetHashCode() : 0); } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Util/ParameterHelper.cs b/src/NHibernate/Util/ParameterHelper.cs new file mode 100644 index 00000000000..d201ebccc3c --- /dev/null +++ b/src/NHibernate/Util/ParameterHelper.cs @@ -0,0 +1,166 @@ +using System; +using System.Collections; +using System.Linq; +using NHibernate.Engine; +using NHibernate.Proxy; +using NHibernate.Type; + +namespace NHibernate.Util +{ + internal static class ParameterHelper + { + /// + /// Guesses the from the param's value. + /// + /// The object to guess the of. + /// The session factory to search for entity persister. + /// The output parameter that represents whether the is a collection. + /// An for the object. + /// + /// Thrown when the param is null because the + /// can't be guess from a null value. + /// + public static IType TryGuessType(object param, ISessionFactoryImplementor sessionFactory, out bool isCollection) + { + if (param == null) + { + throw new ArgumentNullException(nameof(param), "The IType can not be guessed for a null value."); + } + + if (param is IEnumerable enumerable && !(param is string)) + { + var firstValue = enumerable.Cast().FirstOrDefault(); + isCollection = true; + return firstValue == null + ? TryGuessType(enumerable.GetCollectionElementType(), sessionFactory) + : TryGuessType(firstValue, sessionFactory, out _); + } + + isCollection = false; + var clazz = NHibernateProxyHelper.GetClassWithoutInitializingProxy(param); + return TryGuessType(clazz, sessionFactory); + } + + /// + /// Guesses the from the param's value. + /// + /// The object to guess the of. + /// The session factory to search for entity persister. + /// An for the object. + /// + /// Thrown when the param is null because the + /// can't be guess from a null value. + /// + public static IType GuessType(object param, ISessionFactoryImplementor sessionFactory) + { + if (param == null) + { + throw new ArgumentNullException(nameof(param), "The IType can not be guessed for a null value."); + } + + System.Type clazz = NHibernateProxyHelper.GetClassWithoutInitializingProxy(param); + return GuessType(clazz, sessionFactory); + } + + /// + /// Guesses the from the . + /// + /// The to guess the of. + /// The session factory to search for entity persister. + /// The output parameter that represents whether the is a collection. + /// An for the . + /// + /// Thrown when the clazz is null because the + /// can't be guess from a null type. + /// + public static IType TryGuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory, out bool isCollection) + { + if (clazz == null) + { + throw new ArgumentNullException(nameof(clazz), "The IType can not be guessed for a null value."); + } + + if (clazz.IsCollectionType()) + { + isCollection = true; + return TryGuessType(ReflectHelper.GetCollectionElementType(clazz), sessionFactory, out _); + } + + isCollection = false; + return TryGuessType(clazz, sessionFactory); + } + + /// + /// Guesses the from the . + /// + /// The to guess the of. + /// The session factory to search for entity persister. + /// The output parameter that represents whether the is a collection. + /// An for the . + /// + /// Thrown when the clazz is null because the + /// can't be guess from a null type. + /// + public static IType GuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory, out bool isCollection) + { + if (clazz == null) + { + throw new ArgumentNullException(nameof(clazz), "The IType can not be guessed for a null value."); + } + + if (typeof(IEnumerable).IsAssignableFrom(clazz) && typeof(string) != clazz) + { + isCollection = true; + return GuessType(ReflectHelper.GetCollectionElementType(clazz), sessionFactory); + } + + isCollection = false; + return GuessType(clazz, sessionFactory); + } + + /// + /// Guesses the from the . + /// + /// The to guess the of. + /// The session factory to search for entity persister. + /// An for the . + /// + /// Thrown when the clazz is null because the + /// can't be guess from a null type. + /// + public static IType GuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory) + { + return TryGuessType(clazz, sessionFactory) ?? + throw new HibernateException("Could not determine a type for class: " + clazz.AssemblyQualifiedName); + } + + /// + /// Guesses the from the . + /// + /// The to guess the of. + /// The session factory to search for entity persister. + /// An for the . + /// + /// Thrown when the clazz is null because the + /// can't be guess from a null type. + /// + public static IType TryGuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory) + { + if (clazz == null) + { + throw new ArgumentNullException(nameof(clazz), "The IType can not be guessed for a null value."); + } + + var type = TypeFactory.HeuristicType(clazz); + if (type == null || type is SerializableType) + { + if (sessionFactory.TryGetEntityPersister(clazz.FullName) != null) + { + return NHibernateUtil.Entity(clazz); + } + } + + return type; + } + } +} diff --git a/src/NHibernate/Util/ReflectionCache.cs b/src/NHibernate/Util/ReflectionCache.cs index c40a395f98d..47fde15950d 100644 --- a/src/NHibernate/Util/ReflectionCache.cs +++ b/src/NHibernate/Util/ReflectionCache.cs @@ -54,6 +54,11 @@ internal static class EnumerableMethods internal static readonly MethodInfo ToListDefinition = ReflectHelper.FastGetMethodDefinition(Enumerable.ToList, default(IEnumerable)); + + internal static readonly MethodInfo SkipDefinition = + ReflectHelper.FastGetMethodDefinition(Enumerable.Skip, default(IEnumerable), default(int)); + internal static readonly MethodInfo TakeDefinition = + ReflectHelper.FastGetMethodDefinition(Enumerable.Take, default(IEnumerable), default(int)); } internal static class MethodBaseMethods @@ -215,6 +220,11 @@ internal static class QueryableMethods ReflectHelper.FastGetMethodDefinition(Queryable.Average, default(IQueryable), default(Expression>)); internal static readonly MethodInfo AverageWithSelectorOfNullableDecimalDefinition = ReflectHelper.FastGetMethodDefinition(Queryable.Average, default(IQueryable), default(Expression>)); + + internal static readonly MethodInfo SkipDefinition = + ReflectHelper.FastGetMethodDefinition(Queryable.Skip, default(IQueryable), default(int)); + internal static readonly MethodInfo TakeDefinition = + ReflectHelper.FastGetMethodDefinition(Queryable.Take, default(IQueryable), default(int)); } internal static class TypeMethods From a458354b1180bf5a82b86514650f4e0e012d890a Mon Sep 17 00:00:00 2001 From: maca88 Date: Mon, 27 Apr 2020 13:35:12 +0200 Subject: [PATCH 02/11] Fix CodeFactor issue --- src/NHibernate.Test/Linq/ConstantTypeLocatorTests.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/NHibernate.Test/Linq/ConstantTypeLocatorTests.cs b/src/NHibernate.Test/Linq/ConstantTypeLocatorTests.cs index 4e21a078ebe..e16231f7f3c 100644 --- a/src/NHibernate.Test/Linq/ConstantTypeLocatorTests.cs +++ b/src/NHibernate.Test/Linq/ConstantTypeLocatorTests.cs @@ -228,7 +228,6 @@ public void ConditionalMemberTest() ); } - [Test] public void AssignMemberTest() { From c83cc530590d766b51f0473ba872656855c17182 Mon Sep 17 00:00:00 2001 From: maca88 Date: Sat, 9 May 2020 23:01:59 +0200 Subject: [PATCH 03/11] Add support for dynamic components --- .../Northwind/Entities/DynamicUser.cs | 13 +++ .../Northwind/Entities/Northwind.cs | 5 + .../Northwind/Mappings/DynamicUser.hbm.xml | 30 ++++++ src/NHibernate.Test/Linq/LinqTestCase.cs | 3 +- ...rTests.cs => ParameterTypeLocatorTests.cs} | 48 +++++++-- src/NHibernate/Linq/NhLinqExpression.cs | 20 +--- .../Linq/Visitors/ExpressionKeyVisitor.cs | 14 +++ .../Visitors/ExpressionParameterVisitor.cs | 12 +++ .../Visitors/HqlGeneratorExpressionVisitor.cs | 7 +- ...TypeLocator.cs => ParameterTypeLocator.cs} | 97 +++++++++++-------- src/NHibernate/Linq/Visitors/VisitorUtil.cs | 52 ++++++---- src/NHibernate/Util/ExpressionsHelper.cs | 45 +++++++++ 12 files changed, 256 insertions(+), 90 deletions(-) create mode 100644 src/NHibernate.DomainModel/Northwind/Entities/DynamicUser.cs create mode 100644 src/NHibernate.DomainModel/Northwind/Mappings/DynamicUser.hbm.xml rename src/NHibernate.Test/Linq/{ConstantTypeLocatorTests.cs => ParameterTypeLocatorTests.cs} (88%) rename src/NHibernate/Linq/Visitors/{ConstantTypeLocator.cs => ParameterTypeLocator.cs} (74%) diff --git a/src/NHibernate.DomainModel/Northwind/Entities/DynamicUser.cs b/src/NHibernate.DomainModel/Northwind/Entities/DynamicUser.cs new file mode 100644 index 00000000000..974b1ccea7d --- /dev/null +++ b/src/NHibernate.DomainModel/Northwind/Entities/DynamicUser.cs @@ -0,0 +1,13 @@ +using System.Collections; + +namespace NHibernate.DomainModel.Northwind.Entities +{ + public class DynamicUser + { + public virtual int Id { get; set; } + + public virtual dynamic Properties { get; set; } + + public virtual IDictionary Settings { get; set; } + } +} diff --git a/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs b/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs index c4cbda23f26..4551ce0e9d8 100755 --- a/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs @@ -69,6 +69,11 @@ public IQueryable Users get { return _session.Query(); } } + public IQueryable DynamicUsers + { + get { return _session.Query(); } + } + public IQueryable PatientRecords { get { return _session.Query(); } diff --git a/src/NHibernate.DomainModel/Northwind/Mappings/DynamicUser.hbm.xml b/src/NHibernate.DomainModel/Northwind/Mappings/DynamicUser.hbm.xml new file mode 100644 index 00000000000..1b6775b29c6 --- /dev/null +++ b/src/NHibernate.DomainModel/Northwind/Mappings/DynamicUser.hbm.xml @@ -0,0 +1,30 @@ + + + + + select * from Users + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/NHibernate.Test/Linq/LinqTestCase.cs b/src/NHibernate.Test/Linq/LinqTestCase.cs index e047732d7ad..daf14b9cd18 100755 --- a/src/NHibernate.Test/Linq/LinqTestCase.cs +++ b/src/NHibernate.Test/Linq/LinqTestCase.cs @@ -34,7 +34,8 @@ protected override string[] Mappings "Northwind.Mappings.User.hbm.xml", "Northwind.Mappings.TimeSheet.hbm.xml", "Northwind.Mappings.Animal.hbm.xml", - "Northwind.Mappings.Patient.hbm.xml" + "Northwind.Mappings.Patient.hbm.xml", + "Northwind.Mappings.DynamicUser.hbm.xml" }; } } diff --git a/src/NHibernate.Test/Linq/ConstantTypeLocatorTests.cs b/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs similarity index 88% rename from src/NHibernate.Test/Linq/ConstantTypeLocatorTests.cs rename to src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs index e16231f7f3c..e6a46ddf27b 100644 --- a/src/NHibernate.Test/Linq/ConstantTypeLocatorTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Globalization; using System.Linq; +using System.Linq.Dynamic.Core; using System.Linq.Expressions; using NHibernate.DomainModel.Northwind.Entities; using NHibernate.Engine.Query; @@ -13,7 +14,7 @@ namespace NHibernate.Test.Linq { - public class ConstantTypeLocatorTests : LinqTestCase + public class ParameterTypeLocatorTests : LinqTestCase { [Test] public void AddIntegerTest() @@ -172,8 +173,7 @@ public void ConditionalTest() new Dictionary> { {"2", o => o is EnumStoredAsStringType}, - {"Unspecified", o => o is EnumStoredAsStringType}, - {"null", o => o is PersistentEnumType}, // HasValue + {"Unspecified", o => o is EnumStoredAsStringType} }, db.Users.Where(o => (o.NullableEnum2.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) == EnumStoredAsString.Medium), db.Users.Where(o => EnumStoredAsString.Medium == (o.NullableEnum2.HasValue ? EnumStoredAsString.Unspecified : o.Enum1)) @@ -189,8 +189,7 @@ public void DoubleConditionalTest() {"0", o => o is PersistentEnumType}, {"2", o => o is EnumStoredAsStringType}, {"Small", o => o is EnumStoredAsStringType}, - {"Unspecified", o => o is EnumStoredAsStringType}, - {"null", o => o is PersistentEnumType}, // HasValue + {"Unspecified", o => o is EnumStoredAsStringType} }, db.Users.Where(o => (o.Enum2 != EnumStoredAsInt32.Unspecified ? (o.NullableEnum2.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) @@ -228,6 +227,36 @@ public void ConditionalMemberTest() ); } + [Test] + public void DynamicMemberTest() + { + AssertResults( + new Dictionary> + { + {"\"test\"", o => o is AnsiStringType}, + }, + db.DynamicUsers.Where("Properties.Name == @0", "test"), + db.DynamicUsers.Where("@0 == Properties.Name", "test") + ); + } + + [Test] + public void DynamicDictionaryMemberTest() + { + AssertResults( + new Dictionary> + { + {"\"test\"", o => o is AnsiStringType}, + }, +#pragma warning disable CS0252 + db.DynamicUsers.Where(o => o.Settings["Property1"] == "test"), +#pragma warning restore CS0252 +#pragma warning disable CS0253 + db.DynamicUsers.Where(o => "test" == o.Settings["Property1"]) +#pragma warning restore CS0253 + ); + } + [Test] public void AssignMemberTest() { @@ -373,11 +402,12 @@ private void AssertResult( System.Type targetType) { var result = NhRelinqQueryParser.PreTransform(expression, new PreTransformationParameters(queryMode, Sfi)); + var parameters = ExpressionParameterVisitor.Visit(result); expression = result.Expression; var queryModel = NhRelinqQueryParser.Parse(expression); - var types = ConstantTypeLocator.GetTypes(queryModel, targetType, Sfi); - Assert.That(types.Count, Is.EqualTo(expectedResults.Count), "Incorrect number of constants"); - foreach (var pair in types) + ParameterTypeLocator.SetParameterTypes(parameters, queryModel, targetType, Sfi); + Assert.That(parameters.Count, Is.EqualTo(expectedResults.Count), "Incorrect number of parameters"); + foreach (var pair in parameters) { var origCulture = CultureInfo.CurrentCulture; try @@ -385,7 +415,7 @@ private void AssertResult( CultureInfo.CurrentCulture = CultureInfo.InvariantCulture; var expressionText = pair.Key.ToString(); Assert.That(expectedResults.ContainsKey(expressionText), Is.True, $"{expressionText} constant is not expected"); - Assert.That(expectedResults[expressionText](pair.Value), Is.True, $"Invalid type, actual type: {pair.Value?.Name ?? "null"}"); + Assert.That(expectedResults[expressionText](pair.Value.Type), Is.True, $"Invalid type, actual type: {pair.Value?.Name ?? "null"}"); } finally { diff --git a/src/NHibernate/Linq/NhLinqExpression.cs b/src/NHibernate/Linq/NhLinqExpression.cs index 50b41b4f927..cfb0b1d54b0 100644 --- a/src/NHibernate/Linq/NhLinqExpression.cs +++ b/src/NHibernate/Linq/NhLinqExpression.cs @@ -91,7 +91,7 @@ public IASTNode Translate(ISessionFactoryImplementor sessionFactory, bool filter var requiredHqlParameters = new List(); var queryModel = NhRelinqQueryParser.Parse(_expression); queryModel.TransformExpressions(TransparentIdentifierRemovingExpressionVisitor.ReplaceTransparentIdentifiers); - SetParameterTypes(sessionFactory, queryModel); + ParameterTypeLocator.SetParameterTypes(_constantToParameterMap, queryModel, TargetType, sessionFactory, true); var visitorParameters = new VisitorParameters(sessionFactory, _constantToParameterMap, requiredHqlParameters, new QuerySourceNamer(), TargetType, QueryMode); @@ -122,24 +122,6 @@ internal void CopyExpressionTranslation(NhLinqExpression other) Type = other.Type; } - private void SetParameterTypes( - ISessionFactoryImplementor sessionFactory, - QueryModel queryModel) - { - if (_constantToParameterMap.Count == 0) - { - return; - } - - foreach (var pair in ConstantTypeLocator.GetTypes(queryModel, TargetType, sessionFactory, true)) - { - if (_constantToParameterMap.TryGetValue(pair.Key, out var parameter)) - { - parameter.Type = pair.Value; - } - } - } - private static IASTNode DuplicateTree(IASTNode ast) { var thisNode = ast.DupNode(); diff --git a/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs index 28261d2088c..bee3cb2e960 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs @@ -10,6 +10,7 @@ using NHibernate.Engine; using NHibernate.Param; using NHibernate.Type; +using NHibernate.Util; using Remotion.Linq.Parsing; namespace NHibernate.Linq.Visitors @@ -206,6 +207,19 @@ protected override Expression VisitMember(MemberExpression expression) return expression; } + protected override Expression VisitInvocation(InvocationExpression expression) + { + if (ExpressionsHelper.TryGetDynamicMemberBinder(expression, out var memberBinder)) + { + Visit(expression.Arguments[1]); + _string.Append("."); + _string.Append(memberBinder.Name); + return expression; + } + + return base.VisitInvocation(expression); + } + protected override Expression VisitMethodCall(MethodCallExpression expression) { Visit(expression.Object); diff --git a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs index 51607c40c1a..fc0a2f8ad86 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs @@ -106,6 +106,18 @@ protected override Expression VisitMethodCall(MethodCallExpression expression) return base.VisitMethodCall(expression); } + protected override Expression VisitInvocation(InvocationExpression expression) + { + if (ExpressionsHelper.TryGetDynamicMemberBinder(expression, out _)) + { + // Avoid adding System.Runtime.CompilerServices.CallSite instance as a parameter + base.Visit(expression.Arguments[1]); + return expression; + } + + return base.VisitInvocation(expression); + } + protected override Expression VisitConstant(ConstantExpression expression) { if (!_parameters.ContainsKey(expression) && !typeof(IQueryable).IsAssignableFrom(expression.Type) && !IsNullObject(expression)) diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index cd9cd49eadb..2869388ca7f 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -226,12 +226,7 @@ private HqlTreeNode VisitNhNominated(NhNominatedExpression nhNominatedExpression private HqlTreeNode VisitInvocationExpression(InvocationExpression expression) { - //This is an ugly workaround for dynamic expressions. - //Unfortunately we can not tap into the expression tree earlier to intercept the dynamic expression - if (expression.Arguments.Count == 2 && - expression.Arguments[0] is ConstantExpression constant && - constant.Value is CallSite site && - site.Binder is GetMemberBinder binder) + if (ExpressionsHelper.TryGetDynamicMemberBinder(expression, out var binder)) { return _hqlTreeBuilder.Dot( VisitExpression(expression.Arguments[1]).AsExpression(), diff --git a/src/NHibernate/Linq/Visitors/ConstantTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs similarity index 74% rename from src/NHibernate/Linq/Visitors/ConstantTypeLocator.cs rename to src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 534f6623e78..84a901365e6 100644 --- a/src/NHibernate/Linq/Visitors/ConstantTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using System.Linq.Expressions; using NHibernate.Engine; +using NHibernate.Param; using NHibernate.Type; using NHibernate.Util; using Remotion.Linq; @@ -10,9 +11,9 @@ namespace NHibernate.Linq.Visitors { /// - /// Locates actual type based on its usage. + /// Locates parameter actual type based on its usage. /// - public static class ConstantTypeLocator + public static class ParameterTypeLocator { /// /// List of for which the should be related to the other side @@ -43,38 +44,49 @@ public static class ConstantTypeLocator }; /// - /// Calculate constant expressions types inside the query model. + /// Set query parameter types based on the given query model. /// + /// The query parameters. /// The query model. /// The target entity type. /// The session factory. - /// - public static Dictionary GetTypes( + public static void SetParameterTypes( + IDictionary parameters, QueryModel queryModel, System.Type targetType, ISessionFactoryImplementor sessionFactory) { - return GetTypes(queryModel, targetType, sessionFactory, false); + SetParameterTypes(parameters, queryModel, targetType, sessionFactory, false); } - internal static Dictionary GetTypes( + internal static void SetParameterTypes( + IDictionary parameters, QueryModel queryModel, System.Type targetType, ISessionFactoryImplementor sessionFactory, bool removeMappedAsCalls) { - var types = new Dictionary(); - var visitor = new ConstantTypeLocatorVisitor(removeMappedAsCalls, targetType, sessionFactory); + if (parameters.Count == 0) + { + return; + } + + var visitor = new ConstantTypeLocatorVisitor(removeMappedAsCalls, targetType, parameters, sessionFactory); queryModel.TransformExpressions(visitor.Visit); foreach (var pair in visitor.ConstantExpressions) { var type = pair.Value; var constantExpression = pair.Key; + if (!parameters.TryGetValue(constantExpression, out var namedParameter)) + { + continue; + } + if (type != null) { // MappedAs was used - types.Add(constantExpression, type); + namedParameter.Type = type; continue; } @@ -107,30 +119,31 @@ internal static Dictionary GetTypes( : ParameterHelper.TryGuessType(constantExpression.Type, sessionFactory, out _); } - types.Add(constantExpression, type); + namedParameter.Type = type; } - - return types; } private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor { private readonly bool _removeMappedAsCalls; private readonly System.Type _targetType; + private readonly IDictionary _parameters; private readonly ISessionFactoryImplementor _sessionFactory; public readonly Dictionary ConstantExpressions = new Dictionary(); - public readonly Dictionary> RelatedExpressions = - new Dictionary>(); + public readonly Dictionary> RelatedExpressions = + new Dictionary>(); public ConstantTypeLocatorVisitor( bool removeMappedAsCalls, System.Type targetType, + IDictionary parameters, ISessionFactoryImplementor sessionFactory) { _removeMappedAsCalls = removeMappedAsCalls; _targetType = targetType; _sessionFactory = sessionFactory; + _parameters = parameters; } protected override Expression VisitBinary(BinaryExpression node) @@ -149,8 +162,8 @@ protected override Expression VisitBinary(BinaryExpression node) } else { - AddRelatedMemberExpression(node, left, right); - AddRelatedMemberExpression(node, right, left); + AddRelatedExpression(node, left, right); + AddRelatedExpression(node, right, left); } return node; @@ -161,18 +174,12 @@ protected override Expression VisitConditional(ConditionalExpression node) node = (ConditionalExpression) base.VisitConditional(node); var ifTrue = Unwrap(node.IfTrue); var ifFalse = Unwrap(node.IfFalse); - AddRelatedMemberExpression(node, ifTrue, ifFalse); - AddRelatedMemberExpression(node, ifFalse, ifTrue); + AddRelatedExpression(node, ifTrue, ifFalse); + AddRelatedExpression(node, ifFalse, ifTrue); return node; } - protected override MemberAssignment VisitMemberAssignment(MemberAssignment node) - { - node = base.VisitMemberAssignment(node); - return node; - } - protected override Expression VisitMethodCall(MethodCallExpression node) { if (VisitorUtil.IsMappedAs(node.Method)) @@ -201,12 +208,12 @@ protected override Expression VisitMethodCall(MethodCallExpression node) protected override Expression VisitConstant(ConstantExpression node) { - if (node.Value is IEntityNameProvider || RelatedExpressions.ContainsKey(node)) + if (node.Value is IEntityNameProvider || RelatedExpressions.ContainsKey(node) || !_parameters.ContainsKey(node)) { return node; } - RelatedExpressions.Add(node, new HashSet()); + RelatedExpressions.Add(node, new HashSet()); ConstantExpressions.Add(node, null); return node; } @@ -241,42 +248,56 @@ private void VisitAssign(Expression leftNode, Expression rightNode) ConstantExpressions[constantExpression] = persister.EntityMetamodel.GetPropertyType(parameterExpression.Name); } - private void AddRelatedMemberExpression(Expression node, Expression left, Expression right) + private void AddRelatedExpression(Expression node, Expression left, Expression right) { - HashSet set; - if (left is MemberExpression leftMemberExpression) + if (left.NodeType == ExpressionType.MemberAccess || IsDynamicMember(left)) { - AddMemberExpression(right, leftMemberExpression); + AddRelatedExpression(right, left); if (NonVoidOperators.Contains(node.NodeType)) { - AddMemberExpression(node, leftMemberExpression); + AddRelatedExpression(node, left); } } // Copy all found MemberExpressions to the other side // (e.g. (o.Prop ?? constant1) == constant2 -> copy o.Prop to constant2) - if (RelatedExpressions.TryGetValue(left, out set)) + if (RelatedExpressions.TryGetValue(left, out var set)) { foreach (var nestedMemberExpression in set) { - AddMemberExpression(right, nestedMemberExpression); + AddRelatedExpression(right, nestedMemberExpression); if (NonVoidOperators.Contains(node.NodeType)) { - AddMemberExpression(node, nestedMemberExpression); + AddRelatedExpression(node, nestedMemberExpression); } } } } - private void AddMemberExpression(Expression expression, MemberExpression memberExpression) + private void AddRelatedExpression(Expression expression, Expression relatedExpression) { if (!RelatedExpressions.TryGetValue(expression, out var set)) { - set = new HashSet(); + set = new HashSet(); RelatedExpressions.Add(expression, set); } - set.Add(memberExpression); + set.Add(relatedExpression); + } + + private bool IsDynamicMember(Expression expression) + { + switch (expression) + { + case InvocationExpression invocationExpression: + // session.Query().Where("Properties.Name == @0", "First Product") + return ExpressionsHelper.TryGetDynamicMemberBinder(invocationExpression, out _); + case MethodCallExpression methodCallExpression: + // session.Query() where p.Properties["Name"] == "First Product" select p + return VisitorUtil.TryGetPotentialDynamicComponentDictionaryMember(methodCallExpression, out _); + default: + return false; + } } private static Expression Unwrap(Expression expression) diff --git a/src/NHibernate/Linq/Visitors/VisitorUtil.cs b/src/NHibernate/Linq/Visitors/VisitorUtil.cs index 885d2c66a4e..40dbaeb4fb7 100644 --- a/src/NHibernate/Linq/Visitors/VisitorUtil.cs +++ b/src/NHibernate/Linq/Visitors/VisitorUtil.cs @@ -13,25 +13,12 @@ public static class VisitorUtil { public static bool IsDynamicComponentDictionaryGetter(MethodInfo method, Expression targetObject, IEnumerable arguments, ISessionFactory sessionFactory, out string memberName) { - memberName = null; - - // A dynamic component must be an IDictionary with a string key. - - if (method.Name != "get_Item" || !typeof(IDictionary).IsAssignableFrom(targetObject.Type) && !typeof(IDictionary).IsAssignableFrom(targetObject.Type)) - return false; - - var key = arguments.First() as ConstantExpression; - if (key == null || key.Type != typeof(string)) - return false; - - // The potential member name - memberName = (string)key.Value; - - // Need the owning member (the dictionary). - var member = targetObject as MemberExpression; - if (member == null) + if (!TryGetPotentialDynamicComponentDictionaryMember(method, targetObject, arguments, out memberName)) + { return false; + } + var member = (MemberExpression) targetObject; var memberPath = member.Member.Name; var metaData = sessionFactory.GetClassMetadata(member.Expression.Type); @@ -132,6 +119,37 @@ public static string GetMemberPath(this MemberExpression memberExpression) return path; } + internal static bool TryGetPotentialDynamicComponentDictionaryMember(MethodCallExpression expression, out string memberName) + { + return TryGetPotentialDynamicComponentDictionaryMember( + expression.Method, + expression.Object, + expression.Arguments, + out memberName); + } + + internal static bool TryGetPotentialDynamicComponentDictionaryMember( + MethodInfo method, + Expression targetObject, + IEnumerable arguments, + out string memberName) + { + memberName = null; + // A dynamic component must be an IDictionary with a string key. + if (method.Name != "get_Item" || + targetObject.NodeType != ExpressionType.MemberAccess || // Need the owning member (the dictionary). + !(arguments.First() is ConstantExpression key) || + key.Type != typeof(string) || + (!typeof(IDictionary).IsAssignableFrom(targetObject.Type) && !typeof(IDictionary).IsAssignableFrom(targetObject.Type))) + { + return false; + } + + // The potential member name + memberName = (string) key.Value; + return true; + } + internal static bool IsMappedAs(MethodInfo methodInfo) { return methodInfo.Name == nameof(LinqExtensionMethods.MappedAs) && diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index 08a60aeeb66..e9aa1c23735 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Dynamic; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -30,6 +31,29 @@ public static MemberInfo DecodeMemberAccessExpression(Expressi return ((MemberExpression)expression.Body).Member; } + /// + /// Try to retrieve from a reduced expression. + /// + /// The reduced dynamic expression. + /// The out binder parameter. + /// Whether the binder was found. + internal static bool TryGetDynamicMemberBinder(InvocationExpression expression, out GetMemberBinder memberBinder) + { + // This is an ugly workaround for dynamic expressions. + // Unfortunately we can not tap into the expression tree earlier to intercept the dynamic expression + if (expression.Arguments.Count == 2 && + expression.Arguments[0] is ConstantExpression constant && + constant.Value is CallSite site && + site.Binder is GetMemberBinder binder) + { + memberBinder = binder; + return true; + } + + memberBinder = null; + return false; + } + /// /// Check whether the given expression represent a variable. /// @@ -635,6 +659,19 @@ protected override Expression VisitMember(MemberExpression node) return base.Visit(node.Expression); } + protected override Expression VisitInvocation(InvocationExpression node) + { + if (TryGetDynamicMemberBinder(node, out var binder)) + { + _memberPaths.Push(new MemberMetadata(binder.Name, _convertType, _hasIndexer)); + _convertType = null; + _hasIndexer = false; + return base.Visit(node.Arguments[1]); + } + + return base.Visit(node); + } + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression node) { if (node.ReferencedQuerySource is IFromClause fromClause) @@ -721,6 +758,14 @@ protected override Expression VisitMethodCall(MethodCallExpression node) ); } + if (VisitorUtil.TryGetPotentialDynamicComponentDictionaryMember(node, out var memberName)) + { + _memberPaths.Push(new MemberMetadata(memberName, _convertType, _hasIndexer)); + _convertType = null; + _hasIndexer = false; + return base.Visit(node.Object); + } + return Visit(node); } From e97798231bf8f41eedafa346d2ff5e1d166385cd Mon Sep 17 00:00:00 2001 From: maca88 Date: Sun, 10 May 2020 21:17:05 +0200 Subject: [PATCH 04/11] Fix broken test --- .../Linq/ParameterTypeLocatorTests.cs | 2 +- src/NHibernate/Linq/NhLinqExpression.cs | 1 - .../Linq/Visitors/ExpressionKeyVisitor.cs | 7 +++--- .../Visitors/ExpressionParameterVisitor.cs | 2 ++ .../Visitors/HqlGeneratorExpressionVisitor.cs | 3 ++- .../Linq/Visitors/ParameterTypeLocator.cs | 5 ++++ src/NHibernate/Util/ExpressionsHelper.cs | 24 +++++++++++++++++-- 7 files changed, 36 insertions(+), 8 deletions(-) diff --git a/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs b/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs index e6a46ddf27b..1c0727f3be7 100644 --- a/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs @@ -415,7 +415,7 @@ private void AssertResult( CultureInfo.CurrentCulture = CultureInfo.InvariantCulture; var expressionText = pair.Key.ToString(); Assert.That(expectedResults.ContainsKey(expressionText), Is.True, $"{expressionText} constant is not expected"); - Assert.That(expectedResults[expressionText](pair.Value.Type), Is.True, $"Invalid type, actual type: {pair.Value?.Name ?? "null"}"); + Assert.That(expectedResults[expressionText](pair.Value.Type), Is.True, $"Invalid type, actual type: {pair.Value?.Type?.Name ?? "null"}"); } finally { diff --git a/src/NHibernate/Linq/NhLinqExpression.cs b/src/NHibernate/Linq/NhLinqExpression.cs index cfb0b1d54b0..671d949fc02 100644 --- a/src/NHibernate/Linq/NhLinqExpression.cs +++ b/src/NHibernate/Linq/NhLinqExpression.cs @@ -8,7 +8,6 @@ using NHibernate.Linq.Visitors; using NHibernate.Param; using NHibernate.Type; -using Remotion.Linq; namespace NHibernate.Linq { diff --git a/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs index bee3cb2e960..fdce29edd9a 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs @@ -207,18 +207,19 @@ protected override Expression VisitMember(MemberExpression expression) return expression; } +#if NETCOREAPP2_0 protected override Expression VisitInvocation(InvocationExpression expression) { if (ExpressionsHelper.TryGetDynamicMemberBinder(expression, out var memberBinder)) { Visit(expression.Arguments[1]); - _string.Append("."); - _string.Append(memberBinder.Name); + FormatBinder(memberBinder); return expression; } return base.VisitInvocation(expression); } +#endif protected override Expression VisitMethodCall(MethodCallExpression expression) { @@ -279,8 +280,8 @@ protected override Expression VisitQuerySourceReference(Remotion.Linq.Clauses.Ex protected override Expression VisitDynamic(DynamicExpression expression) { - FormatBinder(expression.Binder); Visit(expression.Arguments, AppendCommas); + FormatBinder(expression.Binder); return expression; } diff --git a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs index fc0a2f8ad86..b5f99196ef0 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs @@ -106,6 +106,7 @@ protected override Expression VisitMethodCall(MethodCallExpression expression) return base.VisitMethodCall(expression); } +#if NETCOREAPP2_0 protected override Expression VisitInvocation(InvocationExpression expression) { if (ExpressionsHelper.TryGetDynamicMemberBinder(expression, out _)) @@ -117,6 +118,7 @@ protected override Expression VisitInvocation(InvocationExpression expression) return base.VisitInvocation(expression); } +#endif protected override Expression VisitConstant(ConstantExpression expression) { diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index 2869388ca7f..a16968cccb2 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -226,13 +226,14 @@ private HqlTreeNode VisitNhNominated(NhNominatedExpression nhNominatedExpression private HqlTreeNode VisitInvocationExpression(InvocationExpression expression) { +#if NETCOREAPP2_0 if (ExpressionsHelper.TryGetDynamicMemberBinder(expression, out var binder)) { return _hqlTreeBuilder.Dot( VisitExpression(expression.Arguments[1]).AsExpression(), _hqlTreeBuilder.Ident(binder.Name)); } - +#endif return VisitExpression(expression.Expression); } diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 84a901365e6..34e492db45b 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using System.Dynamic; using System.Linq.Expressions; using NHibernate.Engine; using NHibernate.Param; @@ -289,9 +290,13 @@ private bool IsDynamicMember(Expression expression) { switch (expression) { +#if NETCOREAPP2_0 case InvocationExpression invocationExpression: // session.Query().Where("Properties.Name == @0", "First Product") return ExpressionsHelper.TryGetDynamicMemberBinder(invocationExpression, out _); +#endif + case DynamicExpression dynamicExpression: + return dynamicExpression.Binder is GetMemberBinder; case MethodCallExpression methodCallExpression: // session.Query() where p.Properties["Name"] == "First Product" select p return VisitorUtil.TryGetPotentialDynamicComponentDictionaryMember(methodCallExpression, out _); diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index e9aa1c23735..eebee36c8dd 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -16,6 +16,7 @@ using NHibernate.Type; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Parsing; namespace NHibernate.Util { @@ -31,6 +32,7 @@ public static MemberInfo DecodeMemberAccessExpression(Expressi return ((MemberExpression)expression.Body).Member; } +#if NETCOREAPP2_0 /// /// Try to retrieve from a reduced expression. /// @@ -39,7 +41,9 @@ public static MemberInfo DecodeMemberAccessExpression(Expressi /// Whether the binder was found. internal static bool TryGetDynamicMemberBinder(InvocationExpression expression, out GetMemberBinder memberBinder) { - // This is an ugly workaround for dynamic expressions. + // This is an ugly workaround for dynamic expressions in .NET Core. In .NET Core a dynamic expression is reduced + // when first visited by a expression visitor that is not a DynamicExpressionVisitor, where in .NET Framework it is never reduced. + // As RelinqExpressionVisitor does not extend DynamicExpressionVisitor, we will always have a reduced dynamic expression in .NET Core. // Unfortunately we can not tap into the expression tree earlier to intercept the dynamic expression if (expression.Arguments.Count == 2 && expression.Arguments[0] is ConstantExpression constant && @@ -53,6 +57,7 @@ constant.Value is CallSite site && memberBinder = null; return false; } +#endif /// /// Check whether the given expression represent a variable. @@ -659,6 +664,7 @@ protected override Expression VisitMember(MemberExpression node) return base.Visit(node.Expression); } +#if NETCOREAPP2_0 protected override Expression VisitInvocation(InvocationExpression node) { if (TryGetDynamicMemberBinder(node, out var binder)) @@ -669,7 +675,21 @@ protected override Expression VisitInvocation(InvocationExpression node) return base.Visit(node.Arguments[1]); } - return base.Visit(node); + return base.VisitInvocation(node); + } +#endif + + protected override Expression VisitDynamic(DynamicExpression node) + { + if (node.Binder is GetMemberBinder binder) + { + _memberPaths.Push(new MemberMetadata(binder.Name, _convertType, _hasIndexer)); + _convertType = null; + _hasIndexer = false; + return base.Visit(node.Arguments[0]); + } + + return Visit(node); } protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression node) From 35c511e804553f7f77e88d41b57f69b7e7eb3f3b Mon Sep 17 00:00:00 2001 From: maca88 Date: Sun, 10 May 2020 21:54:38 +0200 Subject: [PATCH 05/11] Code review changes --- src/NHibernate/Driver/SqlClientDriver.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/NHibernate/Driver/SqlClientDriver.cs b/src/NHibernate/Driver/SqlClientDriver.cs index ec876494370..1002005395d 100644 --- a/src/NHibernate/Driver/SqlClientDriver.cs +++ b/src/NHibernate/Driver/SqlClientDriver.cs @@ -295,7 +295,7 @@ protected static bool IsBlob(DbParameter dbParam, SqlType sqlType) /// True, if the parameter should be interpreted as a character, otherwise False protected static bool IsChar(DbParameter dbParam, SqlType sqlType) { - return (DbType.StringFixedLength == dbParam.DbType || DbType.StringFixedLength == dbParam.DbType) && + return (DbType.StringFixedLength == dbParam.DbType || DbType.AnsiStringFixedLength == dbParam.DbType) && sqlType.LengthDefined && sqlType.Length == 1; } From e7101722c33d367a0ada9fe13dffd9f5e55d6d92 Mon Sep 17 00:00:00 2001 From: maca88 Date: Mon, 11 May 2020 03:21:56 +0200 Subject: [PATCH 06/11] Fix IEnumerable parameters outside Contains --- .../Northwind/Entities/DynamicUser.cs | 7 ++- .../Async/Linq/ParameterTests.cs | 28 ++++++++++ src/NHibernate.Test/Linq/ParameterTests.cs | 28 ++++++++++ .../Linq/ParameterTypeLocatorTests.cs | 17 +++++++ .../Visitors/ExpressionParameterVisitor.cs | 27 ++++++---- .../Linq/Visitors/ParameterTypeLocator.cs | 8 +-- src/NHibernate/Util/ParameterHelper.cs | 51 +++++-------------- 7 files changed, 112 insertions(+), 54 deletions(-) diff --git a/src/NHibernate.DomainModel/Northwind/Entities/DynamicUser.cs b/src/NHibernate.DomainModel/Northwind/Entities/DynamicUser.cs index 974b1ccea7d..c8c9fb3cf03 100644 --- a/src/NHibernate.DomainModel/Northwind/Entities/DynamicUser.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/DynamicUser.cs @@ -2,12 +2,17 @@ namespace NHibernate.DomainModel.Northwind.Entities { - public class DynamicUser + public class DynamicUser : IEnumerable { public virtual int Id { get; set; } public virtual dynamic Properties { get; set; } public virtual IDictionary Settings { get; set; } + + public virtual IEnumerator GetEnumerator() + { + throw new System.NotImplementedException(); + } } } diff --git a/src/NHibernate.Test/Async/Linq/ParameterTests.cs b/src/NHibernate.Test/Async/Linq/ParameterTests.cs index 4fbebe3e78b..0956fdfe92b 100644 --- a/src/NHibernate.Test/Async/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Async/Linq/ParameterTests.cs @@ -88,6 +88,34 @@ public async Task UsingTwoEntityParametersAsync() 2)); } + [Test] + public async Task UsingEntityEnumerableParameterTwiceAsync() + { + if (!Dialect.SupportsSubSelects) + { + Assert.Ignore(); + } + + var enumerable = await (db.DynamicUsers.FirstAsync()); + await (AssertTotalParametersAsync( + db.DynamicUsers.Where(o => o == enumerable && o != enumerable), + 1)); + } + + [Test] + public async Task UsingEntityEnumerableListParameterTwiceAsync() + { + if (!Dialect.SupportsSubSelects) + { + Assert.Ignore(); + } + + var enumerable = new[] {await (db.DynamicUsers.FirstAsync())}; + await (AssertTotalParametersAsync( + db.DynamicUsers.Where(o => enumerable.Contains(o) && enumerable.Contains(o)), + 1)); + } + [Test] public async Task UsingValueTypeParameterTwiceAsync() { diff --git a/src/NHibernate.Test/Linq/ParameterTests.cs b/src/NHibernate.Test/Linq/ParameterTests.cs index 920fa565129..cab27fe9dd5 100644 --- a/src/NHibernate.Test/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTests.cs @@ -76,6 +76,34 @@ public void UsingTwoEntityParameters() 2); } + [Test] + public void UsingEntityEnumerableParameterTwice() + { + if (!Dialect.SupportsSubSelects) + { + Assert.Ignore(); + } + + var enumerable = db.DynamicUsers.First(); + AssertTotalParameters( + db.DynamicUsers.Where(o => o == enumerable && o != enumerable), + 1); + } + + [Test] + public void UsingEntityEnumerableListParameterTwice() + { + if (!Dialect.SupportsSubSelects) + { + Assert.Ignore(); + } + + var enumerable = new[] {db.DynamicUsers.First()}; + AssertTotalParameters( + db.DynamicUsers.Where(o => enumerable.Contains(o) && enumerable.Contains(o)), + 1); + } + [Test] public void UsingValueTypeParameterTwice() { diff --git a/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs b/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs index 1c0727f3be7..2cb87bc50b2 100644 --- a/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs @@ -97,6 +97,23 @@ public void EqualStringTest() ); } + [Test] + public void EqualEntityTest() + { + var order = new Order(); + AssertResults( + new Dictionary> + { + { + $"value({typeof(Order).FullName})", + o => o is ManyToOneType manyToOne && manyToOne.Name == typeof(Order).FullName + } + }, + db.Orders.Where(o => o == order), + db.Orders.Where(o => order == o) + ); + } + [Test] public void DoubleEqualTest() { diff --git a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs index b5f99196ef0..3a0cfc0174a 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs @@ -5,6 +5,7 @@ using System.Linq.Expressions; using System.Reflection; using NHibernate.Engine; +using NHibernate.Linq.Functions; using NHibernate.Param; using NHibernate.Type; using NHibernate.Util; @@ -19,8 +20,10 @@ public class ExpressionParameterVisitor : RelinqExpressionVisitor { private readonly Dictionary _parameters = new Dictionary(); private readonly Dictionary _variableParameters = new Dictionary(); + private readonly HashSet _collectionParameters = new HashSet(); private readonly IDictionary _queryVariables; private readonly ISessionFactoryImplementor _sessionFactory; + private readonly ILinqToHqlGeneratorsRegistry _functionRegistry; private static readonly ISet PagingMethods = new HashSet { @@ -41,6 +44,7 @@ public ExpressionParameterVisitor(PreTransformationResult preTransformationResul { _sessionFactory = preTransformationResult.SessionFactory; _queryVariables = preTransformationResult.QueryVariables; + _functionRegistry = _sessionFactory.Settings.LinqToHqlGeneratorsRegistry; } // Since v5.3 @@ -98,6 +102,17 @@ protected override Expression VisitMethodCall(MethodCallExpression expression) return Expression.Call(null, expression.Method, query, arg); } + if (_functionRegistry != null && + _functionRegistry.TryGetGenerator(method, out var generator) && + generator is CollectionContainsGenerator) + { + var argument = method.IsStatic ? expression.Arguments[0] : expression.Object; + if (argument is ConstantExpression constantExpression) + { + _collectionParameters.Add(constantExpression); + } + } + if (VisitorUtil.IsDynamicComponentDictionaryGetter(expression, _sessionFactory)) { return expression; @@ -172,7 +187,7 @@ protected override Expression VisitConstant(ConstantExpression expression) private NamedParameter CreateParameter(ConstantExpression expression, object value, IType type) { var parameterName = "p" + (_parameters.Count + 1); - return IsCollectionType(expression) + return _collectionParameters.Contains(expression) ? new NamedListParameter(parameterName, value, type) : new NamedParameter(parameterName, value, type); } @@ -181,15 +196,5 @@ private static bool IsNullObject(ConstantExpression expression) { return expression.Type == typeof(Object) && expression.Value == null; } - - private static bool IsCollectionType(ConstantExpression expression) - { - if (expression.Value != null) - { - return expression.Value is IEnumerable && !(expression.Value is string); - } - - return expression.Type.IsCollectionType(); - } } } diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 34e492db45b..34326640169 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -116,8 +116,8 @@ internal static void SetParameterTypes( if (type == null) { type = constantExpression.Value != null - ? ParameterHelper.TryGuessType(constantExpression.Value, sessionFactory, out _) - : ParameterHelper.TryGuessType(constantExpression.Type, sessionFactory, out _); + ? ParameterHelper.TryGuessType(constantExpression.Value, sessionFactory, namedParameter.IsCollection) + : ParameterHelper.TryGuessType(constantExpression.Type, sessionFactory, namedParameter.IsCollection); } namedParameter.Type = type; @@ -251,7 +251,9 @@ private void VisitAssign(Expression leftNode, Expression rightNode) private void AddRelatedExpression(Expression node, Expression left, Expression right) { - if (left.NodeType == ExpressionType.MemberAccess || IsDynamicMember(left)) + if (left.NodeType == ExpressionType.MemberAccess || + IsDynamicMember(left) || + left is QuerySourceReferenceExpression) { AddRelatedExpression(right, left); if (NonVoidOperators.Contains(node.NodeType)) diff --git a/src/NHibernate/Util/ParameterHelper.cs b/src/NHibernate/Util/ParameterHelper.cs index d201ebccc3c..d0b6bd14625 100644 --- a/src/NHibernate/Util/ParameterHelper.cs +++ b/src/NHibernate/Util/ParameterHelper.cs @@ -14,29 +14,27 @@ internal static class ParameterHelper /// /// The object to guess the of. /// The session factory to search for entity persister. - /// The output parameter that represents whether the is a collection. + /// Whether is a collection. /// An for the object. /// /// Thrown when the param is null because the /// can't be guess from a null value. /// - public static IType TryGuessType(object param, ISessionFactoryImplementor sessionFactory, out bool isCollection) + public static IType TryGuessType(object param, ISessionFactoryImplementor sessionFactory, bool isCollection) { if (param == null) { - throw new ArgumentNullException(nameof(param), "The IType can not be guessed for a null value."); + return null; } - if (param is IEnumerable enumerable && !(param is string)) + if (param is IEnumerable enumerable && isCollection) { var firstValue = enumerable.Cast().FirstOrDefault(); - isCollection = true; return firstValue == null ? TryGuessType(enumerable.GetCollectionElementType(), sessionFactory) - : TryGuessType(firstValue, sessionFactory, out _); + : TryGuessType(firstValue, sessionFactory, false); } - isCollection = false; var clazz = NHibernateProxyHelper.GetClassWithoutInitializingProxy(param); return TryGuessType(clazz, sessionFactory); } @@ -67,26 +65,24 @@ public static IType GuessType(object param, ISessionFactoryImplementor sessionFa /// /// The to guess the of. /// The session factory to search for entity persister. - /// The output parameter that represents whether the is a collection. + /// Whether is a collection. /// An for the . /// /// Thrown when the clazz is null because the /// can't be guess from a null type. /// - public static IType TryGuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory, out bool isCollection) + public static IType TryGuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory, bool isCollection) { if (clazz == null) { - throw new ArgumentNullException(nameof(clazz), "The IType can not be guessed for a null value."); + return null; } - if (clazz.IsCollectionType()) + if (isCollection) { - isCollection = true; - return TryGuessType(ReflectHelper.GetCollectionElementType(clazz), sessionFactory, out _); + return TryGuessType(ReflectHelper.GetCollectionElementType(clazz), sessionFactory, false); } - isCollection = false; return TryGuessType(clazz, sessionFactory); } @@ -95,41 +91,18 @@ public static IType TryGuessType(System.Type clazz, ISessionFactoryImplementor s /// /// The to guess the of. /// The session factory to search for entity persister. - /// The output parameter that represents whether the is a collection. /// An for the . /// /// Thrown when the clazz is null because the /// can't be guess from a null type. /// - public static IType GuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory, out bool isCollection) + public static IType GuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory) { if (clazz == null) { throw new ArgumentNullException(nameof(clazz), "The IType can not be guessed for a null value."); } - if (typeof(IEnumerable).IsAssignableFrom(clazz) && typeof(string) != clazz) - { - isCollection = true; - return GuessType(ReflectHelper.GetCollectionElementType(clazz), sessionFactory); - } - - isCollection = false; - return GuessType(clazz, sessionFactory); - } - - /// - /// Guesses the from the . - /// - /// The to guess the of. - /// The session factory to search for entity persister. - /// An for the . - /// - /// Thrown when the clazz is null because the - /// can't be guess from a null type. - /// - public static IType GuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory) - { return TryGuessType(clazz, sessionFactory) ?? throw new HibernateException("Could not determine a type for class: " + clazz.AssemblyQualifiedName); } @@ -148,7 +121,7 @@ public static IType TryGuessType(System.Type clazz, ISessionFactoryImplementor s { if (clazz == null) { - throw new ArgumentNullException(nameof(clazz), "The IType can not be guessed for a null value."); + return null; } var type = TypeFactory.HeuristicType(clazz); From 87e43e3f57c3c82d12aac4cef4252f30f1c3cbf6 Mon Sep 17 00:00:00 2001 From: maca88 Date: Tue, 19 May 2020 21:08:39 +0200 Subject: [PATCH 07/11] Code review changes --- .../Hql/Ast/ANTLR/Tree/BetweenOperatorNode.cs | 16 +--------------- .../Linq/Visitors/ExpressionParameterVisitor.cs | 9 +++++---- src/NHibernate/Param/NamedListParameter.cs | 13 ------------- src/NHibernate/Param/NamedParameter.cs | 8 +++++++- 4 files changed, 13 insertions(+), 33 deletions(-) delete mode 100644 src/NHibernate/Param/NamedListParameter.cs diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/BetweenOperatorNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/BetweenOperatorNode.cs index 8b909121987..2ff3a10f790 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/BetweenOperatorNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/BetweenOperatorNode.cs @@ -69,21 +69,7 @@ private static void Check(IASTNode check, IASTNode first, IASTNode second) return; } - IType expectedType = null; - if (first is SqlNode firstNode) - { - expectedType = firstNode.DataType; - } - - if (expectedType == null) - { - if (second is SqlNode secondNode) - { - expectedType = secondNode.DataType; - } - } - - expectedTypeAwareNode.ExpectedType = expectedType; + expectedTypeAwareNode.ExpectedType = (first as SqlNode)?.DataType ?? (second as SqlNode)?.DataType; } } } diff --git a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs index 3a0cfc0174a..e23f2388a01 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs @@ -186,10 +186,11 @@ protected override Expression VisitConstant(ConstantExpression expression) private NamedParameter CreateParameter(ConstantExpression expression, object value, IType type) { - var parameterName = "p" + (_parameters.Count + 1); - return _collectionParameters.Contains(expression) - ? new NamedListParameter(parameterName, value, type) - : new NamedParameter(parameterName, value, type); + return new NamedParameter( + "p" + (_parameters.Count + 1), + value, + type, + _collectionParameters.Contains(expression)); } private static bool IsNullObject(ConstantExpression expression) diff --git a/src/NHibernate/Param/NamedListParameter.cs b/src/NHibernate/Param/NamedListParameter.cs deleted file mode 100644 index 9e02fd442b6..00000000000 --- a/src/NHibernate/Param/NamedListParameter.cs +++ /dev/null @@ -1,13 +0,0 @@ -using NHibernate.Type; - -namespace NHibernate.Param -{ - internal class NamedListParameter : NamedParameter - { - public NamedListParameter(string name, object value, IType elementType) : base(name, value, elementType) - { - } - - public override bool IsCollection => true; - } -} diff --git a/src/NHibernate/Param/NamedParameter.cs b/src/NHibernate/Param/NamedParameter.cs index 95f6604ed33..a9a9b67de2b 100644 --- a/src/NHibernate/Param/NamedParameter.cs +++ b/src/NHibernate/Param/NamedParameter.cs @@ -5,17 +5,23 @@ namespace NHibernate.Param public class NamedParameter { public NamedParameter(string name, object value, IType type) + : this(name, value, type, false) + { + } + + internal NamedParameter(string name, object value, IType type, bool isCollection) { Name = name; Value = value; Type = type; + IsCollection = isCollection; } public string Name { get; private set; } public object Value { get; internal set; } public IType Type { get; internal set; } - public virtual bool IsCollection => false; + public virtual bool IsCollection { get; } public bool Equals(NamedParameter other) { From 117367b88341d8e5fd6c062fa81037109fda17fb Mon Sep 17 00:00:00 2001 From: maca88 Date: Tue, 19 May 2020 23:20:52 +0200 Subject: [PATCH 08/11] Optimize Contains method check --- src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs index e23f2388a01..3687054bffe 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs @@ -103,6 +103,7 @@ protected override Expression VisitMethodCall(MethodCallExpression expression) } if (_functionRegistry != null && + method.Name == nameof(Queryable.Contains) && _functionRegistry.TryGetGenerator(method, out var generator) && generator is CollectionContainsGenerator) { From 33d88565dc3b700652d299fb47bdc1bb83bb163d Mon Sep 17 00:00:00 2001 From: maca88 Date: Tue, 19 May 2020 23:23:10 +0200 Subject: [PATCH 09/11] fix formatting --- src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs index 3687054bffe..d7224d6260a 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs @@ -102,8 +102,8 @@ protected override Expression VisitMethodCall(MethodCallExpression expression) return Expression.Call(null, expression.Method, query, arg); } - if (_functionRegistry != null && - method.Name == nameof(Queryable.Contains) && + if (_functionRegistry != null && + method.Name == nameof(Queryable.Contains) && _functionRegistry.TryGetGenerator(method, out var generator) && generator is CollectionContainsGenerator) { From 1e63f27514d15f55e94203c954acd5efb9da2651 Mon Sep 17 00:00:00 2001 From: maca88 Date: Mon, 25 May 2020 13:27:30 +0200 Subject: [PATCH 10/11] Update src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs Co-authored-by: Alexander Zaytsev --- src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs index d7224d6260a..fa21d745d31 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs @@ -25,7 +25,7 @@ public class ExpressionParameterVisitor : RelinqExpressionVisitor private readonly ISessionFactoryImplementor _sessionFactory; private readonly ILinqToHqlGeneratorsRegistry _functionRegistry; - private static readonly ISet PagingMethods = new HashSet + private static readonly HashSet PagingMethods = new HashSet { ReflectionCache.EnumerableMethods.SkipDefinition, ReflectionCache.EnumerableMethods.TakeDefinition, From 98137f1b55936395f8067850fd24629fa72f2a3c Mon Sep 17 00:00:00 2001 From: maca88 Date: Mon, 25 May 2020 13:47:01 +0200 Subject: [PATCH 11/11] Generalize detection of collection parameters --- .../Functions/BaseHqlGeneratorForMethod.cs | 7 ++++++ .../Functions/GetValueOrDefaultGenerator.cs | 7 ++++++ .../Linq/Functions/IHqlGeneratorForMethod.cs | 23 +++++++++++++++++++ .../Linq/Functions/QueryableGenerator.cs | 9 ++++++++ .../Linq/Functions/StringGenerator.cs | 7 ++++++ .../Visitors/ExpressionParameterVisitor.cs | 13 ++++------- 6 files changed, 57 insertions(+), 9 deletions(-) diff --git a/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs b/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs index bf97cc54cb7..aa0f1ab88f8 100644 --- a/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs +++ b/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs @@ -19,6 +19,13 @@ public abstract class BaseHqlGeneratorForMethod : IHqlGeneratorForMethod, IHqlGe public virtual bool AllowsNullableReturnType(MethodInfo method) => true; + /// + public virtual bool TryGetCollectionParameter(MethodCallExpression expression, out ConstantExpression collectionParameter) + { + collectionParameter = null; + return false; + } + private protected static void LogIgnoredParameter(MethodInfo method, string paramType) { if (Log.IsWarnEnabled()) diff --git a/src/NHibernate/Linq/Functions/GetValueOrDefaultGenerator.cs b/src/NHibernate/Linq/Functions/GetValueOrDefaultGenerator.cs index 33cb12c2c6c..cc0fa7202b9 100644 --- a/src/NHibernate/Linq/Functions/GetValueOrDefaultGenerator.cs +++ b/src/NHibernate/Linq/Functions/GetValueOrDefaultGenerator.cs @@ -42,5 +42,12 @@ private static HqlExpression GetRhs(MethodInfo method, ReadOnlyCollection !method.ReturnType.IsValueType; + + /// + public bool TryGetCollectionParameter(MethodCallExpression expression, out ConstantExpression collectionParameter) + { + collectionParameter = null; + return false; + } } } diff --git a/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs b/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs index 73ad8b3d9e4..3ef7583ee6b 100644 --- a/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs +++ b/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs @@ -18,6 +18,14 @@ public interface IHqlGeneratorForMethod internal interface IHqlGeneratorForMethodExtended { bool AllowsNullableReturnType(MethodInfo method); + + /// + /// Try getting a collection parameter from . + /// + /// The method call expression. + /// Output parameter for the retrieved collection parameter. + /// Whether collection parameter was retrieved. + bool TryGetCollectionParameter(MethodCallExpression expression, out ConstantExpression collectionParameter); } internal static class HqlGeneratorForMethodExtensions @@ -33,6 +41,21 @@ public static bool AllowsNullableReturnType(this IHqlGeneratorForMethod generato return true; } + // 6.0 TODO: Remove + public static bool TryGetCollectionParameters( + this IHqlGeneratorForMethod generator, + MethodCallExpression expression, + out ConstantExpression collectionParameter) + { + if (generator is IHqlGeneratorForMethodExtended extendedGenerator) + { + return extendedGenerator.TryGetCollectionParameter(expression, out collectionParameter); + } + + collectionParameter = null; + return false; + } + // 6.0 TODO: merge into IHqlGeneratorForMethod /// /// Should pre-evaluation be allowed for this method? diff --git a/src/NHibernate/Linq/Functions/QueryableGenerator.cs b/src/NHibernate/Linq/Functions/QueryableGenerator.cs index f007fa22592..eaf44c64f4a 100644 --- a/src/NHibernate/Linq/Functions/QueryableGenerator.cs +++ b/src/NHibernate/Linq/Functions/QueryableGenerator.cs @@ -155,6 +155,15 @@ public CollectionContainsGenerator() public override bool AllowsNullableReturnType(MethodInfo method) => false; + /// + public override bool TryGetCollectionParameter(MethodCallExpression expression, out ConstantExpression collectionParameter) + { + var argument = expression.Method.IsStatic ? expression.Arguments[0] : expression.Object; + collectionParameter = argument as ConstantExpression; + + return collectionParameter != null; + } + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) { // TODO - alias generator diff --git a/src/NHibernate/Linq/Functions/StringGenerator.cs b/src/NHibernate/Linq/Functions/StringGenerator.cs index 5e358f22194..3bb0523a57c 100644 --- a/src/NHibernate/Linq/Functions/StringGenerator.cs +++ b/src/NHibernate/Linq/Functions/StringGenerator.cs @@ -59,6 +59,13 @@ public IHqlGeneratorForMethod GetMethodGenerator(MethodInfo method) } public bool AllowsNullableReturnType(MethodInfo method) => false; + + /// + public bool TryGetCollectionParameter(MethodCallExpression expression, out ConstantExpression collectionParameter) + { + collectionParameter = null; + return false; + } } public class LengthGenerator : BaseHqlGeneratorForProperty diff --git a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs index fa21d745d31..f6a9e5de43f 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs @@ -102,16 +102,11 @@ protected override Expression VisitMethodCall(MethodCallExpression expression) return Expression.Call(null, expression.Method, query, arg); } - if (_functionRegistry != null && - method.Name == nameof(Queryable.Contains) && - _functionRegistry.TryGetGenerator(method, out var generator) && - generator is CollectionContainsGenerator) + if (_functionRegistry != null && + _functionRegistry.TryGetGenerator(method, out var generator) && + generator.TryGetCollectionParameters(expression, out var collectionParameter)) { - var argument = method.IsStatic ? expression.Arguments[0] : expression.Object; - if (argument is ConstantExpression constantExpression) - { - _collectionParameters.Add(constantExpression); - } + _collectionParameters.Add(collectionParameter); } if (VisitorUtil.IsDynamicComponentDictionaryGetter(expression, _sessionFactory))