diff --git a/src/NHibernate.Test/Async/Linq/ByMethod/CountTests.cs b/src/NHibernate.Test/Async/Linq/ByMethod/CountTests.cs
index 8dfdc304aae..5581badf6c8 100644
--- a/src/NHibernate.Test/Async/Linq/ByMethod/CountTests.cs
+++ b/src/NHibernate.Test/Async/Linq/ByMethod/CountTests.cs
@@ -11,6 +11,7 @@
using System;
using System.Linq;
using NHibernate.Cfg;
+using NHibernate.Dialect;
using NUnit.Framework;
using NHibernate.Linq;
@@ -110,5 +111,50 @@ into temp
Assert.That(result.Count, Is.EqualTo(77));
}
+
+ [Test]
+ public async Task CheckSqlFunctionNameLongCountAsync()
+ {
+ var name = Dialect is MsSql2000Dialect ? "count_big" : "count";
+ using (var sqlLog = new SqlLogSpy())
+ {
+ var result = await (db.Orders.LongCountAsync());
+ Assert.That(result, Is.EqualTo(830));
+
+ var log = sqlLog.GetWholeLog();
+ Assert.That(log, Does.Contain($"{name}("));
+ }
+ }
+
+ [Test]
+ public async Task CheckSqlFunctionNameForCountAsync()
+ {
+ using (var sqlLog = new SqlLogSpy())
+ {
+ var result = await (db.Orders.CountAsync());
+ Assert.That(result, Is.EqualTo(830));
+
+ var log = sqlLog.GetWholeLog();
+ Assert.That(log, Does.Contain("count("));
+ }
+ }
+
+ [Test]
+ public async Task CheckMssqlCountCastAsync()
+ {
+ if (!(Dialect is MsSql2000Dialect))
+ {
+ Assert.Ignore();
+ }
+
+ using (var sqlLog = new SqlLogSpy())
+ {
+ var result = await (db.Orders.CountAsync());
+ Assert.That(result, Is.EqualTo(830));
+
+ var log = sqlLog.GetWholeLog();
+ Assert.That(log, Does.Not.Contain("cast("));
+ }
+ }
}
}
diff --git a/src/NHibernate.Test/Linq/ByMethod/CountTests.cs b/src/NHibernate.Test/Linq/ByMethod/CountTests.cs
index 3bb93c7a083..7ef2d7dbbe6 100644
--- a/src/NHibernate.Test/Linq/ByMethod/CountTests.cs
+++ b/src/NHibernate.Test/Linq/ByMethod/CountTests.cs
@@ -1,6 +1,7 @@
using System;
using System.Linq;
using NHibernate.Cfg;
+using NHibernate.Dialect;
using NUnit.Framework;
namespace NHibernate.Test.Linq.ByMethod
@@ -98,5 +99,50 @@ into temp
Assert.That(result.Count, Is.EqualTo(77));
}
+
+ [Test]
+ public void CheckSqlFunctionNameLongCount()
+ {
+ var name = Dialect is MsSql2000Dialect ? "count_big" : "count";
+ using (var sqlLog = new SqlLogSpy())
+ {
+ var result = db.Orders.LongCount();
+ Assert.That(result, Is.EqualTo(830));
+
+ var log = sqlLog.GetWholeLog();
+ Assert.That(log, Does.Contain($"{name}("));
+ }
+ }
+
+ [Test]
+ public void CheckSqlFunctionNameForCount()
+ {
+ using (var sqlLog = new SqlLogSpy())
+ {
+ var result = db.Orders.Count();
+ Assert.That(result, Is.EqualTo(830));
+
+ var log = sqlLog.GetWholeLog();
+ Assert.That(log, Does.Contain("count("));
+ }
+ }
+
+ [Test]
+ public void CheckMssqlCountCast()
+ {
+ if (!(Dialect is MsSql2000Dialect))
+ {
+ Assert.Ignore();
+ }
+
+ using (var sqlLog = new SqlLogSpy())
+ {
+ var result = db.Orders.Count();
+ Assert.That(result, Is.EqualTo(830));
+
+ var log = sqlLog.GetWholeLog();
+ Assert.That(log, Does.Not.Contain("cast("));
+ }
+ }
}
}
diff --git a/src/NHibernate/Dialect/Dialect.cs b/src/NHibernate/Dialect/Dialect.cs
index af1995dca06..53c30b55175 100644
--- a/src/NHibernate/Dialect/Dialect.cs
+++ b/src/NHibernate/Dialect/Dialect.cs
@@ -55,6 +55,7 @@ public abstract partial class Dialect
static Dialect()
{
StandardAggregateFunctions["count"] = new CountQueryFunctionInfo();
+ StandardAggregateFunctions["count_big"] = new CountQueryFunctionInfo();
StandardAggregateFunctions["avg"] = new AvgQueryFunctionInfo();
StandardAggregateFunctions["max"] = new ClassicAggregateFunction("max", false);
StandardAggregateFunctions["min"] = new ClassicAggregateFunction("min", false);
diff --git a/src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs b/src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs
index 299605adb03..e0b78f1e1e2 100644
--- a/src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs
+++ b/src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs
@@ -1,15 +1,15 @@
using System;
using System.Collections;
-using System.Text;
+using System.Collections.Generic;
+using System.Linq;
using NHibernate.Engine;
using NHibernate.SqlCommand;
using NHibernate.Type;
-using NHibernate.Util;
namespace NHibernate.Dialect.Function
{
[Serializable]
- public class ClassicAggregateFunction : ISQLFunction, IFunctionGrammar
+ public class ClassicAggregateFunction : ISQLFunction, IFunctionGrammar, ISQLFunctionExtended
{
private IType returnType = null;
private readonly string name;
@@ -45,6 +45,15 @@ public virtual IType ReturnType(IType columnType, IMapping mapping)
return returnType ?? columnType;
}
+ ///
+ public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ return ReturnType(argumentTypes.FirstOrDefault(), mapping);
+ }
+
+ ///
+ public string FunctionName => name;
+
public bool HasArguments
{
get { return true; }
diff --git a/src/NHibernate/Dialect/Function/ISQLFunction.cs b/src/NHibernate/Dialect/Function/ISQLFunction.cs
index 4433e5dc63c..5302625a01e 100644
--- a/src/NHibernate/Dialect/Function/ISQLFunction.cs
+++ b/src/NHibernate/Dialect/Function/ISQLFunction.cs
@@ -1,4 +1,6 @@
using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
using NHibernate.Engine;
using NHibernate.SqlCommand;
using NHibernate.Type;
@@ -41,4 +43,45 @@ public interface ISQLFunction
/// SQL fragment for the function.
SqlString Render(IList args, ISessionFactoryImplementor factory);
}
+
+ // 6.0 TODO: Remove
+ internal static class SQLFunctionExtensions
+ {
+ ///
+ /// Get the type that will be effectively returned by the underlying database.
+ ///
+ /// The sql function.
+ /// The types of arguments.
+ /// The mapping for retrieving the argument sql types.
+ /// Whether to throw when the number of arguments is invalid or they are not supported.
+ /// The type returned by the underlying database or when the number of arguments
+ /// is invalid or they are not supported.
+ /// When is set to and the
+ /// number of arguments is invalid or they are not supported.
+ public static IType GetEffectiveReturnType(
+ this ISQLFunction sqlFunction,
+ IEnumerable argumentTypes,
+ IMapping mapping,
+ bool throwOnError)
+ {
+ if (!(sqlFunction is ISQLFunctionExtended extendedSqlFunction))
+ {
+ try
+ {
+ return sqlFunction.ReturnType(argumentTypes.FirstOrDefault(), mapping);
+ }
+ catch (QueryException)
+ {
+ if (throwOnError)
+ {
+ throw;
+ }
+
+ return null;
+ }
+ }
+
+ return extendedSqlFunction.GetEffectiveReturnType(argumentTypes, mapping, throwOnError);
+ }
+ }
}
diff --git a/src/NHibernate/Dialect/Function/ISQLFunctionExtended.cs b/src/NHibernate/Dialect/Function/ISQLFunctionExtended.cs
new file mode 100644
index 00000000000..e2db4747198
--- /dev/null
+++ b/src/NHibernate/Dialect/Function/ISQLFunctionExtended.cs
@@ -0,0 +1,27 @@
+using System.Collections.Generic;
+using NHibernate.Engine;
+using NHibernate.Type;
+
+namespace NHibernate.Dialect.Function
+{
+ // 6.0 TODO: Merge into ISQLFunction
+ internal interface ISQLFunctionExtended : ISQLFunction
+ {
+ ///
+ /// The function name or when multiple functions/operators/statements are used.
+ ///
+ string FunctionName { get; }
+
+ ///
+ /// Get the type that will be effectively returned by the underlying database.
+ ///
+ /// The types of arguments.
+ /// The mapping for retrieving the argument sql types.
+ /// Whether to throw when the number of arguments is invalid or they are not supported.
+ /// The type returned by the underlying database or when the number of arguments
+ /// is invalid or they are not supported.
+ /// When is set to and the
+ /// number of arguments is invalid or they are not supported.
+ IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError);
+ }
+}
diff --git a/src/NHibernate/Dialect/MsSql2000Dialect.cs b/src/NHibernate/Dialect/MsSql2000Dialect.cs
index 7acb3cb9b4f..35c760e85b3 100644
--- a/src/NHibernate/Dialect/MsSql2000Dialect.cs
+++ b/src/NHibernate/Dialect/MsSql2000Dialect.cs
@@ -286,7 +286,8 @@ protected virtual void RegisterKeywords()
protected virtual void RegisterFunctions()
{
- RegisterFunction("count", new CountBigQueryFunction());
+ RegisterFunction("count", new CountQueryFunction());
+ RegisterFunction("count_big", new CountBigQueryFunction());
RegisterFunction("abs", new StandardSQLFunction("abs"));
RegisterFunction("absval", new StandardSQLFunction("absval"));
@@ -704,11 +705,15 @@ protected virtual string GetSelectExistingObject(string catalog, string schema,
[Serializable]
protected class CountBigQueryFunction : ClassicAggregateFunction
{
- public CountBigQueryFunction() : base("count_big", true) { }
+ public CountBigQueryFunction() : base("count_big", true, NHibernateUtil.Int64) { }
+ }
- public override IType ReturnType(IType columnType, IMapping mapping)
+ [Serializable]
+ private class CountQueryFunction : CountQueryFunctionInfo
+ {
+ public override IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
{
- return NHibernateUtil.Int64;
+ return NHibernateUtil.Int32;
}
}
diff --git a/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs b/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs
index 4f6f902891d..0f1c7a64603 100644
--- a/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs
+++ b/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs
@@ -310,6 +310,12 @@ private void EndFunctionTemplate(IASTNode m)
}
}
+ private void OutAggregateFunctionName(IASTNode m)
+ {
+ var aggregateNode = (AggregateNode) m;
+ Out(aggregateNode.FunctionName);
+ }
+
private void CommaBetweenParameters(String comma)
{
writer.CommaBetweenParameters(comma);
diff --git a/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.g b/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.g
index 73d04c792ec..21cb3c7521e 100644
--- a/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.g
+++ b/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.g
@@ -150,7 +150,7 @@ selectExpr
;
count
- : ^(COUNT { Out("count("); } ( distinctOrAll ) ? countExpr { Out(")"); } )
+ : ^(c=COUNT { OutAggregateFunctionName(c); Out("("); } ( distinctOrAll ) ? countExpr { Out(")"); } )
;
distinctOrAll
@@ -344,7 +344,7 @@ caseExpr
;
aggregate
- : ^(a=AGGREGATE { Out(a); Out("("); } expr { Out(")"); } )
+ : ^(a=AGGREGATE { OutAggregateFunctionName(a); Out("("); } expr { Out(")"); } )
;
diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs
index d7dc2f9a0f0..a3887a56c01 100644
--- a/src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs
+++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs
@@ -1,5 +1,6 @@
using System;
using Antlr.Runtime;
+using NHibernate.Dialect.Function;
using NHibernate.Type;
using NHibernate.Hql.Ast.ANTLR.Util;
@@ -19,6 +20,19 @@ public AggregateNode(IToken token)
{
}
+ public string FunctionName
+ {
+ get
+ {
+ if (SessionFactoryHelper.FindSQLFunction(Text) is ISQLFunctionExtended sqlFunction)
+ {
+ return sqlFunction.FunctionName;
+ }
+
+ return Text;
+ }
+ }
+
public override IType DataType
{
get
@@ -31,6 +45,7 @@ public override IType DataType
base.DataType = value;
}
}
+
public override void SetScalarColumnText(int i)
{
ColumnHelper.GenerateSingleScalarColumn(ASTFactory, this, i);
diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs
index 88ec9cc22fd..ef6d6272043 100644
--- a/src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs
+++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs
@@ -1,4 +1,5 @@
using Antlr.Runtime;
+using NHibernate.Dialect.Function;
using NHibernate.Hql.Ast.ANTLR.Util;
using NHibernate.Type;
@@ -9,7 +10,7 @@ namespace NHibernate.Hql.Ast.ANTLR.Tree
/// Author: josh
/// Ported by: Steve Strong
///
- class CountNode : AbstractSelectExpression, ISelectExpression
+ class CountNode : AggregateNode, ISelectExpression
{
public CountNode(IToken token) : base(token)
{
@@ -26,9 +27,5 @@ public override IType DataType
base.DataType = value;
}
}
- public override void SetScalarColumnText(int i)
- {
- ColumnHelper.GenerateSingleScalarColumn(ASTFactory, this, i);
- }
}
}
diff --git a/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs b/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs
index bb208295afc..dd59f2ff0f0 100755
--- a/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs
+++ b/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs
@@ -307,6 +307,11 @@ public HqlCount Count(HqlExpression child)
return new HqlCount(_factory, child);
}
+ public HqlCountBig CountBig(HqlExpression child)
+ {
+ return new HqlCountBig(_factory, child);
+ }
+
public HqlRowStar RowStar()
{
return new HqlRowStar(_factory);
diff --git a/src/NHibernate/Hql/Ast/HqlTreeNode.cs b/src/NHibernate/Hql/Ast/HqlTreeNode.cs
index 160739d7f16..8967174bde3 100755
--- a/src/NHibernate/Hql/Ast/HqlTreeNode.cs
+++ b/src/NHibernate/Hql/Ast/HqlTreeNode.cs
@@ -697,6 +697,19 @@ public HqlCount(IASTFactory factory, HqlExpression child)
}
}
+ public class HqlCountBig : HqlExpression
+ {
+ public HqlCountBig(IASTFactory factory)
+ : base(HqlSqlWalker.COUNT, "count_big", factory)
+ {
+ }
+
+ public HqlCountBig(IASTFactory factory, HqlExpression child)
+ : base(HqlSqlWalker.COUNT, "count_big", factory, child)
+ {
+ }
+ }
+
public class HqlAs : HqlExpression
{
public HqlAs(IASTFactory factory, HqlExpression expression, System.Type type)
diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs
index 224af655a24..cd9cd49eadb 100644
--- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs
+++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs
@@ -4,6 +4,7 @@
using System.Linq;
using System.Linq.Expressions;
using System.Runtime.CompilerServices;
+using NHibernate.Dialect.Function;
using NHibernate.Engine.Query;
using NHibernate.Hql.Ast;
using NHibernate.Hql.Ast.ANTLR;
@@ -257,7 +258,22 @@ protected HqlTreeNode VisitNhAverage(NhAverageExpression expression)
protected HqlTreeNode VisitNhCount(NhCountExpression expression)
{
- return _hqlTreeBuilder.Cast(_hqlTreeBuilder.Count(VisitExpression(expression.Expression).AsExpression()), expression.Type);
+ string functionName;
+ HqlExpression countHqlExpression;
+ if (expression is NhLongCountExpression)
+ {
+ functionName = "count_big";
+ countHqlExpression = _hqlTreeBuilder.CountBig(VisitExpression(expression.Expression).AsExpression());
+ }
+ else
+ {
+ functionName = "count";
+ countHqlExpression = _hqlTreeBuilder.Count(VisitExpression(expression.Expression).AsExpression());
+ }
+
+ return IsCastRequired(functionName, expression.Expression, expression.Type)
+ ? (HqlTreeNode) _hqlTreeBuilder.Cast(countHqlExpression, expression.Type)
+ : _hqlTreeBuilder.TransparentCast(countHqlExpression, expression.Type);
}
protected HqlTreeNode VisitNhMin(NhMinExpression expression)
@@ -595,7 +611,7 @@ private bool IsCastRequired(Expression expression, System.Type toType, out bool
{
existType = false;
return toType != typeof(object) &&
- IsCastRequired(GetType(expression), TypeFactory.GetDefaultTypeFor(toType), out existType);
+ IsCastRequired(ExpressionsHelper.GetType(_parameters, expression), TypeFactory.GetDefaultTypeFor(toType), out existType);
}
private bool IsCastRequired(IType type, IType toType, out bool existType)
@@ -639,7 +655,7 @@ private bool IsCastRequired(IType type, IType toType, out bool existType)
private bool IsCastRequired(string sqlFunctionName, Expression argumentExpression, System.Type returnType)
{
- var argumentType = GetType(argumentExpression);
+ var argumentType = ExpressionsHelper.GetType(_parameters, argumentExpression);
if (argumentType == null || returnType == typeof(object))
{
return false;
@@ -657,18 +673,8 @@ private bool IsCastRequired(string sqlFunctionName, Expression argumentExpressio
return true; // Fallback to the old behavior
}
- var fnReturnType = sqlFunction.ReturnType(argumentType, _parameters.SessionFactory);
+ var fnReturnType = sqlFunction.GetEffectiveReturnType(new[] {argumentType}, _parameters.SessionFactory, false);
return fnReturnType == null || IsCastRequired(fnReturnType, returnNhType, out _);
}
-
- private IType GetType(Expression expression)
- {
- // Try to get the mapped type for the member as it may be a non default one
- return expression.Type == typeof(object)
- ? null
- : (ExpressionsHelper.TryGetMappedType(_parameters.SessionFactory, expression, out var type, out _, out _, out _)
- ? type
- : TypeFactory.GetDefaultTypeFor(expression.Type));
- }
}
}
diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs
index 86fb8d04bd3..376a42ccda0 100644
--- a/src/NHibernate/Util/ExpressionsHelper.cs
+++ b/src/NHibernate/Util/ExpressionsHelper.cs
@@ -28,6 +28,29 @@ public static MemberInfo DecodeMemberAccessExpression(Expressi
return ((MemberExpression)expression.Body).Member;
}
+ ///
+ /// Get the mapped type for the given expression.
+ ///
+ /// The query parameters.
+ /// The expression.
+ /// The mapped type of the expression or when the mapped type was not
+ /// found and the type is .
+ internal static IType GetType(VisitorParameters parameters, Expression expression)
+ {
+ if (expression is ConstantExpression constantExpression &&
+ parameters.ConstantToParameterMap.TryGetValue(constantExpression, out var param))
+ {
+ return param.Type;
+ }
+
+ if (TryGetMappedType(parameters.SessionFactory, expression, out var type, out _, out _, out _))
+ {
+ return type;
+ }
+
+ return expression.Type == typeof(object) ? null : TypeFactory.HeuristicType(expression.Type);
+ }
+
///
/// Try to get the mapped nullability from the given expression.
///