diff --git a/src/NHibernate.Test/Async/Linq/ByMethod/CountTests.cs b/src/NHibernate.Test/Async/Linq/ByMethod/CountTests.cs index 8dfdc304aae..5581badf6c8 100644 --- a/src/NHibernate.Test/Async/Linq/ByMethod/CountTests.cs +++ b/src/NHibernate.Test/Async/Linq/ByMethod/CountTests.cs @@ -11,6 +11,7 @@ using System; using System.Linq; using NHibernate.Cfg; +using NHibernate.Dialect; using NUnit.Framework; using NHibernate.Linq; @@ -110,5 +111,50 @@ into temp Assert.That(result.Count, Is.EqualTo(77)); } + + [Test] + public async Task CheckSqlFunctionNameLongCountAsync() + { + var name = Dialect is MsSql2000Dialect ? "count_big" : "count"; + using (var sqlLog = new SqlLogSpy()) + { + var result = await (db.Orders.LongCountAsync()); + Assert.That(result, Is.EqualTo(830)); + + var log = sqlLog.GetWholeLog(); + Assert.That(log, Does.Contain($"{name}(")); + } + } + + [Test] + public async Task CheckSqlFunctionNameForCountAsync() + { + using (var sqlLog = new SqlLogSpy()) + { + var result = await (db.Orders.CountAsync()); + Assert.That(result, Is.EqualTo(830)); + + var log = sqlLog.GetWholeLog(); + Assert.That(log, Does.Contain("count(")); + } + } + + [Test] + public async Task CheckMssqlCountCastAsync() + { + if (!(Dialect is MsSql2000Dialect)) + { + Assert.Ignore(); + } + + using (var sqlLog = new SqlLogSpy()) + { + var result = await (db.Orders.CountAsync()); + Assert.That(result, Is.EqualTo(830)); + + var log = sqlLog.GetWholeLog(); + Assert.That(log, Does.Not.Contain("cast(")); + } + } } } diff --git a/src/NHibernate.Test/Linq/ByMethod/CountTests.cs b/src/NHibernate.Test/Linq/ByMethod/CountTests.cs index 3bb93c7a083..7ef2d7dbbe6 100644 --- a/src/NHibernate.Test/Linq/ByMethod/CountTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/CountTests.cs @@ -1,6 +1,7 @@ using System; using System.Linq; using NHibernate.Cfg; +using NHibernate.Dialect; using NUnit.Framework; namespace NHibernate.Test.Linq.ByMethod @@ -98,5 +99,50 @@ into temp Assert.That(result.Count, Is.EqualTo(77)); } + + [Test] + public void CheckSqlFunctionNameLongCount() + { + var name = Dialect is MsSql2000Dialect ? "count_big" : "count"; + using (var sqlLog = new SqlLogSpy()) + { + var result = db.Orders.LongCount(); + Assert.That(result, Is.EqualTo(830)); + + var log = sqlLog.GetWholeLog(); + Assert.That(log, Does.Contain($"{name}(")); + } + } + + [Test] + public void CheckSqlFunctionNameForCount() + { + using (var sqlLog = new SqlLogSpy()) + { + var result = db.Orders.Count(); + Assert.That(result, Is.EqualTo(830)); + + var log = sqlLog.GetWholeLog(); + Assert.That(log, Does.Contain("count(")); + } + } + + [Test] + public void CheckMssqlCountCast() + { + if (!(Dialect is MsSql2000Dialect)) + { + Assert.Ignore(); + } + + using (var sqlLog = new SqlLogSpy()) + { + var result = db.Orders.Count(); + Assert.That(result, Is.EqualTo(830)); + + var log = sqlLog.GetWholeLog(); + Assert.That(log, Does.Not.Contain("cast(")); + } + } } } diff --git a/src/NHibernate/Dialect/Dialect.cs b/src/NHibernate/Dialect/Dialect.cs index af1995dca06..53c30b55175 100644 --- a/src/NHibernate/Dialect/Dialect.cs +++ b/src/NHibernate/Dialect/Dialect.cs @@ -55,6 +55,7 @@ public abstract partial class Dialect static Dialect() { StandardAggregateFunctions["count"] = new CountQueryFunctionInfo(); + StandardAggregateFunctions["count_big"] = new CountQueryFunctionInfo(); StandardAggregateFunctions["avg"] = new AvgQueryFunctionInfo(); StandardAggregateFunctions["max"] = new ClassicAggregateFunction("max", false); StandardAggregateFunctions["min"] = new ClassicAggregateFunction("min", false); diff --git a/src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs b/src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs index 299605adb03..e0b78f1e1e2 100644 --- a/src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs +++ b/src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs @@ -1,15 +1,15 @@ using System; using System.Collections; -using System.Text; +using System.Collections.Generic; +using System.Linq; using NHibernate.Engine; using NHibernate.SqlCommand; using NHibernate.Type; -using NHibernate.Util; namespace NHibernate.Dialect.Function { [Serializable] - public class ClassicAggregateFunction : ISQLFunction, IFunctionGrammar + public class ClassicAggregateFunction : ISQLFunction, IFunctionGrammar, ISQLFunctionExtended { private IType returnType = null; private readonly string name; @@ -45,6 +45,15 @@ public virtual IType ReturnType(IType columnType, IMapping mapping) return returnType ?? columnType; } + /// + public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + return ReturnType(argumentTypes.FirstOrDefault(), mapping); + } + + /// + public string FunctionName => name; + public bool HasArguments { get { return true; } diff --git a/src/NHibernate/Dialect/Function/ISQLFunction.cs b/src/NHibernate/Dialect/Function/ISQLFunction.cs index 4433e5dc63c..5302625a01e 100644 --- a/src/NHibernate/Dialect/Function/ISQLFunction.cs +++ b/src/NHibernate/Dialect/Function/ISQLFunction.cs @@ -1,4 +1,6 @@ using System.Collections; +using System.Collections.Generic; +using System.Linq; using NHibernate.Engine; using NHibernate.SqlCommand; using NHibernate.Type; @@ -41,4 +43,45 @@ public interface ISQLFunction /// SQL fragment for the function. SqlString Render(IList args, ISessionFactoryImplementor factory); } + + // 6.0 TODO: Remove + internal static class SQLFunctionExtensions + { + /// + /// Get the type that will be effectively returned by the underlying database. + /// + /// The sql function. + /// The types of arguments. + /// The mapping for retrieving the argument sql types. + /// Whether to throw when the number of arguments is invalid or they are not supported. + /// The type returned by the underlying database or when the number of arguments + /// is invalid or they are not supported. + /// When is set to and the + /// number of arguments is invalid or they are not supported. + public static IType GetEffectiveReturnType( + this ISQLFunction sqlFunction, + IEnumerable argumentTypes, + IMapping mapping, + bool throwOnError) + { + if (!(sqlFunction is ISQLFunctionExtended extendedSqlFunction)) + { + try + { + return sqlFunction.ReturnType(argumentTypes.FirstOrDefault(), mapping); + } + catch (QueryException) + { + if (throwOnError) + { + throw; + } + + return null; + } + } + + return extendedSqlFunction.GetEffectiveReturnType(argumentTypes, mapping, throwOnError); + } + } } diff --git a/src/NHibernate/Dialect/Function/ISQLFunctionExtended.cs b/src/NHibernate/Dialect/Function/ISQLFunctionExtended.cs new file mode 100644 index 00000000000..e2db4747198 --- /dev/null +++ b/src/NHibernate/Dialect/Function/ISQLFunctionExtended.cs @@ -0,0 +1,27 @@ +using System.Collections.Generic; +using NHibernate.Engine; +using NHibernate.Type; + +namespace NHibernate.Dialect.Function +{ + // 6.0 TODO: Merge into ISQLFunction + internal interface ISQLFunctionExtended : ISQLFunction + { + /// + /// The function name or when multiple functions/operators/statements are used. + /// + string FunctionName { get; } + + /// + /// Get the type that will be effectively returned by the underlying database. + /// + /// The types of arguments. + /// The mapping for retrieving the argument sql types. + /// Whether to throw when the number of arguments is invalid or they are not supported. + /// The type returned by the underlying database or when the number of arguments + /// is invalid or they are not supported. + /// When is set to and the + /// number of arguments is invalid or they are not supported. + IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError); + } +} diff --git a/src/NHibernate/Dialect/MsSql2000Dialect.cs b/src/NHibernate/Dialect/MsSql2000Dialect.cs index 7acb3cb9b4f..35c760e85b3 100644 --- a/src/NHibernate/Dialect/MsSql2000Dialect.cs +++ b/src/NHibernate/Dialect/MsSql2000Dialect.cs @@ -286,7 +286,8 @@ protected virtual void RegisterKeywords() protected virtual void RegisterFunctions() { - RegisterFunction("count", new CountBigQueryFunction()); + RegisterFunction("count", new CountQueryFunction()); + RegisterFunction("count_big", new CountBigQueryFunction()); RegisterFunction("abs", new StandardSQLFunction("abs")); RegisterFunction("absval", new StandardSQLFunction("absval")); @@ -704,11 +705,15 @@ protected virtual string GetSelectExistingObject(string catalog, string schema, [Serializable] protected class CountBigQueryFunction : ClassicAggregateFunction { - public CountBigQueryFunction() : base("count_big", true) { } + public CountBigQueryFunction() : base("count_big", true, NHibernateUtil.Int64) { } + } - public override IType ReturnType(IType columnType, IMapping mapping) + [Serializable] + private class CountQueryFunction : CountQueryFunctionInfo + { + public override IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) { - return NHibernateUtil.Int64; + return NHibernateUtil.Int32; } } diff --git a/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs b/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs index 4f6f902891d..0f1c7a64603 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs @@ -310,6 +310,12 @@ private void EndFunctionTemplate(IASTNode m) } } + private void OutAggregateFunctionName(IASTNode m) + { + var aggregateNode = (AggregateNode) m; + Out(aggregateNode.FunctionName); + } + private void CommaBetweenParameters(String comma) { writer.CommaBetweenParameters(comma); diff --git a/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.g b/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.g index 73d04c792ec..21cb3c7521e 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.g +++ b/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.g @@ -150,7 +150,7 @@ selectExpr ; count - : ^(COUNT { Out("count("); } ( distinctOrAll ) ? countExpr { Out(")"); } ) + : ^(c=COUNT { OutAggregateFunctionName(c); Out("("); } ( distinctOrAll ) ? countExpr { Out(")"); } ) ; distinctOrAll @@ -344,7 +344,7 @@ caseExpr ; aggregate - : ^(a=AGGREGATE { Out(a); Out("("); } expr { Out(")"); } ) + : ^(a=AGGREGATE { OutAggregateFunctionName(a); Out("("); } expr { Out(")"); } ) ; diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs index d7dc2f9a0f0..a3887a56c01 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs @@ -1,5 +1,6 @@ using System; using Antlr.Runtime; +using NHibernate.Dialect.Function; using NHibernate.Type; using NHibernate.Hql.Ast.ANTLR.Util; @@ -19,6 +20,19 @@ public AggregateNode(IToken token) { } + public string FunctionName + { + get + { + if (SessionFactoryHelper.FindSQLFunction(Text) is ISQLFunctionExtended sqlFunction) + { + return sqlFunction.FunctionName; + } + + return Text; + } + } + public override IType DataType { get @@ -31,6 +45,7 @@ public override IType DataType base.DataType = value; } } + public override void SetScalarColumnText(int i) { ColumnHelper.GenerateSingleScalarColumn(ASTFactory, this, i); diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs index 88ec9cc22fd..ef6d6272043 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs @@ -1,4 +1,5 @@ using Antlr.Runtime; +using NHibernate.Dialect.Function; using NHibernate.Hql.Ast.ANTLR.Util; using NHibernate.Type; @@ -9,7 +10,7 @@ namespace NHibernate.Hql.Ast.ANTLR.Tree /// Author: josh /// Ported by: Steve Strong /// - class CountNode : AbstractSelectExpression, ISelectExpression + class CountNode : AggregateNode, ISelectExpression { public CountNode(IToken token) : base(token) { @@ -26,9 +27,5 @@ public override IType DataType base.DataType = value; } } - public override void SetScalarColumnText(int i) - { - ColumnHelper.GenerateSingleScalarColumn(ASTFactory, this, i); - } } } diff --git a/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs b/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs index bb208295afc..dd59f2ff0f0 100755 --- a/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs +++ b/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs @@ -307,6 +307,11 @@ public HqlCount Count(HqlExpression child) return new HqlCount(_factory, child); } + public HqlCountBig CountBig(HqlExpression child) + { + return new HqlCountBig(_factory, child); + } + public HqlRowStar RowStar() { return new HqlRowStar(_factory); diff --git a/src/NHibernate/Hql/Ast/HqlTreeNode.cs b/src/NHibernate/Hql/Ast/HqlTreeNode.cs index 160739d7f16..8967174bde3 100755 --- a/src/NHibernate/Hql/Ast/HqlTreeNode.cs +++ b/src/NHibernate/Hql/Ast/HqlTreeNode.cs @@ -697,6 +697,19 @@ public HqlCount(IASTFactory factory, HqlExpression child) } } + public class HqlCountBig : HqlExpression + { + public HqlCountBig(IASTFactory factory) + : base(HqlSqlWalker.COUNT, "count_big", factory) + { + } + + public HqlCountBig(IASTFactory factory, HqlExpression child) + : base(HqlSqlWalker.COUNT, "count_big", factory, child) + { + } + } + public class HqlAs : HqlExpression { public HqlAs(IASTFactory factory, HqlExpression expression, System.Type type) diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index 224af655a24..cd9cd49eadb 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -4,6 +4,7 @@ using System.Linq; using System.Linq.Expressions; using System.Runtime.CompilerServices; +using NHibernate.Dialect.Function; using NHibernate.Engine.Query; using NHibernate.Hql.Ast; using NHibernate.Hql.Ast.ANTLR; @@ -257,7 +258,22 @@ protected HqlTreeNode VisitNhAverage(NhAverageExpression expression) protected HqlTreeNode VisitNhCount(NhCountExpression expression) { - return _hqlTreeBuilder.Cast(_hqlTreeBuilder.Count(VisitExpression(expression.Expression).AsExpression()), expression.Type); + string functionName; + HqlExpression countHqlExpression; + if (expression is NhLongCountExpression) + { + functionName = "count_big"; + countHqlExpression = _hqlTreeBuilder.CountBig(VisitExpression(expression.Expression).AsExpression()); + } + else + { + functionName = "count"; + countHqlExpression = _hqlTreeBuilder.Count(VisitExpression(expression.Expression).AsExpression()); + } + + return IsCastRequired(functionName, expression.Expression, expression.Type) + ? (HqlTreeNode) _hqlTreeBuilder.Cast(countHqlExpression, expression.Type) + : _hqlTreeBuilder.TransparentCast(countHqlExpression, expression.Type); } protected HqlTreeNode VisitNhMin(NhMinExpression expression) @@ -595,7 +611,7 @@ private bool IsCastRequired(Expression expression, System.Type toType, out bool { existType = false; return toType != typeof(object) && - IsCastRequired(GetType(expression), TypeFactory.GetDefaultTypeFor(toType), out existType); + IsCastRequired(ExpressionsHelper.GetType(_parameters, expression), TypeFactory.GetDefaultTypeFor(toType), out existType); } private bool IsCastRequired(IType type, IType toType, out bool existType) @@ -639,7 +655,7 @@ private bool IsCastRequired(IType type, IType toType, out bool existType) private bool IsCastRequired(string sqlFunctionName, Expression argumentExpression, System.Type returnType) { - var argumentType = GetType(argumentExpression); + var argumentType = ExpressionsHelper.GetType(_parameters, argumentExpression); if (argumentType == null || returnType == typeof(object)) { return false; @@ -657,18 +673,8 @@ private bool IsCastRequired(string sqlFunctionName, Expression argumentExpressio return true; // Fallback to the old behavior } - var fnReturnType = sqlFunction.ReturnType(argumentType, _parameters.SessionFactory); + var fnReturnType = sqlFunction.GetEffectiveReturnType(new[] {argumentType}, _parameters.SessionFactory, false); return fnReturnType == null || IsCastRequired(fnReturnType, returnNhType, out _); } - - private IType GetType(Expression expression) - { - // Try to get the mapped type for the member as it may be a non default one - return expression.Type == typeof(object) - ? null - : (ExpressionsHelper.TryGetMappedType(_parameters.SessionFactory, expression, out var type, out _, out _, out _) - ? type - : TypeFactory.GetDefaultTypeFor(expression.Type)); - } } } diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index 86fb8d04bd3..376a42ccda0 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -28,6 +28,29 @@ public static MemberInfo DecodeMemberAccessExpression(Expressi return ((MemberExpression)expression.Body).Member; } + /// + /// Get the mapped type for the given expression. + /// + /// The query parameters. + /// The expression. + /// The mapped type of the expression or when the mapped type was not + /// found and the type is . + internal static IType GetType(VisitorParameters parameters, Expression expression) + { + if (expression is ConstantExpression constantExpression && + parameters.ConstantToParameterMap.TryGetValue(constantExpression, out var param)) + { + return param.Type; + } + + if (TryGetMappedType(parameters.SessionFactory, expression, out var type, out _, out _, out _)) + { + return type; + } + + return expression.Type == typeof(object) ? null : TypeFactory.HeuristicType(expression.Type); + } + /// /// Try to get the mapped nullability from the given expression. ///