diff --git a/src/NHibernate.Test/Async/Criteria/Lambda/IntegrationFixture.cs b/src/NHibernate.Test/Async/Criteria/Lambda/IntegrationFixture.cs index 017634dc4f5..abd2733e24c 100644 --- a/src/NHibernate.Test/Async/Criteria/Lambda/IntegrationFixture.cs +++ b/src/NHibernate.Test/Async/Criteria/Lambda/IntegrationFixture.cs @@ -510,5 +510,40 @@ public async Task StatelessSessionAsync() Assert.That(statelessPerson2.Id, Is.EqualTo(personId)); } } + + [Test] + public async Task QueryOverArithmeticAsync() + { + using (ISession s = OpenSession()) + using (ITransaction t = s.BeginTransaction()) + { + await (s.SaveAsync(new Person() {Name = "test person 1", Age = 20})); + await (s.SaveAsync(new Person() {Name = "test person 2", Age = 50})); + await (t.CommitAsync()); + } + + using (var s = OpenSession()) + { + var persons1 = await (s.QueryOver().Where(p => ((p.Age * 2) / 2) + 20 - 20 == 20).ListAsync()); + var persons2 = await (s.QueryOver().Where(p => (-(-p.Age)) > 20).ListAsync()); + var persons3 = await (s.QueryOver().WhereRestrictionOn(p => ((p.Age * 2) / 2) + 20 - 20).IsBetween(19).And(21).ListAsync()); + var persons4 = await (s.QueryOver().WhereRestrictionOn(p => -(-p.Age)).IsBetween(19).And(21).ListAsync()); + var persons5 = await (s.QueryOver().WhereRestrictionOn(p => ((p.Age * 2) / 2) + 20 - 20).IsBetween(19).And(51).ListAsync()); + var persons6 = await (s.QueryOver().Where(p => ((p.Age * 2) / 2) + 20 - 20 == p.Age - p.Age + 20).ListAsync()); +#pragma warning disable CS0472 // The result of the expression is always the same since a value of this type is never equal to 'null' + var persons7 = await (s.QueryOver().Where(p => ((p.Age * 2) / 2) + 20 - 20 == null || p.Age * 2 == 20 * 1).ListAsync()); +#pragma warning restore CS0472 // The result of the expression is always the same since a value of this type is never equal to 'null' + var val1 = await (s.QueryOver().Select(p => p.Age * 2).Where(p => p.Age == 20).SingleOrDefaultAsync()); + + Assert.That(persons1.Count, Is.EqualTo(1)); + Assert.That(persons2.Count, Is.EqualTo(1)); + Assert.That(persons3.Count, Is.EqualTo(1)); + Assert.That(persons4.Count, Is.EqualTo(1)); + Assert.That(persons5.Count, Is.EqualTo(2)); + Assert.That(persons6.Count, Is.EqualTo(1)); + Assert.That(persons7.Count, Is.EqualTo(0)); + Assert.That(val1, Is.EqualTo(40)); + } + } } } diff --git a/src/NHibernate.Test/Criteria/Lambda/IntegrationFixture.cs b/src/NHibernate.Test/Criteria/Lambda/IntegrationFixture.cs index 0b3f46ba784..dd23c4b91b4 100644 --- a/src/NHibernate.Test/Criteria/Lambda/IntegrationFixture.cs +++ b/src/NHibernate.Test/Criteria/Lambda/IntegrationFixture.cs @@ -498,5 +498,40 @@ public void StatelessSession() Assert.That(statelessPerson2.Id, Is.EqualTo(personId)); } } + + [Test] + public void QueryOverArithmetic() + { + using (ISession s = OpenSession()) + using (ITransaction t = s.BeginTransaction()) + { + s.Save(new Person() {Name = "test person 1", Age = 20}); + s.Save(new Person() {Name = "test person 2", Age = 50}); + t.Commit(); + } + + using (var s = OpenSession()) + { + var persons1 = s.QueryOver().Where(p => ((p.Age * 2) / 2) + 20 - 20 == 20).List(); + var persons2 = s.QueryOver().Where(p => (-(-p.Age)) > 20).List(); + var persons3 = s.QueryOver().WhereRestrictionOn(p => ((p.Age * 2) / 2) + 20 - 20).IsBetween(19).And(21).List(); + var persons4 = s.QueryOver().WhereRestrictionOn(p => -(-p.Age)).IsBetween(19).And(21).List(); + var persons5 = s.QueryOver().WhereRestrictionOn(p => ((p.Age * 2) / 2) + 20 - 20).IsBetween(19).And(51).List(); + var persons6 = s.QueryOver().Where(p => ((p.Age * 2) / 2) + 20 - 20 == p.Age - p.Age + 20).List(); +#pragma warning disable CS0472 // The result of the expression is always the same since a value of this type is never equal to 'null' + var persons7 = s.QueryOver().Where(p => ((p.Age * 2) / 2) + 20 - 20 == null || p.Age * 2 == 20 * 1).List(); +#pragma warning restore CS0472 // The result of the expression is always the same since a value of this type is never equal to 'null' + var val1 = s.QueryOver().Select(p => p.Age * 2).Where(p => p.Age == 20).SingleOrDefault(); + + Assert.That(persons1.Count, Is.EqualTo(1)); + Assert.That(persons2.Count, Is.EqualTo(1)); + Assert.That(persons3.Count, Is.EqualTo(1)); + Assert.That(persons4.Count, Is.EqualTo(1)); + Assert.That(persons5.Count, Is.EqualTo(2)); + Assert.That(persons6.Count, Is.EqualTo(1)); + Assert.That(persons7.Count, Is.EqualTo(0)); + Assert.That(val1, Is.EqualTo(40)); + } + } } } diff --git a/src/NHibernate/Criterion/ConstantProjection.cs b/src/NHibernate/Criterion/ConstantProjection.cs index 3504d633ab6..2d901a51b76 100644 --- a/src/NHibernate/Criterion/ConstantProjection.cs +++ b/src/NHibernate/Criterion/ConstantProjection.cs @@ -13,7 +13,7 @@ namespace NHibernate.Criterion public class ConstantProjection : SimpleProjection { private readonly object value; - private readonly TypedValue typedValue; + public TypedValue TypedValue { get; } public ConstantProjection(object value) : this(value, NHibernateUtil.GuessType(value.GetType())) { @@ -22,7 +22,7 @@ public ConstantProjection(object value) : this(value, NHibernateUtil.GuessType(v public ConstantProjection(object value, IType type) { this.value = value; - typedValue = new TypedValue(type, this.value); + TypedValue = new TypedValue(type, this.value); } public override bool IsAggregate @@ -43,19 +43,19 @@ public override bool IsGrouped public override SqlString ToSqlString(ICriteria criteria, int position, ICriteriaQuery criteriaQuery) { return new SqlString( - criteriaQuery.NewQueryParameter(typedValue).Single(), + criteriaQuery.NewQueryParameter(TypedValue).Single(), " as ", GetColumnAliases(position, criteria, criteriaQuery)[0]); } public override IType[] GetTypes(ICriteria criteria, ICriteriaQuery criteriaQuery) { - return new IType[] { typedValue.Type }; + return new IType[] { TypedValue.Type }; } public override TypedValue[] GetTypedValues(ICriteria criteria, ICriteriaQuery criteriaQuery) { - return new TypedValue[] { typedValue }; + return new TypedValue[] { TypedValue }; } } } diff --git a/src/NHibernate/Impl/ExpressionProcessor.cs b/src/NHibernate/Impl/ExpressionProcessor.cs index 874d1c6e063..945a4402494 100644 --- a/src/NHibernate/Impl/ExpressionProcessor.cs +++ b/src/NHibernate/Impl/ExpressionProcessor.cs @@ -5,6 +5,9 @@ using System.Runtime.CompilerServices; using System.Text.RegularExpressions; using NHibernate.Criterion; +using NHibernate.Dialect.Function; +using NHibernate.Engine; +using NHibernate.Type; using NHibernate.Util; using Expression = System.Linq.Expressions.Expression; @@ -84,16 +87,18 @@ public Order CreateOrder(Func orderStringDelegate, Func /// Retrieve the property name from a supplied PropertyProjection - /// Note: throws if the supplied IProjection is not a PropertyProjection + /// Note: throws if the supplied IProjection is not a IPropertyProjection /// public string AsProperty() { if (_property != null) return _property; - var propertyProjection = _projection as PropertyProjection; + var propertyProjection = _projection as IPropertyProjection; if (propertyProjection == null) throw new InvalidOperationException("Cannot determine property for " + _projection); return propertyProjection.PropertyName; } + + internal bool IsConstant(out ConstantProjection value) => (value = _projection as ConstantProjection) != null; } private static readonly Dictionary> _simpleExpressionCreators; @@ -101,6 +106,8 @@ public string AsProperty() private static readonly Dictionary>> _subqueryExpressionCreatorTypes; private static readonly Dictionary> _customMethodCallProcessors; private static readonly Dictionary> _customProjectionProcessors; + private static readonly Dictionary _binaryArithmethicTemplates = new Dictionary(); + private static readonly ISQLFunction _unaryNegateTemplate; static ExpressionProcessor() { @@ -195,6 +202,17 @@ static ExpressionProcessor() RegisterCustomProjection(() => Math.Round(default(double), default(int)), ProjectionsExtensions.ProcessRound); RegisterCustomProjection(() => Math.Round(default(decimal), default(int)), ProjectionsExtensions.ProcessRound); RegisterCustomProjection(() => ProjectionsExtensions.AsEntity(default(object)), ProjectionsExtensions.ProcessAsEntity); + + RegisterBinaryArithmeticExpression(ExpressionType.Add, "+"); + RegisterBinaryArithmeticExpression(ExpressionType.Subtract, "-"); + RegisterBinaryArithmeticExpression(ExpressionType.Multiply, "*"); + RegisterBinaryArithmeticExpression(ExpressionType.Divide, "/"); + _unaryNegateTemplate = new VarArgsSQLFunction("(-", string.Empty, ")"); + } + + private static void RegisterBinaryArithmeticExpression(ExpressionType type, string sqlOperand) + { + _binaryArithmethicTemplates[type] = new VarArgsSQLFunction("(", sqlOperand, ")"); } private static ICriterion Eq(ProjectionInfo property, object value) @@ -245,15 +263,13 @@ public static object FindValue(Expression expression) public static ProjectionInfo FindMemberProjection(Expression expression) { if (!IsMemberExpression(expression)) - return ProjectionInfo.ForProjection(Projections.Constant(FindValue(expression))); + return AsArithmeticProjection(expression) + ?? ProjectionInfo.ForProjection(Projections.Constant(FindValue(expression), NHibernateUtil.GuessType(expression.Type))); - var unaryExpression = expression as UnaryExpression; - if (unaryExpression != null) + var unwrapExpression = UnwrapConvertExpression(expression); + if (unwrapExpression != null) { - if (!IsConversion(unaryExpression.NodeType)) - throw new ArgumentException("Cannot interpret member from " + expression, nameof(expression)); - - return FindMemberProjection(unaryExpression.Operand); + return FindMemberProjection(unwrapExpression); } var methodCallExpression = expression as MethodCallExpression; @@ -266,20 +282,69 @@ public static ProjectionInfo FindMemberProjection(Expression expression) return ProjectionInfo.ForProjection(processor(methodCallExpression)); } } - var memberExpression = expression as MemberExpression; - if (memberExpression != null) + var memberExpression = expression as MemberExpression; + if (memberExpression != null) { - var signature = Signature(memberExpression.Member); + var signature = Signature(memberExpression.Member); Func processor; if (_customProjectionProcessors.TryGetValue(signature, out processor)) { - return ProjectionInfo.ForProjection(processor(memberExpression)); + return ProjectionInfo.ForProjection(processor(memberExpression)); } } return ProjectionInfo.ForProperty(FindMemberExpression(expression)); } + private static Expression UnwrapConvertExpression(Expression expression) + { + if (expression is UnaryExpression unaryExpression) + { + if (!IsConversion(unaryExpression.NodeType)) + { + if (IsSupportedUnaryExpression(unaryExpression)) + return null; + + throw new ArgumentException("Cannot interpret member from " + expression, nameof(expression)); + } + return unaryExpression.Operand; + } + + return null; + } + + private static bool IsSupportedUnaryExpression(UnaryExpression expression) + { + return expression.NodeType == ExpressionType.Negate; + } + + private static ProjectionInfo AsArithmeticProjection(Expression expression) + { + if (!(expression is BinaryExpression be)) + { + if (expression is UnaryExpression unary && unary.NodeType == ExpressionType.Negate) + { + return ProjectionInfo.ForProjection( + new SqlFunctionProjection(_unaryNegateTemplate, TypeFactory.HeuristicType(unary.Type), FindMemberProjection(unary.Operand).AsProjection())); + } + + var unwrapExpression = UnwrapConvertExpression(expression); + return unwrapExpression != null ? AsArithmeticProjection(unwrapExpression) : null; + } + + if (!_binaryArithmethicTemplates.TryGetValue(be.NodeType, out var template)) + { + return null; + } + + return ProjectionInfo.ForProjection( + new SqlFunctionProjection( + template, + TypeFactory.HeuristicType(be.Type), + FindMemberProjection(be.Left).AsProjection(), + FindMemberProjection(be.Right).AsProjection())); + } + //http://stackoverflow.com/a/2509524/259946 private static readonly Regex GeneratedMemberNameRegex = new Regex(@"^(CS\$)?<\w*>[1-9a-s]__[a-zA-Z]+[0-9]*$", RegexOptions.Compiled | RegexOptions.Singleline); @@ -407,13 +472,10 @@ private static System.Type FindMemberType(Expression expression) return memberExpression.Type; } - var unaryExpression = expression as UnaryExpression; - if (unaryExpression != null) + var unwrapExpression = UnwrapConvertExpression(expression); + if (unwrapExpression != null) { - if (!IsConversion(unaryExpression.NodeType)) - throw new ArgumentException("Cannot interpret member from " + expression, nameof(expression)); - - return FindMemberType(unaryExpression.Operand); + return FindMemberType(unwrapExpression); } var methodCallExpression = expression as MethodCallExpression; @@ -422,6 +484,9 @@ private static System.Type FindMemberType(Expression expression) return methodCallExpression.Method.ReturnType; } + if (expression is BinaryExpression || expression is UnaryExpression) + return expression.Type; + throw new ArgumentException("Could not determine member type from " + expression, nameof(expression)); } @@ -443,13 +508,10 @@ private static bool IsMemberExpression(Expression expression) return EvaluatesToNull(memberExpression.Expression); } - var unaryExpression = expression as UnaryExpression; - if (unaryExpression != null) + var unwrapExpression = UnwrapConvertExpression(expression); + if (unwrapExpression != null) { - if (!IsConversion(unaryExpression.NodeType)) - throw new ArgumentException("Cannot interpret member from " + expression, nameof(expression)); - - return IsMemberExpression(unaryExpression.Operand); + return IsMemberExpression(unwrapExpression); } var methodCallExpression = expression as MethodCallExpression; @@ -504,21 +566,12 @@ private static object ConvertType(object value, System.Type type) throw new ArgumentException(string.Format("Cannot convert '{0}' to {1}", value, type)); } - private static ICriterion ProcessSimpleExpression(BinaryExpression be) - { - if (be.Left.NodeType == ExpressionType.Call && ((MethodCallExpression)be.Left).Method.Name == "CompareString") - return ProcessVisualBasicStringComparison(be); - - return ProcessSimpleExpression(be.Left, be.Right, be.NodeType); - } - - private static ICriterion ProcessSimpleExpression(Expression left, Expression right, ExpressionType nodeType) + private static ICriterion ProcessSimpleExpression(Expression left, TypedValue rightValue, ExpressionType nodeType) { ProjectionInfo property = FindMemberProjection(left); System.Type propertyType = FindMemberType(left); - object value = FindValue(right); - value = ConvertType(value, propertyType); + var value = ConvertType(rightValue.Value, propertyType); if (value == null) return ProcessSimpleNullExpression(property, nodeType); @@ -530,14 +583,17 @@ private static ICriterion ProcessSimpleExpression(Expression left, Expression ri return simpleExpressionCreator(property, value); } - private static ICriterion ProcessVisualBasicStringComparison(BinaryExpression be) + private static ICriterion ProcessAsVisualBasicStringComparison(Expression left, ExpressionType nodeType) { - var methodCall = (MethodCallExpression)be.Left; + if (left.NodeType != ExpressionType.Call) + { + return null; + } - if (IsMemberExpression(methodCall.Arguments[1])) - return ProcessMemberExpression(methodCall.Arguments[0], methodCall.Arguments[1], be.NodeType); - else - return ProcessSimpleExpression(methodCall.Arguments[0], methodCall.Arguments[1], be.NodeType); + var methodCall = (MethodCallExpression) left; + return methodCall.Method.Name == "CompareString" + ? ProcessMemberExpression(methodCall.Arguments[0], methodCall.Arguments[1], nodeType) + : null; } private static ICriterion ProcessSimpleNullExpression(ProjectionInfo property, ExpressionType expressionType) @@ -552,16 +608,16 @@ private static ICriterion ProcessSimpleNullExpression(ProjectionInfo property, E throw new ArgumentException("Cannot supply null value to operator " + expressionType, nameof(expressionType)); } - private static ICriterion ProcessMemberExpression(BinaryExpression be) - { - return ProcessMemberExpression(be.Left, be.Right, be.NodeType); - } - private static ICriterion ProcessMemberExpression(Expression left, Expression right, ExpressionType nodeType) { - ProjectionInfo leftProperty = FindMemberProjection(left); ProjectionInfo rightProperty = FindMemberProjection(right); + if (rightProperty.IsConstant(out var constProjection)) + { + return ProcessAsVisualBasicStringComparison(left, nodeType) + ?? ProcessSimpleExpression(left, constProjection.TypedValue, nodeType); + } + ProjectionInfo leftProperty = FindMemberProjection(left); Func propertyExpressionCreator; if (!_propertyExpressionCreators.TryGetValue(nodeType, out propertyExpressionCreator)) throw new InvalidOperationException("Unhandled property expression type: " + nodeType); @@ -599,11 +655,7 @@ private static ICriterion ProcessBinaryExpression(BinaryExpression expression) case ExpressionType.GreaterThanOrEqual: case ExpressionType.LessThan: case ExpressionType.LessThanOrEqual: - if (IsMemberExpression(expression.Right)) - return ProcessMemberExpression(expression); - else - return ProcessSimpleExpression(expression); - + return ProcessMemberExpression(expression.Left, expression.Right, expression.NodeType); default: throw new NotImplementedException("Unhandled binary expression: " + expression.NodeType + ", " + expression); }