diff --git a/src/NHibernate.Test/Async/Linq/CustomExtensionsExample.cs b/src/NHibernate.Test/Async/Linq/CustomExtensionsExample.cs index 8f53b1f0cf6..cc04d951659 100644 --- a/src/NHibernate.Test/Async/Linq/CustomExtensionsExample.cs +++ b/src/NHibernate.Test/Async/Linq/CustomExtensionsExample.cs @@ -14,6 +14,7 @@ using System.Reflection; using System.Text.RegularExpressions; using NHibernate.Cfg; +using NHibernate.DomainModel.Northwind.Entities; using NHibernate.Hql.Ast; using NHibernate.Linq.Functions; using NHibernate.Linq.Visitors; @@ -33,6 +34,14 @@ protected override void Configure(NHibernate.Cfg.Configuration configuration) configuration.LinqToHqlGeneratorsRegistry(); } + [Test] + public async Task CanUseObjectEqualsAsync() + { + var users = await (db.Users.Where(o => ((object) EnumStoredAsString.Medium).Equals(o.NullableEnum1)).ToListAsync()); + Assert.That(users.Count, Is.EqualTo(2)); + Assert.That(users.All(c => c.NullableEnum1 == EnumStoredAsString.Medium), Is.True); + } + [Test] public async Task CanUseMyCustomExtensionAsync() { diff --git a/src/NHibernate.Test/Linq/CustomExtensionsExample.cs b/src/NHibernate.Test/Linq/CustomExtensionsExample.cs index 89300508a4d..c9b76f92cec 100644 --- a/src/NHibernate.Test/Linq/CustomExtensionsExample.cs +++ b/src/NHibernate.Test/Linq/CustomExtensionsExample.cs @@ -4,6 +4,7 @@ using System.Reflection; using System.Text.RegularExpressions; using NHibernate.Cfg; +using NHibernate.DomainModel.Northwind.Entities; using NHibernate.Hql.Ast; using NHibernate.Linq.Functions; using NHibernate.Linq.Visitors; @@ -30,6 +31,7 @@ public MyLinqToHqlGeneratorsRegistry():base() { RegisterGenerator(ReflectHelper.GetMethodDefinition(() => MyLinqExtensions.IsLike(null, null)), new IsLikeGenerator()); + RegisterGenerator(ReflectHelper.GetMethodDefinition(() => new object().Equals(null)), new ObjectEqualsGenerator()); } } @@ -48,6 +50,21 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, } } + public class ObjectEqualsGenerator : BaseHqlGeneratorForMethod + { + public ObjectEqualsGenerator() + { + SupportedMethods = new[] { ReflectHelper.GetMethodDefinition(() => new object().Equals(null)) }; + } + + public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, + ReadOnlyCollection arguments, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor) + { + return treeBuilder.Equality(visitor.Visit(targetObject).AsExpression(), + visitor.Visit(arguments[0]).AsExpression()); + } + } + [TestFixture] public class CustomExtensionsExample : LinqTestCase { @@ -56,6 +73,14 @@ protected override void Configure(NHibernate.Cfg.Configuration configuration) configuration.LinqToHqlGeneratorsRegistry(); } + [Test] + public void CanUseObjectEquals() + { + var users = db.Users.Where(o => ((object) EnumStoredAsString.Medium).Equals(o.NullableEnum1)).ToList(); + Assert.That(users.Count, Is.EqualTo(2)); + Assert.That(users.All(c => c.NullableEnum1 == EnumStoredAsString.Medium), Is.True); + } + [Test] public void CanUseMyCustomExtension() { diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index ccc950d57f2..977e6d9c015 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -156,6 +156,11 @@ private static IType GetParameterType( return candidateType; } + if (visitor.NotGuessableConstants.Contains(constantExpression)) + { + return null; + } + // No related MemberExpressions was found, guess the type by value or its type when null. // When a numeric parameter is compared to different columns with different types (e.g. Where(o => o.Single >= singleParam || o.Double <= singleParam)) // do not change the parameter type, but instead cast the parameter when comparing with different column types. @@ -166,10 +171,13 @@ private static IType GetParameterType( private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor { + private bool _hqlGenerator; private readonly bool _removeMappedAsCalls; private readonly System.Type _targetType; private readonly IDictionary _parameters; private readonly ISessionFactoryImplementor _sessionFactory; + private readonly ILinqToHqlGeneratorsRegistry _functionRegistry; + public readonly HashSet NotGuessableConstants = new HashSet(); public readonly Dictionary ConstantExpressions = new Dictionary(); public readonly Dictionary> ParameterConstants = @@ -187,6 +195,7 @@ public ConstantTypeLocatorVisitor( _targetType = targetType; _sessionFactory = sessionFactory; _parameters = parameters; + _functionRegistry = sessionFactory.Settings.LinqToHqlGeneratorsRegistry; } protected override Expression VisitBinary(BinaryExpression node) @@ -257,6 +266,16 @@ protected override Expression VisitMethodCall(MethodCallExpression node) return node; } + // For hql method generators we do not want to guess the parameter type here, let hql logic figure it out. + if (_functionRegistry.TryGetGenerator(node.Method, out _)) + { + var origHqlGenerator = _hqlGenerator; + _hqlGenerator = true; + var expression = base.VisitMethodCall(node); + _hqlGenerator = origHqlGenerator; + return expression; + } + return base.VisitMethodCall(node); } @@ -267,6 +286,11 @@ protected override Expression VisitConstant(ConstantExpression node) return node; } + if (_hqlGenerator) + { + NotGuessableConstants.Add(node); + } + RelatedExpressions.Add(node, new HashSet()); ConstantExpressions.Add(node, null); if (!ParameterConstants.TryGetValue(param, out var set))