diff --git a/src/NHibernate.DomainModel/Northwind/Entities/DynamicUser.cs b/src/NHibernate.DomainModel/Northwind/Entities/DynamicUser.cs new file mode 100644 index 00000000000..c8c9fb3cf03 --- /dev/null +++ b/src/NHibernate.DomainModel/Northwind/Entities/DynamicUser.cs @@ -0,0 +1,18 @@ +using System.Collections; + +namespace NHibernate.DomainModel.Northwind.Entities +{ + public class DynamicUser : IEnumerable + { + public virtual int Id { get; set; } + + public virtual dynamic Properties { get; set; } + + public virtual IDictionary Settings { get; set; } + + public virtual IEnumerator GetEnumerator() + { + throw new System.NotImplementedException(); + } + } +} diff --git a/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs b/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs index c4cbda23f26..4551ce0e9d8 100755 --- a/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs @@ -69,6 +69,11 @@ public IQueryable Users get { return _session.Query(); } } + public IQueryable DynamicUsers + { + get { return _session.Query(); } + } + public IQueryable PatientRecords { get { return _session.Query(); } diff --git a/src/NHibernate.DomainModel/Northwind/Entities/User.cs b/src/NHibernate.DomainModel/Northwind/Entities/User.cs index c23e667be9b..14096dac912 100644 --- a/src/NHibernate.DomainModel/Northwind/Entities/User.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/User.cs @@ -48,10 +48,16 @@ public class User : IUser, IEntity public virtual FeatureSet Features { get; set; } + public virtual User NotMappedUser => this; + public virtual EnumStoredAsString Enum1 { get; set; } + public virtual EnumStoredAsString? NullableEnum1 { get; set; } + public virtual EnumStoredAsInt32 Enum2 { get; set; } + public virtual EnumStoredAsInt32? NullableEnum2 { get; set; } + public virtual IUser CreatedBy { get; set; } public virtual IUser ModifiedBy { get; set; } diff --git a/src/NHibernate.DomainModel/Northwind/Mappings/DynamicUser.hbm.xml b/src/NHibernate.DomainModel/Northwind/Mappings/DynamicUser.hbm.xml new file mode 100644 index 00000000000..1b6775b29c6 --- /dev/null +++ b/src/NHibernate.DomainModel/Northwind/Mappings/DynamicUser.hbm.xml @@ -0,0 +1,30 @@ + + + + + select * from Users + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml b/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml index 2764cb70898..f249de9574e 100644 --- a/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml +++ b/src/NHibernate.DomainModel/Northwind/Mappings/User.hbm.xml @@ -24,8 +24,14 @@ + + + + + diff --git a/src/NHibernate.Test/Async/Linq/EnumTests.cs b/src/NHibernate.Test/Async/Linq/EnumTests.cs index 622a806ed30..6e9355d294c 100644 --- a/src/NHibernate.Test/Async/Linq/EnumTests.cs +++ b/src/NHibernate.Test/Async/Linq/EnumTests.cs @@ -61,5 +61,42 @@ public async Task CanQueryOnEnumStoredAsString_Small_1Async() Assert.AreEqual(expectedCount, query.Count); } + + [Test] + public async Task ConditionalNavigationPropertyAsync() + { + EnumStoredAsString? type = null; + await (db.Users.Where(o => o.Enum1 == EnumStoredAsString.Large).ToListAsync()); + await (db.Users.Where(o => EnumStoredAsString.Large != o.Enum1).ToListAsync()); + await (db.Users.Where(o => (o.NullableEnum1 ?? EnumStoredAsString.Large) == EnumStoredAsString.Medium).ToListAsync()); + await (db.Users.Where(o => ((o.NullableEnum1 ?? type) ?? o.Enum1) == EnumStoredAsString.Medium).ToListAsync()); + + await (db.Users.Where(o => (o.NullableEnum1.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) == EnumStoredAsString.Medium).ToListAsync()); + await (db.Users.Where(o => (o.Enum1 != EnumStoredAsString.Large + ? (o.NullableEnum1.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) + : EnumStoredAsString.Small) == EnumStoredAsString.Medium).ToListAsync()); + + await (db.Users.Where(o => (o.Enum1 == EnumStoredAsString.Large ? o.Role : o.Role).Name == "test").ToListAsync()); + } + + [Test] + public async Task CanQueryComplexExpressionOnEnumStoredAsStringAsync() + { + var type = EnumStoredAsString.Unspecified; + var query = await ((from user in db.Users + where (user.NullableEnum1 == EnumStoredAsString.Large + ? EnumStoredAsString.Medium + : user.NullableEnum1 ?? user.Enum1 + ) == type + select new + { + user, + simple = user.Enum1, + condition = user.Enum1 == EnumStoredAsString.Large ? EnumStoredAsString.Medium : user.Enum1, + coalesce = user.NullableEnum1 ?? EnumStoredAsString.Medium + }).ToListAsync()); + + Assert.That(query.Count, Is.EqualTo(0)); + } } } diff --git a/src/NHibernate.Test/Async/Linq/ParameterTests.cs b/src/NHibernate.Test/Async/Linq/ParameterTests.cs index 4fbebe3e78b..0956fdfe92b 100644 --- a/src/NHibernate.Test/Async/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Async/Linq/ParameterTests.cs @@ -88,6 +88,34 @@ public async Task UsingTwoEntityParametersAsync() 2)); } + [Test] + public async Task UsingEntityEnumerableParameterTwiceAsync() + { + if (!Dialect.SupportsSubSelects) + { + Assert.Ignore(); + } + + var enumerable = await (db.DynamicUsers.FirstAsync()); + await (AssertTotalParametersAsync( + db.DynamicUsers.Where(o => o == enumerable && o != enumerable), + 1)); + } + + [Test] + public async Task UsingEntityEnumerableListParameterTwiceAsync() + { + if (!Dialect.SupportsSubSelects) + { + Assert.Ignore(); + } + + var enumerable = new[] {await (db.DynamicUsers.FirstAsync())}; + await (AssertTotalParametersAsync( + db.DynamicUsers.Where(o => enumerable.Contains(o) && enumerable.Contains(o)), + 1)); + } + [Test] public async Task UsingValueTypeParameterTwiceAsync() { diff --git a/src/NHibernate.Test/Linq/ConstantTest.cs b/src/NHibernate.Test/Linq/ConstantTest.cs index 6b693ddbc4a..f75bf3a9070 100644 --- a/src/NHibernate.Test/Linq/ConstantTest.cs +++ b/src/NHibernate.Test/Linq/ConstantTest.cs @@ -217,12 +217,12 @@ public void ConstantInWhereDoesNotCauseManyKeys() select c); var preTransformParameters = new PreTransformationParameters(QueryMode.Select, Sfi); var preTransformResult = NhRelinqQueryParser.PreTransform(q1.Expression, preTransformParameters); - var expression = ExpressionParameterVisitor.Visit(preTransformResult, out var parameters1); - var k1 = ExpressionKeyVisitor.Visit(expression, parameters1); + var parameters1 = ExpressionParameterVisitor.Visit(preTransformResult); + var k1 = ExpressionKeyVisitor.Visit(preTransformResult.Expression, parameters1, Sfi); var preTransformResult2 = NhRelinqQueryParser.PreTransform(q2.Expression, preTransformParameters); - var expression2 = ExpressionParameterVisitor.Visit(preTransformResult2, out var parameters2); - var k2 = ExpressionKeyVisitor.Visit(expression2, parameters2); + var parameters2 = ExpressionParameterVisitor.Visit(preTransformResult2); + var k2 = ExpressionKeyVisitor.Visit(preTransformResult2.Expression, parameters2, Sfi); Assert.That(parameters1, Has.Count.GreaterThan(0), "parameters1"); Assert.That(parameters2, Has.Count.GreaterThan(0), "parameters2"); diff --git a/src/NHibernate.Test/Linq/EnumTests.cs b/src/NHibernate.Test/Linq/EnumTests.cs index 4050c7ddb97..aeea060b51e 100644 --- a/src/NHibernate.Test/Linq/EnumTests.cs +++ b/src/NHibernate.Test/Linq/EnumTests.cs @@ -48,5 +48,42 @@ public void CanQueryOnEnumStoredAsString(EnumStoredAsString type, int expectedCo Assert.AreEqual(expectedCount, query.Count); } + + [Test] + public void ConditionalNavigationProperty() + { + EnumStoredAsString? type = null; + db.Users.Where(o => o.Enum1 == EnumStoredAsString.Large).ToList(); + db.Users.Where(o => EnumStoredAsString.Large != o.Enum1).ToList(); + db.Users.Where(o => (o.NullableEnum1 ?? EnumStoredAsString.Large) == EnumStoredAsString.Medium).ToList(); + db.Users.Where(o => ((o.NullableEnum1 ?? type) ?? o.Enum1) == EnumStoredAsString.Medium).ToList(); + + db.Users.Where(o => (o.NullableEnum1.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) == EnumStoredAsString.Medium).ToList(); + db.Users.Where(o => (o.Enum1 != EnumStoredAsString.Large + ? (o.NullableEnum1.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) + : EnumStoredAsString.Small) == EnumStoredAsString.Medium).ToList(); + + db.Users.Where(o => (o.Enum1 == EnumStoredAsString.Large ? o.Role : o.Role).Name == "test").ToList(); + } + + [Test] + public void CanQueryComplexExpressionOnEnumStoredAsString() + { + var type = EnumStoredAsString.Unspecified; + var query = (from user in db.Users + where (user.NullableEnum1 == EnumStoredAsString.Large + ? EnumStoredAsString.Medium + : user.NullableEnum1 ?? user.Enum1 + ) == type + select new + { + user, + simple = user.Enum1, + condition = user.Enum1 == EnumStoredAsString.Large ? EnumStoredAsString.Medium : user.Enum1, + coalesce = user.NullableEnum1 ?? EnumStoredAsString.Medium + }).ToList(); + + Assert.That(query.Count, Is.EqualTo(0)); + } } } diff --git a/src/NHibernate.Test/Linq/LinqTestCase.cs b/src/NHibernate.Test/Linq/LinqTestCase.cs index e047732d7ad..daf14b9cd18 100755 --- a/src/NHibernate.Test/Linq/LinqTestCase.cs +++ b/src/NHibernate.Test/Linq/LinqTestCase.cs @@ -34,7 +34,8 @@ protected override string[] Mappings "Northwind.Mappings.User.hbm.xml", "Northwind.Mappings.TimeSheet.hbm.xml", "Northwind.Mappings.Animal.hbm.xml", - "Northwind.Mappings.Patient.hbm.xml" + "Northwind.Mappings.Patient.hbm.xml", + "Northwind.Mappings.DynamicUser.hbm.xml" }; } } diff --git a/src/NHibernate.Test/Linq/ParameterTests.cs b/src/NHibernate.Test/Linq/ParameterTests.cs index 920fa565129..cab27fe9dd5 100644 --- a/src/NHibernate.Test/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTests.cs @@ -76,6 +76,34 @@ public void UsingTwoEntityParameters() 2); } + [Test] + public void UsingEntityEnumerableParameterTwice() + { + if (!Dialect.SupportsSubSelects) + { + Assert.Ignore(); + } + + var enumerable = db.DynamicUsers.First(); + AssertTotalParameters( + db.DynamicUsers.Where(o => o == enumerable && o != enumerable), + 1); + } + + [Test] + public void UsingEntityEnumerableListParameterTwice() + { + if (!Dialect.SupportsSubSelects) + { + Assert.Ignore(); + } + + var enumerable = new[] {db.DynamicUsers.First()}; + AssertTotalParameters( + db.DynamicUsers.Where(o => enumerable.Contains(o) && enumerable.Contains(o)), + 1); + } + [Test] public void UsingValueTypeParameterTwice() { diff --git a/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs b/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs new file mode 100644 index 00000000000..2cb87bc50b2 --- /dev/null +++ b/src/NHibernate.Test/Linq/ParameterTypeLocatorTests.cs @@ -0,0 +1,444 @@ +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Linq.Dynamic.Core; +using System.Linq.Expressions; +using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Engine.Query; +using NHibernate.Linq; +using NHibernate.Linq.Visitors; +using NHibernate.Type; +using NUnit.Framework; +using Remotion.Linq.Clauses; + +namespace NHibernate.Test.Linq +{ + public class ParameterTypeLocatorTests : LinqTestCase + { + [Test] + public void AddIntegerTest() + { + AssertResults( + new Dictionary> + { + {"2.1", o => o is DoubleType}, + {"5", o => o is Int32Type}, + }, + db.Users.Where(o => o.Id + 5 > 2.1), + db.Users.Where(o => 2.1 < 5 + o.Id) + ); + } + + [Test] + public void AddDecimalTest() + { + AssertResults( + new Dictionary> + { + {"2.1", o => o is DecimalType}, + {"5.2", o => o is DecimalType}, + }, + db.Users.Where(o => o.Id + 5.2m > 2.1m), + db.Users.Where(o => 2.1m < 5.2m + o.Id) + ); + } + + [Test] + public void SubtractFloatTest() + { + AssertResults( + new Dictionary> + { + {"2.1", o => o is DoubleType}, + {"5.2", o => o is SingleType}, + }, + db.Users.Where(o => o.Id - 5.2f > 2.1), + db.Users.Where(o => 2.1 < 5.2f - o.Id) + ); + } + + [Test] + public void GreaterThanTest() + { + AssertResults( + new Dictionary> + { + {"2.1", o => o is Int32Type} + }, + db.Users.Where(o => o.Id > 2.1), + db.Users.Where(o => 2.1 > o.Id) + ); + } + + [Test] + public void EqualStringEnumTest() + { + AssertResults( + new Dictionary> + { + {"3", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => o.Enum1 == EnumStoredAsString.Large), + db.Users.Where(o => EnumStoredAsString.Large == o.Enum1) + ); + } + + [Test] + public void EqualStringTest() + { + AssertResults( + new Dictionary> + { + {"\"London\"", o => o is StringType stringType && stringType.SqlType.Length == 15} + }, + db.Orders.Where(o => o.ShippingAddress.City == "London"), + db.Orders.Where(o => "London" == o.ShippingAddress.City) + ); + } + + [Test] + public void EqualEntityTest() + { + var order = new Order(); + AssertResults( + new Dictionary> + { + { + $"value({typeof(Order).FullName})", + o => o is ManyToOneType manyToOne && manyToOne.Name == typeof(Order).FullName + } + }, + db.Orders.Where(o => o == order), + db.Orders.Where(o => order == o) + ); + } + + [Test] + public void DoubleEqualTest() + { + AssertResults( + new Dictionary> + { + {"3", o => o is EnumStoredAsStringType}, + {"1", o => o is PersistentEnumType} + }, + db.Users.Where(o => o.Enum1 == EnumStoredAsString.Large && o.Enum2 == EnumStoredAsInt32.High), + db.Users.Where(o => EnumStoredAsInt32.High == o.Enum2 && EnumStoredAsString.Large == o.Enum1) + ); + } + + [Test] + public void NotEqualTest() + { + AssertResults( + new Dictionary> + { + {"3", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => o.Enum1 != EnumStoredAsString.Large), + db.Users.Where(o => EnumStoredAsString.Large != o.Enum1) + ); + } + + [Test] + public void DoubleNotEqualTest() + { + AssertResults( + new Dictionary> + { + {"3", o => o is EnumStoredAsStringType}, + {"1", o => o is PersistentEnumType} + }, + db.Users.Where(o => o.Enum1 != EnumStoredAsString.Large || o.NullableEnum2 != EnumStoredAsInt32.High), + db.Users.Where(o => EnumStoredAsInt32.High != o.NullableEnum2 || o.Enum1 != EnumStoredAsString.Large) + ); + } + + [Test] + public void CoalesceTest() + { + AssertResults( + new Dictionary> + { + {"2", o => o is EnumStoredAsStringType}, + {"Large", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => (o.NullableEnum1 ?? EnumStoredAsString.Large) == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == (o.NullableEnum1 ?? EnumStoredAsString.Large)) + ); + } + + [Test] + public void DoubleCoalesceTest() + { + AssertResults( + new Dictionary> + { + {"2", o => o is EnumStoredAsStringType}, + {"Large", o => o is EnumStoredAsStringType}, + }, + db.Users.Where(o => ((o.NullableEnum1 ?? (EnumStoredAsString?) EnumStoredAsString.Large) ?? o.Enum1) == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == ((o.NullableEnum1 ?? (EnumStoredAsString?) EnumStoredAsString.Large) ?? o.Enum1)) + ); + } + + [Test] + public void ConditionalTest() + { + AssertResults( + new Dictionary> + { + {"2", o => o is EnumStoredAsStringType}, + {"Unspecified", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => (o.NullableEnum2.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == (o.NullableEnum2.HasValue ? EnumStoredAsString.Unspecified : o.Enum1)) + ); + } + + [Test] + public void DoubleConditionalTest() + { + AssertResults( + new Dictionary> + { + {"0", o => o is PersistentEnumType}, + {"2", o => o is EnumStoredAsStringType}, + {"Small", o => o is EnumStoredAsStringType}, + {"Unspecified", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => (o.Enum2 != EnumStoredAsInt32.Unspecified + ? (o.NullableEnum2.HasValue ? o.Enum1 : EnumStoredAsString.Unspecified) + : EnumStoredAsString.Small) == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == (o.Enum2 != EnumStoredAsInt32.Unspecified + ? EnumStoredAsString.Small + : (o.NullableEnum2.HasValue ? EnumStoredAsString.Unspecified : o.Enum1))) + ); + } + + [Test] + public void CoalesceMemberTest() + { + AssertResults( + new Dictionary> + { + {"2", o => o is EnumStoredAsStringType} + }, + db.Users.Where(o => (o.NotMappedUser ?? o).Enum1 == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == (o ?? o.NotMappedUser).Enum1) + ); + } + + [Test] + public void ConditionalMemberTest() + { + AssertResults( + new Dictionary> + { + {"2", o => o is EnumStoredAsStringType}, + {"\"test\"", o => o is AnsiStringType}, + }, + db.Users.Where(o => (o.Name == "test" ? o.NotMappedUser : o).Enum1 == EnumStoredAsString.Medium), + db.Users.Where(o => EnumStoredAsString.Medium == (o.Name == "test" ? o : o.NotMappedUser).Enum1) + ); + } + + [Test] + public void DynamicMemberTest() + { + AssertResults( + new Dictionary> + { + {"\"test\"", o => o is AnsiStringType}, + }, + db.DynamicUsers.Where("Properties.Name == @0", "test"), + db.DynamicUsers.Where("@0 == Properties.Name", "test") + ); + } + + [Test] + public void DynamicDictionaryMemberTest() + { + AssertResults( + new Dictionary> + { + {"\"test\"", o => o is AnsiStringType}, + }, +#pragma warning disable CS0252 + db.DynamicUsers.Where(o => o.Settings["Property1"] == "test"), +#pragma warning restore CS0252 +#pragma warning disable CS0253 + db.DynamicUsers.Where(o => "test" == o.Settings["Property1"]) +#pragma warning restore CS0253 + ); + } + + [Test] + public void AssignMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"val\"", o => o is AnsiStringType}, + {"Large", o => o is EnumStoredAsStringType}, + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new User {Name = "val", Enum1 = EnumStoredAsString.Large} + ); + } + + [Test] + public void AssignComponentMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"prop1\"", o => o is AnsiStringType} + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new User {Component = new UserComponent {Property1 = "prop1"}} + ); + } + + [Test] + public void AssignNestedComponentMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"other\"", o => o is AnsiStringType} + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new User + { + Component = new UserComponent {OtherComponent = new UserComponent2 {OtherProperty1 = "other"}} + } + ); + } + + [Test] + public void AnonymousAssignMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"val\"", o => o is AnsiStringType}, + {"Large", o => o is EnumStoredAsStringType}, + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new {Name = "val", Enum1 = EnumStoredAsString.Large} + ); + } + + [Test] + public void AnonymousAssignComponentMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"prop1\"", o => o is AnsiStringType} + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new {Component = new {Property1 = "prop1"}} + ); + } + + [Test] + public void AnonymousAssignNestedComponentMemberTest() + { + AssertResult( + new Dictionary> + { + {"0", o => o is Int32Type}, + {"\"other\"", o => o is AnsiStringType} + }, + QueryMode.Insert, + db.Users.Where(o => o.InvalidLoginAttempts > 0), + o => new {Component = new {OtherComponent = new {OtherProperty1 = "other"}}} + ); + } + + private void AssertResults( + Dictionary> expectedResults, + params IQueryable[] queries) + { + foreach (var query in queries) + { + AssertResult(expectedResults, query); + } + } + + private void AssertResult( + Dictionary> expectedResults, + IQueryable query) + { + AssertResult(expectedResults, QueryMode.Select, query.Expression, query.Expression.Type); + } + + private void AssertResult( + Dictionary> expectedResults, + QueryMode queryMode, + IQueryable query, + Expression> expression) + { + var dmlExpression = expression != null + ? DmlExpressionRewriter.PrepareExpression(query.Expression, expression) + : query.Expression; + + AssertResult(expectedResults, queryMode, dmlExpression, typeof(T)); + } + + private void AssertResult( + Dictionary> expectedResults, + QueryMode queryMode, + IQueryable query, + Expression> expression) + { + var dmlExpression = expression != null + ? DmlExpressionRewriter.PrepareExpressionFromAnonymous(query.Expression, expression) + : query.Expression; + + AssertResult(expectedResults, queryMode, dmlExpression, typeof(T)); + } + + private void AssertResult( + Dictionary> expectedResults, + QueryMode queryMode, + Expression expression, + System.Type targetType) + { + var result = NhRelinqQueryParser.PreTransform(expression, new PreTransformationParameters(queryMode, Sfi)); + var parameters = ExpressionParameterVisitor.Visit(result); + expression = result.Expression; + var queryModel = NhRelinqQueryParser.Parse(expression); + ParameterTypeLocator.SetParameterTypes(parameters, queryModel, targetType, Sfi); + Assert.That(parameters.Count, Is.EqualTo(expectedResults.Count), "Incorrect number of parameters"); + foreach (var pair in parameters) + { + var origCulture = CultureInfo.CurrentCulture; + try + { + CultureInfo.CurrentCulture = CultureInfo.InvariantCulture; + var expressionText = pair.Key.ToString(); + Assert.That(expectedResults.ContainsKey(expressionText), Is.True, $"{expressionText} constant is not expected"); + Assert.That(expectedResults[expressionText](pair.Value.Type), Is.True, $"Invalid type, actual type: {pair.Value?.Type?.Name ?? "null"}"); + } + finally + { + CultureInfo.CurrentCulture = origCulture; + } + } + } + } +} diff --git a/src/NHibernate.Test/Linq/TryGetMappedTests.cs b/src/NHibernate.Test/Linq/TryGetMappedTests.cs index 20610d32bad..880ce2e3e69 100644 --- a/src/NHibernate.Test/Linq/TryGetMappedTests.cs +++ b/src/NHibernate.Test/Linq/TryGetMappedTests.cs @@ -774,7 +774,8 @@ private void AssertResult( var expression = query.Expression; var preTransformResult = NhRelinqQueryParser.PreTransform(expression, new PreTransformationParameters(QueryMode.Select, Sfi)); - expression = ExpressionParameterVisitor.Visit(preTransformResult, out var constantToParameterMap); + expression = preTransformResult.Expression; + var constantToParameterMap = ExpressionParameterVisitor.Visit(preTransformResult); var queryModel = NhRelinqQueryParser.Parse(expression); var requiredHqlParameters = new List(); var visitorParameters = new VisitorParameters( diff --git a/src/NHibernate.Test/NHSpecificTest/GH1526/Fixture.cs b/src/NHibernate.Test/NHSpecificTest/GH1526/Fixture.cs index 5318d771ec2..ee2b82c02f9 100644 --- a/src/NHibernate.Test/NHSpecificTest/GH1526/Fixture.cs +++ b/src/NHibernate.Test/NHSpecificTest/GH1526/Fixture.cs @@ -71,7 +71,7 @@ public void ShouldCreateDifferentKeys_TypeBinaryExpression() private static string GetCacheKey(Expression exp) { - return ExpressionKeyVisitor.Visit(exp, new Dictionary()); + return ExpressionKeyVisitor.Visit(exp, new Dictionary(), null); } } } diff --git a/src/NHibernate/Async/Linq/DefaultQueryProvider.cs b/src/NHibernate/Async/Linq/DefaultQueryProvider.cs index 4d0344a27eb..2cc4b8863b4 100644 --- a/src/NHibernate/Async/Linq/DefaultQueryProvider.cs +++ b/src/NHibernate/Async/Linq/DefaultQueryProvider.cs @@ -21,6 +21,7 @@ using NHibernate.Util; using System.Threading.Tasks; using NHibernate.Multi; +using NHibernate.Param; namespace NHibernate.Linq { @@ -103,7 +104,7 @@ public Task ExecuteDmlAsync(QueryMode queryMode, Expression expression, var query = Session.CreateQuery(nhLinqExpression); - SetParameters(query, nhLinqExpression.ParameterValuesByName); + SetParameters(query, nhLinqExpression.NamedParameters); _options?.Apply(query); return query.ExecuteUpdateAsync(cancellationToken); } diff --git a/src/NHibernate/Driver/OdbcDriver.cs b/src/NHibernate/Driver/OdbcDriver.cs index cf8df041cea..5ac80336849 100644 --- a/src/NHibernate/Driver/OdbcDriver.cs +++ b/src/NHibernate/Driver/OdbcDriver.cs @@ -78,10 +78,18 @@ private void SetVariableLengthParameterSize(DbParameter dbParam, SqlType sqlType { switch (dbParam.DbType) { - case DbType.AnsiString: + case DbType.StringFixedLength: case DbType.AnsiStringFixedLength: + // For types that are using one character (CharType, AnsiCharType, TrueFalseType, YesNoType and EnumCharType), + // we have to specify the length otherwise sql function like charindex won't work as expected. + if (sqlType.Length == 1) + { + dbParam.Size = sqlType.Length; + } + + break; case DbType.String: - case DbType.StringFixedLength: + case DbType.AnsiString: // NH-4083: do not limit to column length if above 2000. Setting size may trigger conversion from // nvarchar to ntext when size is superior or equal to 2000, causing some queries to fail: // https://stackoverflow.com/q/8569844/1178314 diff --git a/src/NHibernate/Driver/SqlClientDriver.cs b/src/NHibernate/Driver/SqlClientDriver.cs index 682d755472a..1002005395d 100644 --- a/src/NHibernate/Driver/SqlClientDriver.cs +++ b/src/NHibernate/Driver/SqlClientDriver.cs @@ -161,7 +161,9 @@ protected override void InitializeParameter(DbParameter dbParam, string name, Sq { case DbType.AnsiString: case DbType.AnsiStringFixedLength: - dbParam.Size = IsAnsiText(dbParam, sqlType) ? MsSql2000Dialect.MaxSizeForAnsiClob : MsSql2000Dialect.MaxSizeForLengthLimitedAnsiString; + dbParam.Size = IsAnsiText(dbParam, sqlType) + ? MsSql2000Dialect.MaxSizeForAnsiClob + : IsChar(dbParam, sqlType) ? sqlType.Length : MsSql2000Dialect.MaxSizeForLengthLimitedAnsiString; break; case DbType.Binary: dbParam.Size = IsBlob(dbParam, sqlType) ? MsSql2000Dialect.MaxSizeForBlob : MsSql2000Dialect.MaxSizeForLengthLimitedBinary; @@ -174,7 +176,9 @@ protected override void InitializeParameter(DbParameter dbParam, string name, Sq break; case DbType.String: case DbType.StringFixedLength: - dbParam.Size = IsText(dbParam, sqlType) ? MsSql2000Dialect.MaxSizeForClob : MsSql2000Dialect.MaxSizeForLengthLimitedString; + dbParam.Size = IsText(dbParam, sqlType) + ? MsSql2000Dialect.MaxSizeForClob + : IsChar(dbParam, sqlType) ? sqlType.Length : MsSql2000Dialect.MaxSizeForLengthLimitedString; break; case DbType.DateTime2: dbParam.Size = MsSql2000Dialect.MaxDateTime2; @@ -283,6 +287,18 @@ protected static bool IsBlob(DbParameter dbParam, SqlType sqlType) return (sqlType is BinaryBlobSqlType) || ((DbType.Binary == dbParam.DbType) && sqlType.LengthDefined && (sqlType.Length > MsSql2000Dialect.MaxSizeForLengthLimitedBinary)); } + /// + /// Interprets if a parameter is a character (for the purposes of setting its default size) + /// + /// The parameter + /// The of the parameter + /// True, if the parameter should be interpreted as a character, otherwise False + protected static bool IsChar(DbParameter dbParam, SqlType sqlType) + { + return (DbType.StringFixedLength == dbParam.DbType || DbType.AnsiStringFixedLength == dbParam.DbType) && + sqlType.LengthDefined && sqlType.Length == 1; + } + public override IResultSetsCommand GetResultSetsCommand(ISessionImplementor session) { return new BasicResultSetsCommand(session); diff --git a/src/NHibernate/Driver/SqlServerCeDriver.cs b/src/NHibernate/Driver/SqlServerCeDriver.cs index eb4f03316ea..0b6a4ad93bc 100644 --- a/src/NHibernate/Driver/SqlServerCeDriver.cs +++ b/src/NHibernate/Driver/SqlServerCeDriver.cs @@ -75,6 +75,12 @@ public override IResultSetsCommand GetResultSetsCommand(Engine.ISessionImplement protected override void InitializeParameter(DbParameter dbParam, string name, SqlType sqlType) { base.InitializeParameter(dbParam, name, AdjustSqlType(sqlType)); + // For types that are using one character (CharType, AnsiCharType, TrueFalseType, YesNoType and EnumCharType), + // we have to specify the length otherwise sql function like charindex won't work as expected. + if (sqlType.LengthDefined && sqlType.Length == 1) + { + dbParam.Size = sqlType.Length; + } AdjustDbParamTypeForLargeObjects(dbParam, sqlType); } diff --git a/src/NHibernate/Hql/Ast/ANTLR/ASTQueryTranslatorFactory.cs b/src/NHibernate/Hql/Ast/ANTLR/ASTQueryTranslatorFactory.cs index 7b9e937bd18..e7b95eed2cb 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/ASTQueryTranslatorFactory.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/ASTQueryTranslatorFactory.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using NHibernate.Engine; using NHibernate.Hql.Ast.ANTLR.Tree; +using NHibernate.Linq; using NHibernate.Util; namespace NHibernate.Hql.Ast.ANTLR @@ -16,15 +17,24 @@ public class ASTQueryTranslatorFactory : IQueryTranslatorFactory { public IQueryTranslator[] CreateQueryTranslators(IQueryExpression queryExpression, string collectionRole, bool shallow, IDictionary filters, ISessionFactoryImplementor factory) { - return CreateQueryTranslators(queryExpression.Translate(factory, collectionRole != null), queryExpression.Key, collectionRole, shallow, filters, factory); + return CreateQueryTranslators(queryExpression, queryExpression.Translate(factory, collectionRole != null), queryExpression.Key, collectionRole, shallow, filters, factory); } - static IQueryTranslator[] CreateQueryTranslators(IASTNode ast, string queryIdentifier, string collectionRole, bool shallow, IDictionary filters, ISessionFactoryImplementor factory) + static IQueryTranslator[] CreateQueryTranslators( + IQueryExpression queryExpression, + IASTNode ast, + string queryIdentifier, + string collectionRole, + bool shallow, + IDictionary filters, + ISessionFactoryImplementor factory) { var polymorphicParsers = AstPolymorphicProcessor.Process(ast, factory); var translators = polymorphicParsers - .ToArray(hql => new QueryTranslatorImpl(queryIdentifier, hql, filters, factory)); + .ToArray(hql => queryExpression is NhLinqExpression linqExpression + ? new QueryTranslatorImpl(queryIdentifier, hql, filters, factory, linqExpression.NamedParameters) + : new QueryTranslatorImpl(queryIdentifier, hql, filters, factory)); foreach (var translator in translators) { diff --git a/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs b/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs index 37e5eaffc6e..d6f3a7f0861 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.cs @@ -36,7 +36,7 @@ public partial class HqlSqlWalker private string _statementTypeName; private int _positionalParameterCount; private int _parameterCount; - private readonly NullableDictionary _namedParameters = new NullableDictionary(); + private readonly NullableDictionary _namedParameterLocations = new NullableDictionary(); private readonly List _parameters = new List(); private FromClause _currentFromClause; private SelectClause _selectClause; @@ -54,6 +54,7 @@ public partial class HqlSqlWalker private readonly LiteralProcessor _literalProcessor; private readonly IDictionary _tokenReplacements; + private readonly IDictionary _namedParameters; private JoinType _impliedJoinType; @@ -64,17 +65,30 @@ public partial class HqlSqlWalker private int numberOfParametersInSetClause; private Stack clauseStack=new Stack(); - public HqlSqlWalker(QueryTranslatorImpl qti, - ISessionFactoryImplementor sfi, - ITreeNodeStream input, - IDictionary tokenReplacements, - string collectionRole) + public HqlSqlWalker( + QueryTranslatorImpl qti, + ISessionFactoryImplementor sfi, + ITreeNodeStream input, + IDictionary tokenReplacements, + string collectionRole) + : this(qti, sfi, input, tokenReplacements, null, collectionRole) + { + } + + internal HqlSqlWalker( + QueryTranslatorImpl qti, + ISessionFactoryImplementor sfi, + ITreeNodeStream input, + IDictionary tokenReplacements, + IDictionary namedParameters, + string collectionRole) : this(input) { _sessionFactoryHelper = new SessionFactoryHelperExtensions(sfi); _qti = qti; _literalProcessor = new LiteralProcessor(this); _tokenReplacements = tokenReplacements; + _namedParameters = namedParameters; _collectionFilterRole = collectionRole; } @@ -122,7 +136,7 @@ public ISet QuerySpaces public IDictionary NamedParameters { - get { return _namedParameters; } + get { return _namedParameterLocations; } } internal SessionFactoryHelperExtensions SessionFactoryHelper @@ -1033,13 +1047,20 @@ IASTNode GenerateNamedParameter(IASTNode delimiterNode, IASTNode nameNode) ); parameter.HqlParameterSpecification = paramSpec; + if (_namedParameters != null && _namedParameters.TryGetValue(name, out var namedParameter)) + { + // Add the parameter type information so that we are able to calculate functions return types + // when the parameter is used as an argument. + parameter.ExpectedType = namedParameter.Type; + } + _parameters.Add(paramSpec); return parameter; } IASTNode GeneratePositionalParameter(IASTNode inputNode) { - if (_namedParameters.Count > 0) + if (_namedParameterLocations.Count > 0) { // NH TODO: remove this limitation throw new SemanticException("cannot define positional parameter after any named parameters have been defined"); @@ -1171,15 +1192,15 @@ public void AddQuerySpaces(string[] spaces) private void TrackNamedParameterPositions(string name) { int loc = _parameterCount++; - object o = _namedParameters[name]; + object o = _namedParameterLocations[name]; if ( o == null ) { - _namedParameters.Add(name, loc); + _namedParameterLocations.Add(name, loc); } else if (o is int) { List list = new List(4) {(int) o, loc}; - _namedParameters[name] = list; + _namedParameterLocations[name] = list; } else { diff --git a/src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs b/src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs index 2e27559d1dd..bcf3dc14e11 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/QueryTranslatorImpl.cs @@ -29,7 +29,8 @@ public partial class QueryTranslatorImpl : IFilterTranslator private readonly string _queryIdentifier; private readonly IASTNode _stageOneAst; private readonly ISessionFactoryImplementor _factory; - + private readonly IDictionary _namedParameters; + private bool _shallowQuery; private bool _compiled; private IDictionary _enabledFilters; @@ -47,10 +48,28 @@ public partial class QueryTranslatorImpl : IFilterTranslator /// Currently enabled filters /// The session factory constructing this translator instance. public QueryTranslatorImpl( - string queryIdentifier, - IASTNode parsedQuery, - IDictionary enabledFilters, - ISessionFactoryImplementor factory) + string queryIdentifier, + IASTNode parsedQuery, + IDictionary enabledFilters, + ISessionFactoryImplementor factory) + : this(queryIdentifier, parsedQuery, enabledFilters, factory, null) + { + } + + /// + /// Creates a new AST-based query translator. + /// + /// The query-identifier (used in stats collection) + /// The hql query to translate + /// Currently enabled filters + /// The session factory constructing this translator instance. + /// The named parameters information. + internal QueryTranslatorImpl( + string queryIdentifier, + IASTNode parsedQuery, + IDictionary enabledFilters, + ISessionFactoryImplementor factory, + IDictionary namedParameters) { _queryIdentifier = queryIdentifier; _stageOneAst = parsedQuery; @@ -58,6 +77,7 @@ public QueryTranslatorImpl( _shallowQuery = false; _enabledFilters = enabledFilters; _factory = factory; + _namedParameters = namedParameters; } /// @@ -434,7 +454,7 @@ private static IStatementExecutor BuildAppropriateStatementExecutor(IStatement s private HqlSqlTranslator Analyze(string collectionRole) { - var translator = new HqlSqlTranslator(_stageOneAst, this, _factory, _tokenReplacements, collectionRole); + var translator = new HqlSqlTranslator(_stageOneAst, this, _factory, _tokenReplacements, _namedParameters, collectionRole); translator.Translate(); @@ -548,15 +568,23 @@ internal class HqlSqlTranslator private readonly QueryTranslatorImpl _qti; private readonly ISessionFactoryImplementor _sfi; private readonly IDictionary _tokenReplacements; + private readonly IDictionary _namedParameters; private readonly string _collectionRole; private IStatement _resultAst; - public HqlSqlTranslator(IASTNode ast, QueryTranslatorImpl qti, ISessionFactoryImplementor sfi, IDictionary tokenReplacements, string collectionRole) + public HqlSqlTranslator( + IASTNode ast, + QueryTranslatorImpl qti, + ISessionFactoryImplementor sfi, + IDictionary tokenReplacements, + IDictionary namedParameters, + string collectionRole) { _inputAst = ast; _qti = qti; _sfi = sfi; _tokenReplacements = tokenReplacements; + _namedParameters = namedParameters; _collectionRole = collectionRole; } @@ -576,7 +604,7 @@ public IStatement Translate() var nodes = new BufferedTreeNodeStream(_inputAst); - var hqlSqlWalker = new HqlSqlWalker(_qti, _sfi, nodes, _tokenReplacements, _collectionRole); + var hqlSqlWalker = new HqlSqlWalker(_qti, _sfi, nodes, _tokenReplacements, _namedParameters, _collectionRole); hqlSqlWalker.TreeAdaptor = new HqlSqlWalkerTreeAdaptor(hqlSqlWalker); try diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/BetweenOperatorNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/BetweenOperatorNode.cs index fd91e09fd3a..2ff3a10f790 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/BetweenOperatorNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/BetweenOperatorNode.cs @@ -63,25 +63,13 @@ private IASTNode GetHighOperand() private static void Check(IASTNode check, IASTNode first, IASTNode second) { - var expectedTypeAwareNode = check as IExpectedTypeAwareNode; - if (expectedTypeAwareNode != null) + if (!(check is IExpectedTypeAwareNode expectedTypeAwareNode) || + expectedTypeAwareNode.ExpectedType != null) { - IType expectedType = null; - var firstNode = first as SqlNode; - if (firstNode != null) - { - expectedType = firstNode.DataType; - } - if (expectedType == null) - { - var secondNode = second as SqlNode; - if (secondNode != null) - { - expectedType = secondNode.DataType; - } - } - expectedTypeAwareNode.ExpectedType = expectedType; + return; } + + expectedTypeAwareNode.ExpectedType = (first as SqlNode)?.DataType ?? (second as SqlNode)?.DataType; } } } diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryArithmeticOperatorNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryArithmeticOperatorNode.cs index 9b706facb49..60ca24ca379 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryArithmeticOperatorNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryArithmeticOperatorNode.cs @@ -32,32 +32,34 @@ public void Initialize() IType lhType = (lhs is SqlNode) ? ((SqlNode)lhs).DataType : null; IType rhType = (rhs is SqlNode) ? ((SqlNode)rhs).DataType : null; - if (lhs is IExpectedTypeAwareNode && rhType != null) + TrySetExpectedType(lhs, rhType, true); + TrySetExpectedType(rhs, lhType, false); + } + + private void TrySetExpectedType(IASTNode operand, IType otherOperandType, bool leftHandOperand) + { + if (!(operand is IExpectedTypeAwareNode typeAwareNode) || + otherOperandType == null || + typeAwareNode.ExpectedType != null) { - IType expectedType; + return; + } + + IType expectedType = null; - // we have something like : "? [op] rhs" - if (IsDateTimeType(rhType)) + // we have something like : "lhs [op] ?" or "? [op] rhs" + if (IsDateTimeType(otherOperandType)) + { + if (leftHandOperand) { // more specifically : "? [op] datetime" // 1) if the operator is MINUS, the param needs to be of // some datetime type // 2) if the operator is PLUS, the param needs to be of // some numeric type - expectedType = Type == HqlSqlWalker.PLUS ? NHibernateUtil.Double : rhType; + expectedType = Type == HqlSqlWalker.PLUS ? NHibernateUtil.Double : otherOperandType; } - else - { - expectedType = rhType; - } - ((IExpectedTypeAwareNode)lhs).ExpectedType = expectedType; - } - else if (rhs is ParameterNode && lhType != null) - { - IType expectedType = null; - - // we have something like : "lhs [op] ?" - if (IsDateTimeType(lhType)) + else if (Type == HqlSqlWalker.PLUS) { // more specifically : "datetime [op] ?" // 1) if the operator is MINUS, we really cannot determine @@ -65,17 +67,15 @@ public void Initialize() // numeric would be valid // 2) if the operator is PLUS, the param needs to be of // some numeric type - if (Type == HqlSqlWalker.PLUS) - { - expectedType = NHibernateUtil.Double; - } - } - else - { - expectedType = lhType; + expectedType = NHibernateUtil.Double; } - ((IExpectedTypeAwareNode)rhs).ExpectedType = expectedType; } + else + { + expectedType = otherOperandType; + } + + typeAwareNode.ExpectedType = expectedType; } public override IType DataType diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs index bf0560dfc76..cae4b920ec8 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/BinaryLogicOperatorNode.cs @@ -65,15 +65,14 @@ public virtual void Initialize() rhsType = lhsType; } - var lshExpectedTypeAwareNode = lhs as IExpectedTypeAwareNode; - if (lshExpectedTypeAwareNode != null) + if (lhs is IExpectedTypeAwareNode lshTypeAwareNode && lshTypeAwareNode.ExpectedType == null) { - lshExpectedTypeAwareNode.ExpectedType = rhsType; + lshTypeAwareNode.ExpectedType = rhsType; } - var rshExpectedTypeAwareNode = rhs as IExpectedTypeAwareNode; - if (rshExpectedTypeAwareNode != null) + + if (rhs is IExpectedTypeAwareNode rshTypeAwareNode && rshTypeAwareNode.ExpectedType == null) { - rshExpectedTypeAwareNode.ExpectedType = lhsType; + rshTypeAwareNode.ExpectedType = lhsType; } MutateRowValueConstructorSyntaxesIfNecessary( lhsType, rhsType ); diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/InLogicOperatorNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/InLogicOperatorNode.cs index 0ad4e404bda..f0b9856f76a 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/InLogicOperatorNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/InLogicOperatorNode.cs @@ -47,11 +47,12 @@ public override void Initialize() IASTNode inListChild = inList.GetChild(0); while (inListChild != null) { - var expectedTypeAwareNode = inListChild as IExpectedTypeAwareNode; - if (expectedTypeAwareNode != null) + if (inListChild is IExpectedTypeAwareNode expectedTypeAwareNode && + expectedTypeAwareNode.ExpectedType == null) { expectedTypeAwareNode.ExpectedType = lhsType; } + inListChild = inListChild.NextSibling; } } diff --git a/src/NHibernate/Impl/AbstractQueryImpl.cs b/src/NHibernate/Impl/AbstractQueryImpl.cs index 9ff4c712b0d..ba46b665466 100644 --- a/src/NHibernate/Impl/AbstractQueryImpl.cs +++ b/src/NHibernate/Impl/AbstractQueryImpl.cs @@ -142,7 +142,8 @@ protected internal virtual IType DetermineType(int paramPosition, object paramVa protected internal virtual IType DetermineType(int paramPosition, object paramValue) { - IType type = parameterMetadata.GetOrdinalParameterExpectedType(paramPosition + 1) ?? GuessType(paramValue); + IType type = parameterMetadata.GetOrdinalParameterExpectedType(paramPosition + 1) ?? + ParameterHelper.GuessType(paramValue, session.Factory); return type; } @@ -154,67 +155,15 @@ protected internal virtual IType DetermineType(string paramName, object paramVal protected internal virtual IType DetermineType(string paramName, object paramValue) { - IType type = parameterMetadata.GetNamedParameterExpectedType(paramName) ?? GuessType(paramValue); + IType type = parameterMetadata.GetNamedParameterExpectedType(paramName) ?? + ParameterHelper.GuessType(paramValue, session.Factory); return type; } protected internal virtual IType DetermineType(string paramName, System.Type clazz) { - IType type = parameterMetadata.GetNamedParameterExpectedType(paramName) ?? GuessType(clazz); - return type; - } - - /// - /// Guesses the from the param's value. - /// - /// The object to guess the of. - /// An for the object. - /// - /// Thrown when the param is null because the - /// can't be guess from a null value. - /// - private IType GuessType(object param) - { - if (param == null) - { - throw new ArgumentNullException("param", "The IType can not be guessed for a null value."); - } - - System.Type clazz = NHibernateProxyHelper.GetClassWithoutInitializingProxy(param); - return GuessType(clazz); - } - - /// - /// Guesses the from the . - /// - /// The to guess the of. - /// An for the . - /// - /// Thrown when the clazz is null because the - /// can't be guess from a null type. - /// - private IType GuessType(System.Type clazz) - { - if (clazz == null) - { - throw new ArgumentNullException("clazz", "The IType can not be guessed for a null value."); - } - - var type = TypeFactory.HeuristicType(clazz); - if (type == null || type is SerializableType) - { - if (session.Factory.TryGetEntityPersister(clazz.FullName) != null) - { - return NHibernateUtil.Entity(clazz); - } - - if (type == null) - { - throw new HibernateException( - "Could not determine a type for class: " + clazz.AssemblyQualifiedName); - } - } - + IType type = parameterMetadata.GetNamedParameterExpectedType(paramName) ?? + ParameterHelper.GuessType(clazz, session.Factory); return type; } @@ -310,7 +259,11 @@ public IQuery SetParameter(int position, T val) { CheckPositionalParameter(position); - return SetParameter(position, val, parameterMetadata.GetOrdinalParameterExpectedType(position + 1) ?? GuessType(typeof(T))); + return SetParameter( + position, + val, + parameterMetadata.GetOrdinalParameterExpectedType(position + 1) ?? + ParameterHelper.GuessType(typeof(T), session.Factory)); } private void CheckPositionalParameter(int position) @@ -327,7 +280,11 @@ private void CheckPositionalParameter(int position) public IQuery SetParameter(string name, T val) { - return SetParameter(name, val, parameterMetadata.GetNamedParameterExpectedType(name) ?? GuessType(typeof (T))); + return SetParameter( + name, + val, + parameterMetadata.GetNamedParameterExpectedType(name) ?? + ParameterHelper.GuessType(typeof(T), session.Factory)); } public IQuery SetParameter(string name, object val) @@ -792,7 +749,12 @@ public IQuery SetParameterList(string name, IEnumerable vals) } object firstValue = vals.Cast().FirstOrDefault(); - SetParameterList(name, vals, firstValue == null ? GuessType(vals.GetCollectionElementType()) : DetermineType(name, firstValue)); + SetParameterList( + name, + vals, + firstValue == null + ? ParameterHelper.GuessType(vals.GetCollectionElementType(), session.Factory) + : DetermineType(name, firstValue)); return this; } diff --git a/src/NHibernate/Linq/DefaultQueryProvider.cs b/src/NHibernate/Linq/DefaultQueryProvider.cs index c8de5a37a5e..912b640a951 100644 --- a/src/NHibernate/Linq/DefaultQueryProvider.cs +++ b/src/NHibernate/Linq/DefaultQueryProvider.cs @@ -11,6 +11,7 @@ using NHibernate.Util; using System.Threading.Tasks; using NHibernate.Multi; +using NHibernate.Param; namespace NHibernate.Linq { @@ -211,7 +212,7 @@ protected virtual NhLinqExpression PrepareQuery(Expression expression, out IQuer query = Session.CreateFilter(Collection, nhLinqExpression); } - SetParameters(query, nhLinqExpression.ParameterValuesByName); + SetParameters(query, nhLinqExpression.NamedParameters); _options?.Apply(query); SetResultTransformerAndAdditionalCriteria(query, nhLinqExpression, nhLinqExpression.ParameterValuesByName); @@ -252,38 +253,19 @@ protected virtual object ExecuteQuery(NhLinqExpression nhLinqExpression, IQuery #pragma warning restore 618 } - private static void SetParameters(IQuery query, IDictionary> parameters) + private static void SetParameters(IQuery query, IDictionary parameters) { foreach (var parameterName in query.NamedParameters) { - var param = parameters[parameterName]; - - if (param.Item1 == null) + // The parameter type will be taken from the parameter metadata + var parameter = parameters[parameterName]; + if (parameter.IsCollection) { - if (typeof(IEnumerable).IsAssignableFrom(param.Item2.ReturnedClass) && - param.Item2.ReturnedClass != typeof(string)) - { - query.SetParameterList(parameterName, null, param.Item2); - } - else - { - query.SetParameter(parameterName, null, param.Item2); - } + query.SetParameterList(parameter.Name, (IEnumerable) parameter.Value); } else { - if (param.Item1 is IEnumerable && !(param.Item1 is string)) - { - query.SetParameterList(parameterName, (IEnumerable)param.Item1); - } - else if (param.Item2 != null) - { - query.SetParameter(parameterName, param.Item1, param.Item2); - } - else - { - query.SetParameter(parameterName, param.Item1); - } + query.SetParameter(parameter.Name, parameter.Value); } } } @@ -310,7 +292,7 @@ public int ExecuteDml(QueryMode queryMode, Expression expression) var query = Session.CreateQuery(nhLinqExpression); - SetParameters(query, nhLinqExpression.ParameterValuesByName); + SetParameters(query, nhLinqExpression.NamedParameters); _options?.Apply(query); return query.ExecuteUpdate(); } diff --git a/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs b/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs index bf97cc54cb7..aa0f1ab88f8 100644 --- a/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs +++ b/src/NHibernate/Linq/Functions/BaseHqlGeneratorForMethod.cs @@ -19,6 +19,13 @@ public abstract class BaseHqlGeneratorForMethod : IHqlGeneratorForMethod, IHqlGe public virtual bool AllowsNullableReturnType(MethodInfo method) => true; + /// + public virtual bool TryGetCollectionParameter(MethodCallExpression expression, out ConstantExpression collectionParameter) + { + collectionParameter = null; + return false; + } + private protected static void LogIgnoredParameter(MethodInfo method, string paramType) { if (Log.IsWarnEnabled()) diff --git a/src/NHibernate/Linq/Functions/GetValueOrDefaultGenerator.cs b/src/NHibernate/Linq/Functions/GetValueOrDefaultGenerator.cs index 33cb12c2c6c..cc0fa7202b9 100644 --- a/src/NHibernate/Linq/Functions/GetValueOrDefaultGenerator.cs +++ b/src/NHibernate/Linq/Functions/GetValueOrDefaultGenerator.cs @@ -42,5 +42,12 @@ private static HqlExpression GetRhs(MethodInfo method, ReadOnlyCollection !method.ReturnType.IsValueType; + + /// + public bool TryGetCollectionParameter(MethodCallExpression expression, out ConstantExpression collectionParameter) + { + collectionParameter = null; + return false; + } } } diff --git a/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs b/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs index 73ad8b3d9e4..3ef7583ee6b 100644 --- a/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs +++ b/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs @@ -18,6 +18,14 @@ public interface IHqlGeneratorForMethod internal interface IHqlGeneratorForMethodExtended { bool AllowsNullableReturnType(MethodInfo method); + + /// + /// Try getting a collection parameter from . + /// + /// The method call expression. + /// Output parameter for the retrieved collection parameter. + /// Whether collection parameter was retrieved. + bool TryGetCollectionParameter(MethodCallExpression expression, out ConstantExpression collectionParameter); } internal static class HqlGeneratorForMethodExtensions @@ -33,6 +41,21 @@ public static bool AllowsNullableReturnType(this IHqlGeneratorForMethod generato return true; } + // 6.0 TODO: Remove + public static bool TryGetCollectionParameters( + this IHqlGeneratorForMethod generator, + MethodCallExpression expression, + out ConstantExpression collectionParameter) + { + if (generator is IHqlGeneratorForMethodExtended extendedGenerator) + { + return extendedGenerator.TryGetCollectionParameter(expression, out collectionParameter); + } + + collectionParameter = null; + return false; + } + // 6.0 TODO: merge into IHqlGeneratorForMethod /// /// Should pre-evaluation be allowed for this method? diff --git a/src/NHibernate/Linq/Functions/QueryableGenerator.cs b/src/NHibernate/Linq/Functions/QueryableGenerator.cs index f007fa22592..eaf44c64f4a 100644 --- a/src/NHibernate/Linq/Functions/QueryableGenerator.cs +++ b/src/NHibernate/Linq/Functions/QueryableGenerator.cs @@ -155,6 +155,15 @@ public CollectionContainsGenerator() public override bool AllowsNullableReturnType(MethodInfo method) => false; + /// + public override bool TryGetCollectionParameter(MethodCallExpression expression, out ConstantExpression collectionParameter) + { + var argument = expression.Method.IsStatic ? expression.Arguments[0] : expression.Object; + collectionParameter = argument as ConstantExpression; + + return collectionParameter != null; + } + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) { // TODO - alias generator diff --git a/src/NHibernate/Linq/Functions/StringGenerator.cs b/src/NHibernate/Linq/Functions/StringGenerator.cs index 5e358f22194..3bb0523a57c 100644 --- a/src/NHibernate/Linq/Functions/StringGenerator.cs +++ b/src/NHibernate/Linq/Functions/StringGenerator.cs @@ -59,6 +59,13 @@ public IHqlGeneratorForMethod GetMethodGenerator(MethodInfo method) } public bool AllowsNullableReturnType(MethodInfo method) => false; + + /// + public bool TryGetCollectionParameter(MethodCallExpression expression, out ConstantExpression collectionParameter) + { + collectionParameter = null; + return false; + } } public class LengthGenerator : BaseHqlGeneratorForProperty diff --git a/src/NHibernate/Linq/NhLinqExpression.cs b/src/NHibernate/Linq/NhLinqExpression.cs index ad39397fd71..671d949fc02 100644 --- a/src/NHibernate/Linq/NhLinqExpression.cs +++ b/src/NHibernate/Linq/NhLinqExpression.cs @@ -34,6 +34,8 @@ public class NhLinqExpression : IQueryExpression, ICacheableQueryExpression protected virtual QueryMode QueryMode { get; } + internal IDictionary NamedParameters { get; } + private readonly Expression _expression; private readonly IDictionary _constantToParameterMap; @@ -56,12 +58,12 @@ internal NhLinqExpression(QueryMode queryMode, Expression expression, ISessionFa // referenced from the main query. LinqLogging.LogExpression("Expression (partially evaluated)", _expression); - _expression = ExpressionParameterVisitor.Visit(preTransformResult, out _constantToParameterMap); + _constantToParameterMap = ExpressionParameterVisitor.Visit(preTransformResult); ParameterValuesByName = _constantToParameterMap.Values.Distinct().ToDictionary(p => p.Name, - p => System.Tuple.Create(p.Value, p.Type)); - - Key = ExpressionKeyVisitor.Visit(_expression, _constantToParameterMap); + p => System.Tuple.Create(p.Value, p.Type)); + NamedParameters = _constantToParameterMap.Values.Distinct().ToDictionary(p => p.Name); + Key = ExpressionKeyVisitor.Visit(_expression, _constantToParameterMap, sessionFactory); Type = _expression.Type; @@ -88,6 +90,7 @@ public IASTNode Translate(ISessionFactoryImplementor sessionFactory, bool filter var requiredHqlParameters = new List(); var queryModel = NhRelinqQueryParser.Parse(_expression); queryModel.TransformExpressions(TransparentIdentifierRemovingExpressionVisitor.ReplaceTransparentIdentifiers); + ParameterTypeLocator.SetParameterTypes(_constantToParameterMap, queryModel, TargetType, sessionFactory, true); var visitorParameters = new VisitorParameters(sessionFactory, _constantToParameterMap, requiredHqlParameters, new QuerySourceNamer(), TargetType, QueryMode); diff --git a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs index 99a8e009571..9826ffdbf06 100644 --- a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs +++ b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs @@ -29,8 +29,8 @@ private AddJoinsReWriter(ISessionFactoryImplementor sessionFactory, QueryModel q { _sessionFactory = sessionFactory; var joiner = new Joiner(queryModel, AddJoin); - _memberExpressionJoinDetector = new MemberExpressionJoinDetector(this, joiner); - _whereJoinDetector = new WhereJoinDetector(this, joiner); + _memberExpressionJoinDetector = new MemberExpressionJoinDetector(this, joiner, _sessionFactory); + _whereJoinDetector = new WhereJoinDetector(this, joiner, _sessionFactory); } public static void ReWrite(QueryModel queryModel, VisitorParameters parameters) diff --git a/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs index ef4981d2aec..fdce29edd9a 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionKeyVisitor.cs @@ -7,7 +7,10 @@ using System.Reflection; using System.Runtime.CompilerServices; using System.Text; +using NHibernate.Engine; using NHibernate.Param; +using NHibernate.Type; +using NHibernate.Util; using Remotion.Linq.Parsing; namespace NHibernate.Linq.Visitors @@ -22,22 +25,46 @@ namespace NHibernate.Linq.Visitors public class ExpressionKeyVisitor : RelinqExpressionVisitor { private readonly IDictionary _constantToParameterMap; + private readonly ISessionFactoryImplementor _sessionFactory; readonly StringBuilder _string = new StringBuilder(); - private ExpressionKeyVisitor(IDictionary constantToParameterMap) + private ExpressionKeyVisitor( + IDictionary constantToParameterMap, + ISessionFactoryImplementor sessionFactory) { _constantToParameterMap = constantToParameterMap; + _sessionFactory = sessionFactory; } + // Since v5.3 + [Obsolete("Use the overload with ISessionFactoryImplementor parameter")] public static string Visit(Expression expression, IDictionary parameters) { - var visitor = new ExpressionKeyVisitor(parameters); + var visitor = new ExpressionKeyVisitor(parameters, null); visitor.Visit(expression); return visitor.ToString(); } + /// + /// Generates the key for the expression. + /// + /// The expression. + /// The session factory. + /// Parameters found in . + /// The key for the expression. + public static string Visit( + Expression rootExpression, + IDictionary parameters, + ISessionFactoryImplementor sessionFactory) + { + var visitor = new ExpressionKeyVisitor(parameters, sessionFactory); + visitor.Visit(rootExpression); + + return visitor.ToString(); + } + public override string ToString() { return _string.ToString(); @@ -86,49 +113,70 @@ protected override Expression VisitConstant(ConstantExpression expression) throw new InvalidOperationException("Cannot visit a constant without a constant to parameter map."); if (_constantToParameterMap.TryGetValue(expression, out param)) { - // Nulls generate different query plans. X = variable generates a different query depending on if variable is null or not. - if (param.Value == null) - { - _string.Append("NULL"); - } - else - { - var value = param.Value as IEnumerable; - if (value != null && !(value is string) && !value.Cast().Any()) - { - _string.Append("EmptyList"); - } - else - { - _string.Append(param.Name); - } - } + VisitParameter(param); } else { - if (expression.Value == null) - { - _string.Append("NULL"); - } - else - { - var value = expression.Value as IEnumerable; - if (value != null && !(value is string) && !(value is IQueryable)) - { - _string.Append("{"); - _string.Append(String.Join(",", value.Cast())); - _string.Append("}"); - } - else - { - _string.Append(expression.Value); - } - } + VisitConstantValue(expression.Value); } return base.VisitConstant(expression); } + private void VisitConstantValue(object value) + { + if (value == null) + { + _string.Append("NULL"); + return; + } + + if (value is IEnumerable enumerable && !(value is IQueryable)) + { + _string.Append("{"); + _string.Append(string.Join(",", enumerable.Cast())); + _string.Append("}"); + return; + } + + // When MappedAs is used we have to put all sql types information in the key in order to + // distinct when different precisions/sizes are used. + if (_sessionFactory != null && value is IType type) + { + _string.Append(type.Name); + _string.Append('['); + _string.Append(string.Join(",", type.SqlTypes(_sessionFactory).Select(o => o.ToString()))); + _string.Append(']'); + return; + } + + _string.Append(value); + } + + private void VisitParameter(NamedParameter param) + { + // Nulls generate different query plans. X = variable generates a different query depending on if variable is null or not. + if (param.Value == null) + { + _string.Append("NULL"); + return; + } + + if (param.IsCollection && !((IEnumerable) param.Value).Cast().Any()) + { + _string.Append("EmptyList"); + } + else + { + _string.Append(param.Name); + } + + // Add the type in order to avoid invalid parameter conversions (string -> char) + _string.Append("<"); + _string.Append(param.Value.GetType()); + _string.Append(">"); + } + private T AppendCommas(T expression) where T : Expression { Visit(expression); @@ -159,6 +207,20 @@ protected override Expression VisitMember(MemberExpression expression) return expression; } +#if NETCOREAPP2_0 + protected override Expression VisitInvocation(InvocationExpression expression) + { + if (ExpressionsHelper.TryGetDynamicMemberBinder(expression, out var memberBinder)) + { + Visit(expression.Arguments[1]); + FormatBinder(memberBinder); + return expression; + } + + return base.VisitInvocation(expression); + } +#endif + protected override Expression VisitMethodCall(MethodCallExpression expression) { Visit(expression.Object); @@ -218,8 +280,8 @@ protected override Expression VisitQuerySourceReference(Remotion.Linq.Clauses.Ex protected override Expression VisitDynamic(DynamicExpression expression) { - FormatBinder(expression.Binder); Visit(expression.Arguments, AppendCommas); + FormatBinder(expression.Binder); return expression; } diff --git a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs index 45134248a51..f6a9e5de43f 100644 --- a/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs +++ b/src/NHibernate/Linq/Visitors/ExpressionParameterVisitor.cs @@ -1,9 +1,11 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Reflection; using NHibernate.Engine; +using NHibernate.Linq.Functions; using NHibernate.Param; using NHibernate.Type; using NHibernate.Util; @@ -18,23 +20,18 @@ public class ExpressionParameterVisitor : RelinqExpressionVisitor { private readonly Dictionary _parameters = new Dictionary(); private readonly Dictionary _variableParameters = new Dictionary(); + private readonly HashSet _collectionParameters = new HashSet(); private readonly IDictionary _queryVariables; private readonly ISessionFactoryImplementor _sessionFactory; + private readonly ILinqToHqlGeneratorsRegistry _functionRegistry; - private static readonly MethodInfo QueryableSkipDefinition = - ReflectHelper.FastGetMethodDefinition(Queryable.Skip, default(IQueryable), 0); - private static readonly MethodInfo QueryableTakeDefinition = - ReflectHelper.FastGetMethodDefinition(Queryable.Take, default(IQueryable), 0); - private static readonly MethodInfo EnumerableSkipDefinition = - ReflectHelper.FastGetMethodDefinition(Enumerable.Skip, default(IEnumerable), 0); - private static readonly MethodInfo EnumerableTakeDefinition = - ReflectHelper.FastGetMethodDefinition(Enumerable.Take, default(IEnumerable), 0); - - private readonly ICollection _pagingMethods = new HashSet - { - QueryableSkipDefinition, QueryableTakeDefinition, - EnumerableSkipDefinition, EnumerableTakeDefinition - }; + private static readonly HashSet PagingMethods = new HashSet + { + ReflectionCache.EnumerableMethods.SkipDefinition, + ReflectionCache.EnumerableMethods.TakeDefinition, + ReflectionCache.QueryableMethods.SkipDefinition, + ReflectionCache.QueryableMethods.TakeDefinition + }; // Since v5.3 [Obsolete("Please use overload with preTransformationResult parameter instead.")] @@ -47,6 +44,7 @@ public ExpressionParameterVisitor(PreTransformationResult preTransformationResul { _sessionFactory = preTransformationResult.SessionFactory; _queryVariables = preTransformationResult.QueryVariables; + _functionRegistry = _sessionFactory.Settings.LinqToHqlGeneratorsRegistry; } // Since v5.3 @@ -59,22 +57,19 @@ public static IDictionary Visit(Expression e return visitor._parameters; } - public static Expression Visit( - PreTransformationResult preTransformationResult, - out IDictionary parameters) + public static IDictionary Visit(PreTransformationResult preTransformationResult) { var visitor = new ExpressionParameterVisitor(preTransformationResult); - var expression = visitor.Visit(preTransformationResult.Expression); - parameters = visitor._parameters; - - return expression; + visitor.Visit(preTransformationResult.Expression); + return visitor._parameters; } protected override Expression VisitMethodCall(MethodCallExpression expression) { - if (expression.Method.Name == nameof(LinqExtensionMethods.MappedAs) && expression.Method.DeclaringType == typeof(LinqExtensionMethods)) + if (VisitorUtil.IsMappedAs(expression.Method)) { var rawParameter = Visit(expression.Arguments[0]); + // TODO 6.0: Remove below code and return expression as this logic is now inside ConstantTypeLocator var parameter = rawParameter as ConstantExpression; var type = expression.Arguments[1] as ConstantExpression; if (parameter == null) @@ -95,10 +90,10 @@ protected override Expression VisitMethodCall(MethodCallExpression expression) ? expression.Method.GetGenericMethodDefinition() : expression.Method; - if (_pagingMethods.Contains(method) && !_sessionFactory.Dialect.SupportsVariableLimit) + if (PagingMethods.Contains(method) && !_sessionFactory.Dialect.SupportsVariableLimit) { - //TODO: find a way to make this code cleaner var query = Visit(expression.Arguments[0]); + //TODO 6.0: Remove the below code and return expression var arg = expression.Arguments[1]; if (query == expression.Arguments[0]) @@ -107,6 +102,13 @@ protected override Expression VisitMethodCall(MethodCallExpression expression) return Expression.Call(null, expression.Method, query, arg); } + if (_functionRegistry != null && + _functionRegistry.TryGetGenerator(method, out var generator) && + generator.TryGetCollectionParameters(expression, out var collectionParameter)) + { + _collectionParameters.Add(collectionParameter); + } + if (VisitorUtil.IsDynamicComponentDictionaryGetter(expression, _sessionFactory)) { return expression; @@ -115,6 +117,20 @@ protected override Expression VisitMethodCall(MethodCallExpression expression) return base.VisitMethodCall(expression); } +#if NETCOREAPP2_0 + protected override Expression VisitInvocation(InvocationExpression expression) + { + if (ExpressionsHelper.TryGetDynamicMemberBinder(expression, out _)) + { + // Avoid adding System.Runtime.CompilerServices.CallSite instance as a parameter + base.Visit(expression.Arguments[1]); + return expression; + } + + return base.VisitInvocation(expression); + } +#endif + protected override Expression VisitConstant(ConstantExpression expression) { if (!_parameters.ContainsKey(expression) && !typeof(IQueryable).IsAssignableFrom(expression.Type) && !IsNullObject(expression)) @@ -125,11 +141,14 @@ protected override Expression VisitConstant(ConstantExpression expression) // We have a bit more information about the null parameter value. // Figure out a type so that HQL doesn't break on the null. (Related to NH-2430) + // In v5.3 types are calculated by ConstantTypeLocator, this logic is only for back compatibility. + // TODO 6.0: Remove if (expression.Value == null) type = NHibernateUtil.GuessType(expression.Type); // Constant characters should be sent as strings - if (expression.Type == typeof(char)) + // TODO 6.0: Remove + if (_queryVariables == null && expression.Type == typeof(char)) { value = value.ToString(); } @@ -144,13 +163,13 @@ protected override Expression VisitConstant(ConstantExpression expression) _queryVariables.TryGetValue(expression, out var variable) && !_variableParameters.TryGetValue(variable, out parameter)) { - parameter = new NamedParameter("p" + (_parameters.Count + 1), value, type); + parameter = CreateParameter(expression, value, type); _variableParameters.Add(variable, parameter); } if (parameter == null) { - parameter = new NamedParameter("p" + (_parameters.Count + 1), value, type); + parameter = CreateParameter(expression, value, type); } _parameters.Add(expression, parameter); @@ -161,6 +180,15 @@ protected override Expression VisitConstant(ConstantExpression expression) return base.VisitConstant(expression); } + private NamedParameter CreateParameter(ConstantExpression expression, object value, IType type) + { + return new NamedParameter( + "p" + (_parameters.Count + 1), + value, + type, + _collectionParameters.Contains(expression)); + } + private static bool IsNullObject(ConstantExpression expression) { return expression.Type == typeof(Object) && expression.Value == null; diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index cd9cd49eadb..a16968cccb2 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -226,18 +226,14 @@ private HqlTreeNode VisitNhNominated(NhNominatedExpression nhNominatedExpression private HqlTreeNode VisitInvocationExpression(InvocationExpression expression) { - //This is an ugly workaround for dynamic expressions. - //Unfortunately we can not tap into the expression tree earlier to intercept the dynamic expression - if (expression.Arguments.Count == 2 && - expression.Arguments[0] is ConstantExpression constant && - constant.Value is CallSite site && - site.Binder is GetMemberBinder binder) +#if NETCOREAPP2_0 + if (ExpressionsHelper.TryGetDynamicMemberBinder(expression, out var binder)) { return _hqlTreeBuilder.Dot( VisitExpression(expression.Arguments[1]).AsExpression(), _hqlTreeBuilder.Ident(binder.Name)); } - +#endif return VisitExpression(expression.Expression); } diff --git a/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs b/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs index 580ba3cf00c..019769fccb1 100644 --- a/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs +++ b/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs @@ -19,16 +19,18 @@ internal class MemberExpressionJoinDetector : RelinqExpressionVisitor { private readonly IIsEntityDecider _isEntityDecider; private readonly IJoiner _joiner; + private readonly ISessionFactoryImplementor _sessionFactory; private bool _requiresJoinForNonIdentifier; private bool _preventJoinsInConditionalTest; private bool _hasIdentifier; private int _memberExpressionDepth; - public MemberExpressionJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner) + public MemberExpressionJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner, ISessionFactoryImplementor sessionFactory) { _isEntityDecider = isEntityDecider; _joiner = joiner; + _sessionFactory = sessionFactory; } protected override Expression VisitMember(MemberExpression expression) @@ -55,7 +57,7 @@ protected override Expression VisitMember(MemberExpression expression) ((_requiresJoinForNonIdentifier && !_hasIdentifier) || _memberExpressionDepth > 0) && _joiner.CanAddJoin(expression)) { - var key = ExpressionKeyVisitor.Visit(expression, null); + var key = ExpressionKeyVisitor.Visit(expression, null, _sessionFactory); return _joiner.AddJoin(result, key); } diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs new file mode 100644 index 00000000000..34326640169 --- /dev/null +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -0,0 +1,321 @@ +using System.Collections.Generic; +using System.Dynamic; +using System.Linq.Expressions; +using NHibernate.Engine; +using NHibernate.Param; +using NHibernate.Type; +using NHibernate.Util; +using Remotion.Linq; +using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Parsing; + +namespace NHibernate.Linq.Visitors +{ + /// + /// Locates parameter actual type based on its usage. + /// + public static class ParameterTypeLocator + { + /// + /// List of for which the should be related to the other side + /// of a (e.g. o.MyEnum == MyEnum.Option -> MyEnum.Option should have o.MyEnum as a related + /// ). + /// + private static readonly HashSet ValidBinaryExpressionTypes = new HashSet + { + ExpressionType.Equal, + ExpressionType.NotEqual, + ExpressionType.GreaterThanOrEqual, + ExpressionType.GreaterThan, + ExpressionType.LessThan, + ExpressionType.LessThanOrEqual, + ExpressionType.Coalesce, + ExpressionType.Assign + }; + + /// + /// List of for which the should be copied across + /// as related (e.g. (o.MyEnum ?? MyEnum.Option) == MyEnum.Option2 -> MyEnum.Option2 should have o.MyEnum as a related + /// ). + /// + private static readonly HashSet NonVoidOperators = new HashSet + { + ExpressionType.Coalesce, + ExpressionType.Conditional + }; + + /// + /// Set query parameter types based on the given query model. + /// + /// The query parameters. + /// The query model. + /// The target entity type. + /// The session factory. + public static void SetParameterTypes( + IDictionary parameters, + QueryModel queryModel, + System.Type targetType, + ISessionFactoryImplementor sessionFactory) + { + SetParameterTypes(parameters, queryModel, targetType, sessionFactory, false); + } + + internal static void SetParameterTypes( + IDictionary parameters, + QueryModel queryModel, + System.Type targetType, + ISessionFactoryImplementor sessionFactory, + bool removeMappedAsCalls) + { + if (parameters.Count == 0) + { + return; + } + + var visitor = new ConstantTypeLocatorVisitor(removeMappedAsCalls, targetType, parameters, sessionFactory); + queryModel.TransformExpressions(visitor.Visit); + + foreach (var pair in visitor.ConstantExpressions) + { + var type = pair.Value; + var constantExpression = pair.Key; + if (!parameters.TryGetValue(constantExpression, out var namedParameter)) + { + continue; + } + + if (type != null) + { + // MappedAs was used + namedParameter.Type = type; + continue; + } + + // In order to get the actual type we have to check first the related member expressions, as + // an enum is translated in a numeric type when used in a BinaryExpression and also it can be mapped as string. + // By getting the type from a related member expression we also get the correct length in case of StringType + // or precision when having a DecimalType. + if (visitor.RelatedExpressions.TryGetValue(constantExpression, out var memberExpressions)) + { + foreach (var memberExpression in memberExpressions) + { + if (ExpressionsHelper.TryGetMappedType( + sessionFactory, + memberExpression, + out type, + out _, + out _, + out _)) + { + break; + } + } + } + + // No related MemberExpressions was found, guess the type by value or its type when null. + if (type == null) + { + type = constantExpression.Value != null + ? ParameterHelper.TryGuessType(constantExpression.Value, sessionFactory, namedParameter.IsCollection) + : ParameterHelper.TryGuessType(constantExpression.Type, sessionFactory, namedParameter.IsCollection); + } + + namedParameter.Type = type; + } + } + + private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor + { + private readonly bool _removeMappedAsCalls; + private readonly System.Type _targetType; + private readonly IDictionary _parameters; + private readonly ISessionFactoryImplementor _sessionFactory; + public readonly Dictionary ConstantExpressions = + new Dictionary(); + public readonly Dictionary> RelatedExpressions = + new Dictionary>(); + + public ConstantTypeLocatorVisitor( + bool removeMappedAsCalls, + System.Type targetType, + IDictionary parameters, + ISessionFactoryImplementor sessionFactory) + { + _removeMappedAsCalls = removeMappedAsCalls; + _targetType = targetType; + _sessionFactory = sessionFactory; + _parameters = parameters; + } + + protected override Expression VisitBinary(BinaryExpression node) + { + node = (BinaryExpression) base.VisitBinary(node); + if (!ValidBinaryExpressionTypes.Contains(node.NodeType)) + { + return node; + } + + var left = Unwrap(node.Left); + var right = Unwrap(node.Right); + if (node.NodeType == ExpressionType.Assign) + { + VisitAssign(left, right); + } + else + { + AddRelatedExpression(node, left, right); + AddRelatedExpression(node, right, left); + } + + return node; + } + + protected override Expression VisitConditional(ConditionalExpression node) + { + node = (ConditionalExpression) base.VisitConditional(node); + var ifTrue = Unwrap(node.IfTrue); + var ifFalse = Unwrap(node.IfFalse); + AddRelatedExpression(node, ifTrue, ifFalse); + AddRelatedExpression(node, ifFalse, ifTrue); + + return node; + } + + protected override Expression VisitMethodCall(MethodCallExpression node) + { + if (VisitorUtil.IsMappedAs(node.Method)) + { + var rawParameter = Visit(node.Arguments[0]); + var parameter = rawParameter as ConstantExpression; + var type = node.Arguments[1] as ConstantExpression; + if (parameter == null) + throw new HibernateException( + $"{nameof(LinqExtensionMethods.MappedAs)} must be called on an expression which can be evaluated as " + + $"{nameof(ConstantExpression)}. It was call on {rawParameter?.GetType().Name ?? "null"} instead."); + if (type == null) + throw new HibernateException( + $"{nameof(LinqExtensionMethods.MappedAs)} type must be supplied as {nameof(ConstantExpression)}. " + + $"It was {node.Arguments[1]?.GetType().Name ?? "null"} instead."); + + ConstantExpressions[parameter] = (IType) type.Value; + + return _removeMappedAsCalls + ? rawParameter + : node; + } + + return base.VisitMethodCall(node); + } + + protected override Expression VisitConstant(ConstantExpression node) + { + if (node.Value is IEntityNameProvider || RelatedExpressions.ContainsKey(node) || !_parameters.ContainsKey(node)) + { + return node; + } + + RelatedExpressions.Add(node, new HashSet()); + ConstantExpressions.Add(node, null); + return node; + } + + public override Expression Visit(Expression node) + { + if (node is SubQueryExpression subQueryExpression) + { + subQueryExpression.QueryModel.TransformExpressions(Visit); + } + + return base.Visit(node); + } + + private void VisitAssign(Expression leftNode, Expression rightNode) + { + // Insert and Update statements have assign expressions, where the left side is a parameter and its name + // represents the property path to be assigned + if (!(leftNode is ParameterExpression parameterExpression) || + !(rightNode is ConstantExpression constantExpression)) + { + return; + } + + var entityName = _sessionFactory.TryGetGuessEntityName(_targetType); + if (entityName == null) + { + return; + } + + var persister = _sessionFactory.GetEntityPersister(entityName); + ConstantExpressions[constantExpression] = persister.EntityMetamodel.GetPropertyType(parameterExpression.Name); + } + + private void AddRelatedExpression(Expression node, Expression left, Expression right) + { + if (left.NodeType == ExpressionType.MemberAccess || + IsDynamicMember(left) || + left is QuerySourceReferenceExpression) + { + AddRelatedExpression(right, left); + if (NonVoidOperators.Contains(node.NodeType)) + { + AddRelatedExpression(node, left); + } + } + + // Copy all found MemberExpressions to the other side + // (e.g. (o.Prop ?? constant1) == constant2 -> copy o.Prop to constant2) + if (RelatedExpressions.TryGetValue(left, out var set)) + { + foreach (var nestedMemberExpression in set) + { + AddRelatedExpression(right, nestedMemberExpression); + if (NonVoidOperators.Contains(node.NodeType)) + { + AddRelatedExpression(node, nestedMemberExpression); + } + } + } + } + + private void AddRelatedExpression(Expression expression, Expression relatedExpression) + { + if (!RelatedExpressions.TryGetValue(expression, out var set)) + { + set = new HashSet(); + RelatedExpressions.Add(expression, set); + } + + set.Add(relatedExpression); + } + + private bool IsDynamicMember(Expression expression) + { + switch (expression) + { +#if NETCOREAPP2_0 + case InvocationExpression invocationExpression: + // session.Query().Where("Properties.Name == @0", "First Product") + return ExpressionsHelper.TryGetDynamicMemberBinder(invocationExpression, out _); +#endif + case DynamicExpression dynamicExpression: + return dynamicExpression.Binder is GetMemberBinder; + case MethodCallExpression methodCallExpression: + // session.Query() where p.Properties["Name"] == "First Product" select p + return VisitorUtil.TryGetPotentialDynamicComponentDictionaryMember(methodCallExpression, out _); + default: + return false; + } + } + + private static Expression Unwrap(Expression expression) + { + if (expression is UnaryExpression unaryExpression) + { + return unaryExpression.Operand; + } + + return expression; + } + } + } +} diff --git a/src/NHibernate/Linq/Visitors/VisitorUtil.cs b/src/NHibernate/Linq/Visitors/VisitorUtil.cs index 22ac89dd0aa..40dbaeb4fb7 100644 --- a/src/NHibernate/Linq/Visitors/VisitorUtil.cs +++ b/src/NHibernate/Linq/Visitors/VisitorUtil.cs @@ -13,25 +13,12 @@ public static class VisitorUtil { public static bool IsDynamicComponentDictionaryGetter(MethodInfo method, Expression targetObject, IEnumerable arguments, ISessionFactory sessionFactory, out string memberName) { - memberName = null; - - // A dynamic component must be an IDictionary with a string key. - - if (method.Name != "get_Item" || !typeof(IDictionary).IsAssignableFrom(targetObject.Type) && !typeof(IDictionary).IsAssignableFrom(targetObject.Type)) - return false; - - var key = arguments.First() as ConstantExpression; - if (key == null || key.Type != typeof(string)) - return false; - - // The potential member name - memberName = (string)key.Value; - - // Need the owning member (the dictionary). - var member = targetObject as MemberExpression; - if (member == null) + if (!TryGetPotentialDynamicComponentDictionaryMember(method, targetObject, arguments, out memberName)) + { return false; + } + var member = (MemberExpression) targetObject; var memberPath = member.Member.Name; var metaData = sessionFactory.GetClassMetadata(member.Expression.Type); @@ -131,5 +118,42 @@ public static string GetMemberPath(this MemberExpression memberExpression) } return path; } + + internal static bool TryGetPotentialDynamicComponentDictionaryMember(MethodCallExpression expression, out string memberName) + { + return TryGetPotentialDynamicComponentDictionaryMember( + expression.Method, + expression.Object, + expression.Arguments, + out memberName); + } + + internal static bool TryGetPotentialDynamicComponentDictionaryMember( + MethodInfo method, + Expression targetObject, + IEnumerable arguments, + out string memberName) + { + memberName = null; + // A dynamic component must be an IDictionary with a string key. + if (method.Name != "get_Item" || + targetObject.NodeType != ExpressionType.MemberAccess || // Need the owning member (the dictionary). + !(arguments.First() is ConstantExpression key) || + key.Type != typeof(string) || + (!typeof(IDictionary).IsAssignableFrom(targetObject.Type) && !typeof(IDictionary).IsAssignableFrom(targetObject.Type))) + { + return false; + } + + // The potential member name + memberName = (string) key.Value; + return true; + } + + internal static bool IsMappedAs(MethodInfo methodInfo) + { + return methodInfo.Name == nameof(LinqExtensionMethods.MappedAs) && + methodInfo.DeclaringType == typeof(LinqExtensionMethods); + } } } diff --git a/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs b/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs index 886d4e0e2b1..689457a7403 100644 --- a/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs +++ b/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs @@ -62,6 +62,7 @@ internal class WhereJoinDetector : RelinqExpressionVisitor // TODO: There are a number of types of expressions that we didn't handle here due to time constraints. For example, the ?: operator could be checked easily. private readonly IIsEntityDecider _isEntityDecider; private readonly IJoiner _joiner; + private readonly ISessionFactoryImplementor _sessionFactory; private readonly Stack _handled = new Stack(); @@ -71,10 +72,11 @@ internal class WhereJoinDetector : RelinqExpressionVisitor // The following is used for member expressions traversal. private int _memberExpressionDepth; - internal WhereJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner) + internal WhereJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner, ISessionFactoryImplementor sessionFactory) { _isEntityDecider = isEntityDecider; _joiner = joiner; + _sessionFactory = sessionFactory; } public Expression Transform(Expression expression) @@ -329,7 +331,7 @@ protected override Expression VisitMember(MemberExpression expression) { // Don't add joins for things like a.B == a.C where B and C are entities. // We only need to join B when there's something like a.B.D. - var key = ExpressionKeyVisitor.Visit(expression, null); + var key = ExpressionKeyVisitor.Visit(expression, null, _sessionFactory); if (_memberExpressionDepth > 0 && _joiner.CanAddJoin(expression)) { diff --git a/src/NHibernate/Param/NamedParameter.cs b/src/NHibernate/Param/NamedParameter.cs index b42f69925f0..a9a9b67de2b 100644 --- a/src/NHibernate/Param/NamedParameter.cs +++ b/src/NHibernate/Param/NamedParameter.cs @@ -5,16 +5,24 @@ namespace NHibernate.Param public class NamedParameter { public NamedParameter(string name, object value, IType type) + : this(name, value, type, false) + { + } + + internal NamedParameter(string name, object value, IType type, bool isCollection) { Name = name; Value = value; Type = type; + IsCollection = isCollection; } public string Name { get; private set; } public object Value { get; internal set; } public IType Type { get; internal set; } + public virtual bool IsCollection { get; } + public bool Equals(NamedParameter other) { if (ReferenceEquals(null, other)) @@ -38,4 +46,4 @@ public override int GetHashCode() return (Name != null ? Name.GetHashCode() : 0); } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index 08a60aeeb66..eebee36c8dd 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Dynamic; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -15,6 +16,7 @@ using NHibernate.Type; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Parsing; namespace NHibernate.Util { @@ -30,6 +32,33 @@ public static MemberInfo DecodeMemberAccessExpression(Expressi return ((MemberExpression)expression.Body).Member; } +#if NETCOREAPP2_0 + /// + /// Try to retrieve from a reduced expression. + /// + /// The reduced dynamic expression. + /// The out binder parameter. + /// Whether the binder was found. + internal static bool TryGetDynamicMemberBinder(InvocationExpression expression, out GetMemberBinder memberBinder) + { + // This is an ugly workaround for dynamic expressions in .NET Core. In .NET Core a dynamic expression is reduced + // when first visited by a expression visitor that is not a DynamicExpressionVisitor, where in .NET Framework it is never reduced. + // As RelinqExpressionVisitor does not extend DynamicExpressionVisitor, we will always have a reduced dynamic expression in .NET Core. + // Unfortunately we can not tap into the expression tree earlier to intercept the dynamic expression + if (expression.Arguments.Count == 2 && + expression.Arguments[0] is ConstantExpression constant && + constant.Value is CallSite site && + site.Binder is GetMemberBinder binder) + { + memberBinder = binder; + return true; + } + + memberBinder = null; + return false; + } +#endif + /// /// Check whether the given expression represent a variable. /// @@ -635,6 +664,34 @@ protected override Expression VisitMember(MemberExpression node) return base.Visit(node.Expression); } +#if NETCOREAPP2_0 + protected override Expression VisitInvocation(InvocationExpression node) + { + if (TryGetDynamicMemberBinder(node, out var binder)) + { + _memberPaths.Push(new MemberMetadata(binder.Name, _convertType, _hasIndexer)); + _convertType = null; + _hasIndexer = false; + return base.Visit(node.Arguments[1]); + } + + return base.VisitInvocation(node); + } +#endif + + protected override Expression VisitDynamic(DynamicExpression node) + { + if (node.Binder is GetMemberBinder binder) + { + _memberPaths.Push(new MemberMetadata(binder.Name, _convertType, _hasIndexer)); + _convertType = null; + _hasIndexer = false; + return base.Visit(node.Arguments[0]); + } + + return Visit(node); + } + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression node) { if (node.ReferencedQuerySource is IFromClause fromClause) @@ -721,6 +778,14 @@ protected override Expression VisitMethodCall(MethodCallExpression node) ); } + if (VisitorUtil.TryGetPotentialDynamicComponentDictionaryMember(node, out var memberName)) + { + _memberPaths.Push(new MemberMetadata(memberName, _convertType, _hasIndexer)); + _convertType = null; + _hasIndexer = false; + return base.Visit(node.Object); + } + return Visit(node); } diff --git a/src/NHibernate/Util/ParameterHelper.cs b/src/NHibernate/Util/ParameterHelper.cs new file mode 100644 index 00000000000..d0b6bd14625 --- /dev/null +++ b/src/NHibernate/Util/ParameterHelper.cs @@ -0,0 +1,139 @@ +using System; +using System.Collections; +using System.Linq; +using NHibernate.Engine; +using NHibernate.Proxy; +using NHibernate.Type; + +namespace NHibernate.Util +{ + internal static class ParameterHelper + { + /// + /// Guesses the from the param's value. + /// + /// The object to guess the of. + /// The session factory to search for entity persister. + /// Whether is a collection. + /// An for the object. + /// + /// Thrown when the param is null because the + /// can't be guess from a null value. + /// + public static IType TryGuessType(object param, ISessionFactoryImplementor sessionFactory, bool isCollection) + { + if (param == null) + { + return null; + } + + if (param is IEnumerable enumerable && isCollection) + { + var firstValue = enumerable.Cast().FirstOrDefault(); + return firstValue == null + ? TryGuessType(enumerable.GetCollectionElementType(), sessionFactory) + : TryGuessType(firstValue, sessionFactory, false); + } + + var clazz = NHibernateProxyHelper.GetClassWithoutInitializingProxy(param); + return TryGuessType(clazz, sessionFactory); + } + + /// + /// Guesses the from the param's value. + /// + /// The object to guess the of. + /// The session factory to search for entity persister. + /// An for the object. + /// + /// Thrown when the param is null because the + /// can't be guess from a null value. + /// + public static IType GuessType(object param, ISessionFactoryImplementor sessionFactory) + { + if (param == null) + { + throw new ArgumentNullException(nameof(param), "The IType can not be guessed for a null value."); + } + + System.Type clazz = NHibernateProxyHelper.GetClassWithoutInitializingProxy(param); + return GuessType(clazz, sessionFactory); + } + + /// + /// Guesses the from the . + /// + /// The to guess the of. + /// The session factory to search for entity persister. + /// Whether is a collection. + /// An for the . + /// + /// Thrown when the clazz is null because the + /// can't be guess from a null type. + /// + public static IType TryGuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory, bool isCollection) + { + if (clazz == null) + { + return null; + } + + if (isCollection) + { + return TryGuessType(ReflectHelper.GetCollectionElementType(clazz), sessionFactory, false); + } + + return TryGuessType(clazz, sessionFactory); + } + + /// + /// Guesses the from the . + /// + /// The to guess the of. + /// The session factory to search for entity persister. + /// An for the . + /// + /// Thrown when the clazz is null because the + /// can't be guess from a null type. + /// + public static IType GuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory) + { + if (clazz == null) + { + throw new ArgumentNullException(nameof(clazz), "The IType can not be guessed for a null value."); + } + + return TryGuessType(clazz, sessionFactory) ?? + throw new HibernateException("Could not determine a type for class: " + clazz.AssemblyQualifiedName); + } + + /// + /// Guesses the from the . + /// + /// The to guess the of. + /// The session factory to search for entity persister. + /// An for the . + /// + /// Thrown when the clazz is null because the + /// can't be guess from a null type. + /// + public static IType TryGuessType(System.Type clazz, ISessionFactoryImplementor sessionFactory) + { + if (clazz == null) + { + return null; + } + + var type = TypeFactory.HeuristicType(clazz); + if (type == null || type is SerializableType) + { + if (sessionFactory.TryGetEntityPersister(clazz.FullName) != null) + { + return NHibernateUtil.Entity(clazz); + } + } + + return type; + } + } +} diff --git a/src/NHibernate/Util/ReflectionCache.cs b/src/NHibernate/Util/ReflectionCache.cs index c40a395f98d..47fde15950d 100644 --- a/src/NHibernate/Util/ReflectionCache.cs +++ b/src/NHibernate/Util/ReflectionCache.cs @@ -54,6 +54,11 @@ internal static class EnumerableMethods internal static readonly MethodInfo ToListDefinition = ReflectHelper.FastGetMethodDefinition(Enumerable.ToList, default(IEnumerable)); + + internal static readonly MethodInfo SkipDefinition = + ReflectHelper.FastGetMethodDefinition(Enumerable.Skip, default(IEnumerable), default(int)); + internal static readonly MethodInfo TakeDefinition = + ReflectHelper.FastGetMethodDefinition(Enumerable.Take, default(IEnumerable), default(int)); } internal static class MethodBaseMethods @@ -215,6 +220,11 @@ internal static class QueryableMethods ReflectHelper.FastGetMethodDefinition(Queryable.Average, default(IQueryable), default(Expression>)); internal static readonly MethodInfo AverageWithSelectorOfNullableDecimalDefinition = ReflectHelper.FastGetMethodDefinition(Queryable.Average, default(IQueryable), default(Expression>)); + + internal static readonly MethodInfo SkipDefinition = + ReflectHelper.FastGetMethodDefinition(Queryable.Skip, default(IQueryable), default(int)); + internal static readonly MethodInfo TakeDefinition = + ReflectHelper.FastGetMethodDefinition(Queryable.Take, default(IQueryable), default(int)); } internal static class TypeMethods