Skip to content

Reduce cast usage for COUNT aggregate and add support for Mssql count_big #2061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions src/NHibernate.Test/Async/Linq/ByMethod/CountTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using System;
using System.Linq;
using NHibernate.Cfg;
using NHibernate.Dialect;
using NUnit.Framework;
using NHibernate.Linq;

Expand Down Expand Up @@ -110,5 +111,50 @@ into temp

Assert.That(result.Count, Is.EqualTo(77));
}

[Test]
public async Task CheckSqlFunctionNameLongCountAsync()
{
var name = Dialect is MsSql2000Dialect ? "count_big" : "count";
using (var sqlLog = new SqlLogSpy())
{
var result = await (db.Orders.LongCountAsync());
Assert.That(result, Is.EqualTo(830));

var log = sqlLog.GetWholeLog();
Assert.That(log, Does.Contain($"{name}("));
}
}

[Test]
public async Task CheckSqlFunctionNameForCountAsync()
{
using (var sqlLog = new SqlLogSpy())
{
var result = await (db.Orders.CountAsync());
Assert.That(result, Is.EqualTo(830));

var log = sqlLog.GetWholeLog();
Assert.That(log, Does.Contain("count("));
}
}

[Test]
public async Task CheckMssqlCountCastAsync()
{
if (!(Dialect is MsSql2000Dialect))
{
Assert.Ignore();
}

using (var sqlLog = new SqlLogSpy())
{
var result = await (db.Orders.CountAsync());
Assert.That(result, Is.EqualTo(830));

var log = sqlLog.GetWholeLog();
Assert.That(log, Does.Not.Contain("cast("));
}
}
}
}
46 changes: 46 additions & 0 deletions src/NHibernate.Test/Linq/ByMethod/CountTests.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Linq;
using NHibernate.Cfg;
using NHibernate.Dialect;
using NUnit.Framework;

namespace NHibernate.Test.Linq.ByMethod
Expand Down Expand Up @@ -98,5 +99,50 @@ into temp

Assert.That(result.Count, Is.EqualTo(77));
}

[Test]
public void CheckSqlFunctionNameLongCount()
{
var name = Dialect is MsSql2000Dialect ? "count_big" : "count";
using (var sqlLog = new SqlLogSpy())
{
var result = db.Orders.LongCount();
Assert.That(result, Is.EqualTo(830));

var log = sqlLog.GetWholeLog();
Assert.That(log, Does.Contain($"{name}("));
}
}

[Test]
public void CheckSqlFunctionNameForCount()
{
using (var sqlLog = new SqlLogSpy())
{
var result = db.Orders.Count();
Assert.That(result, Is.EqualTo(830));

var log = sqlLog.GetWholeLog();
Assert.That(log, Does.Contain("count("));
}
}

[Test]
public void CheckMssqlCountCast()
{
if (!(Dialect is MsSql2000Dialect))
{
Assert.Ignore();
}

using (var sqlLog = new SqlLogSpy())
{
var result = db.Orders.Count();
Assert.That(result, Is.EqualTo(830));

var log = sqlLog.GetWholeLog();
Assert.That(log, Does.Not.Contain("cast("));
}
}
}
}
1 change: 1 addition & 0 deletions src/NHibernate/Dialect/Dialect.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public abstract partial class Dialect
static Dialect()
{
StandardAggregateFunctions["count"] = new CountQueryFunctionInfo();
StandardAggregateFunctions["count_big"] = new CountQueryFunctionInfo();
StandardAggregateFunctions["avg"] = new AvgQueryFunctionInfo();
StandardAggregateFunctions["max"] = new ClassicAggregateFunction("max", false);
StandardAggregateFunctions["min"] = new ClassicAggregateFunction("min", false);
Expand Down
15 changes: 12 additions & 3 deletions src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
using System;
using System.Collections;
using System.Text;
using System.Collections.Generic;
using System.Linq;
using NHibernate.Engine;
using NHibernate.SqlCommand;
using NHibernate.Type;
using NHibernate.Util;

namespace NHibernate.Dialect.Function
{
[Serializable]
public class ClassicAggregateFunction : ISQLFunction, IFunctionGrammar
public class ClassicAggregateFunction : ISQLFunction, IFunctionGrammar, ISQLFunctionExtended
{
private IType returnType = null;
private readonly string name;
Expand Down Expand Up @@ -45,6 +45,15 @@ public virtual IType ReturnType(IType columnType, IMapping mapping)
return returnType ?? columnType;
}

/// <inheritdoc />
public virtual IType GetEffectiveReturnType(IEnumerable<IType> argumentTypes, IMapping mapping, bool throwOnError)
{
return ReturnType(argumentTypes.FirstOrDefault(), mapping);
}

/// <inheritdoc />
public string FunctionName => name;

public bool HasArguments
{
get { return true; }
Expand Down
43 changes: 43 additions & 0 deletions src/NHibernate/Dialect/Function/ISQLFunction.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using NHibernate.Engine;
using NHibernate.SqlCommand;
using NHibernate.Type;
Expand Down Expand Up @@ -41,4 +43,45 @@ public interface ISQLFunction
/// <returns>SQL fragment for the function.</returns>
SqlString Render(IList args, ISessionFactoryImplementor factory);
}

// 6.0 TODO: Remove
internal static class SQLFunctionExtensions
{
/// <summary>
/// Get the type that will be effectively returned by the underlying database.
/// </summary>
/// <param name="sqlFunction">The sql function.</param>
/// <param name="argumentTypes">The types of arguments.</param>
/// <param name="mapping">The mapping for retrieving the argument sql types.</param>
/// <param name="throwOnError">Whether to throw when the number of arguments is invalid or they are not supported.</param>
/// <returns>The type returned by the underlying database or <see langword="null"/> when the number of arguments
/// is invalid or they are not supported.</returns>
/// <exception cref="QueryException">When <paramref name="throwOnError"/> is set to <see langword="true"/> and the
/// number of arguments is invalid or they are not supported.</exception>
public static IType GetEffectiveReturnType(
this ISQLFunction sqlFunction,
IEnumerable<IType> argumentTypes,
IMapping mapping,
bool throwOnError)
{
if (!(sqlFunction is ISQLFunctionExtended extendedSqlFunction))
{
try
{
return sqlFunction.ReturnType(argumentTypes.FirstOrDefault(), mapping);
}
catch (QueryException)
{
if (throwOnError)
{
throw;
}

return null;
}
}

return extendedSqlFunction.GetEffectiveReturnType(argumentTypes, mapping, throwOnError);
}
}
}
27 changes: 27 additions & 0 deletions src/NHibernate/Dialect/Function/ISQLFunctionExtended.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using System.Collections.Generic;
using NHibernate.Engine;
using NHibernate.Type;

namespace NHibernate.Dialect.Function
{
// 6.0 TODO: Merge into ISQLFunction
internal interface ISQLFunctionExtended : ISQLFunction
{
/// <summary>
/// The function name or <see langword="null"/> when multiple functions/operators/statements are used.
/// </summary>
string FunctionName { get; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be renamed to Name to reduce need of obsoleting things.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not, renamed in #2359.


/// <summary>
/// Get the type that will be effectively returned by the underlying database.
/// </summary>
/// <param name="argumentTypes">The types of arguments.</param>
/// <param name="mapping">The mapping for retrieving the argument sql types.</param>
/// <param name="throwOnError">Whether to throw when the number of arguments is invalid or they are not supported.</param>
/// <returns>The type returned by the underlying database or <see langword="null"/> when the number of arguments
/// is invalid or they are not supported.</returns>
/// <exception cref="QueryException">When <paramref name="throwOnError"/> is set to <see langword="true"/> and the
/// number of arguments is invalid or they are not supported.</exception>
IType GetEffectiveReturnType(IEnumerable<IType> argumentTypes, IMapping mapping, bool throwOnError);
}
}
13 changes: 9 additions & 4 deletions src/NHibernate/Dialect/MsSql2000Dialect.cs
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ protected virtual void RegisterKeywords()

protected virtual void RegisterFunctions()
{
RegisterFunction("count", new CountBigQueryFunction());
RegisterFunction("count", new CountQueryFunction());
RegisterFunction("count_big", new CountBigQueryFunction());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this still causing the breaking change you were writing about?

For example the following test would fail in Sql Server:

object count = s.CreateQuery("select count(*) from Simple").UniqueResult();
Assert.IsTrue(count is Int64);

because count effective return type in Sql Server is an int.

This changes the registration of HQL count for SQL-Server to map to SQL-Server count instead of count_big, which returns an int32 and fails if the count is bigger.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this still causing the breaking change you were writing about?

For HQL nothing is changed, the new GetEffectiveReturnType method is used only for Linq queries. CountQueryFunction.ReturnType still returns Int64, which is what HQL uses to determine the type, so hql queries using count will always be casted to Int64.


RegisterFunction("abs", new StandardSQLFunction("abs"));
RegisterFunction("absval", new StandardSQLFunction("absval"));
Expand Down Expand Up @@ -704,11 +705,15 @@ protected virtual string GetSelectExistingObject(string catalog, string schema,
[Serializable]
protected class CountBigQueryFunction : ClassicAggregateFunction
{
public CountBigQueryFunction() : base("count_big", true) { }
public CountBigQueryFunction() : base("count_big", true, NHibernateUtil.Int64) { }
}

public override IType ReturnType(IType columnType, IMapping mapping)
[Serializable]
private class CountQueryFunction : CountQueryFunctionInfo
{
public override IType GetEffectiveReturnType(IEnumerable<IType> argumentTypes, IMapping mapping, bool throwOnError)
{
return NHibernateUtil.Int64;
return NHibernateUtil.Int32;
}
}

Expand Down
6 changes: 6 additions & 0 deletions src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,12 @@ private void EndFunctionTemplate(IASTNode m)
}
}

private void OutAggregateFunctionName(IASTNode m)
{
var aggregateNode = (AggregateNode) m;
Out(aggregateNode.FunctionName);
}

private void CommaBetweenParameters(String comma)
{
writer.CommaBetweenParameters(comma);
Expand Down
4 changes: 2 additions & 2 deletions src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.g
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ selectExpr
;

count
: ^(COUNT { Out("count("); } ( distinctOrAll ) ? countExpr { Out(")"); } )
: ^(c=COUNT { OutAggregateFunctionName(c); Out("("); } ( distinctOrAll ) ? countExpr { Out(")"); } )
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this invalidate my previous comment? I mean, was this causing the SQL-Server HQL count to count_big mapping to be ignored by this piece of code, causing NHibernate to anyway use SQL-Server int32 count?

Copy link
Contributor Author

@maca88 maca88 Mar 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, was this causing the SQL-Server HQL count to count_big mapping to be ignored by this piece of code, causing NHibernate to anyway use SQL-Server int32 count?

Correct, count was hardcoded, which was the reason why changing from count to count_big in CountBigQueryFunction didn't worked.

;

distinctOrAll
Expand Down Expand Up @@ -344,7 +344,7 @@ caseExpr
;

aggregate
: ^(a=AGGREGATE { Out(a); Out("("); } expr { Out(")"); } )
: ^(a=AGGREGATE { OutAggregateFunctionName(a); Out("("); } expr { Out(")"); } )
;


Expand Down
15 changes: 15 additions & 0 deletions src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using Antlr.Runtime;
using NHibernate.Dialect.Function;
using NHibernate.Type;
using NHibernate.Hql.Ast.ANTLR.Util;

Expand All @@ -19,6 +20,19 @@ public AggregateNode(IToken token)
{
}

public string FunctionName
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be an overgeneralization and should be refactored into a node-specific code unless you see that it can be extended for functions other than count and count_big

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be extended for string concatenation, which is supported by some databases (Sql server, MySql, Postgresql and Oracle), where different function names are used (LISTAGG, GROUP_CONCAT, STRING_AGG).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maca88 can you provide some sort of prototype of string_agg? I'm dying to use this with linq provider

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To my knowledge it is not possible to implement it without modifying NHibernate source code. It is on my todo list for the next release, so unfortunately you will have to wait.

{
get
{
if (SessionFactoryHelper.FindSQLFunction(Text) is ISQLFunctionExtended sqlFunction)
{
return sqlFunction.FunctionName;
}

return Text;
}
}

public override IType DataType
{
get
Expand All @@ -31,6 +45,7 @@ public override IType DataType
base.DataType = value;
}
}

public override void SetScalarColumnText(int i)
{
ColumnHelper.GenerateSingleScalarColumn(ASTFactory, this, i);
Expand Down
7 changes: 2 additions & 5 deletions src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Antlr.Runtime;
using NHibernate.Dialect.Function;
using NHibernate.Hql.Ast.ANTLR.Util;
using NHibernate.Type;

Expand All @@ -9,7 +10,7 @@ namespace NHibernate.Hql.Ast.ANTLR.Tree
/// Author: josh
/// Ported by: Steve Strong
/// </summary>
class CountNode : AbstractSelectExpression, ISelectExpression
class CountNode : AggregateNode, ISelectExpression
{
public CountNode(IToken token) : base(token)
{
Expand All @@ -26,9 +27,5 @@ public override IType DataType
base.DataType = value;
}
}
public override void SetScalarColumnText(int i)
{
ColumnHelper.GenerateSingleScalarColumn(ASTFactory, this, i);
}
}
}
5 changes: 5 additions & 0 deletions src/NHibernate/Hql/Ast/HqlTreeBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,11 @@ public HqlCount Count(HqlExpression child)
return new HqlCount(_factory, child);
}

public HqlCountBig CountBig(HqlExpression child)
{
return new HqlCountBig(_factory, child);
}

public HqlRowStar RowStar()
{
return new HqlRowStar(_factory);
Expand Down
13 changes: 13 additions & 0 deletions src/NHibernate/Hql/Ast/HqlTreeNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,19 @@ public HqlCount(IASTFactory factory, HqlExpression child)
}
}

public class HqlCountBig : HqlExpression
{
public HqlCountBig(IASTFactory factory)
: base(HqlSqlWalker.COUNT, "count_big", factory)
{
}

public HqlCountBig(IASTFactory factory, HqlExpression child)
: base(HqlSqlWalker.COUNT, "count_big", factory, child)
{
}
}

public class HqlAs : HqlExpression
{
public HqlAs(IASTFactory factory, HqlExpression expression, System.Type type)
Expand Down
Loading