diff --git a/src/NHibernate/Criterion/SqlFunctionProjection.cs b/src/NHibernate/Criterion/SqlFunctionProjection.cs index 95d5298fa77..0c3bba98f98 100644 --- a/src/NHibernate/Criterion/SqlFunctionProjection.cs +++ b/src/NHibernate/Criterion/SqlFunctionProjection.cs @@ -125,7 +125,7 @@ private IType GetReturnType(ICriteria criteria, ICriteriaQuery criteriaQuery) var resultType = returnType ?? returnTypeProjection?.GetTypes(criteria, criteriaQuery).FirstOrDefault(); - return sqlFunction.ReturnType(resultType, criteriaQuery.Factory); + return sqlFunction.GetReturnType(new[] {resultType}, criteriaQuery.Factory, true); } public override TypedValue[] GetTypedValues(ICriteria criteria, ICriteriaQuery criteriaQuery) diff --git a/src/NHibernate/Dialect/DB2Dialect.cs b/src/NHibernate/Dialect/DB2Dialect.cs index bd24dda28d7..4ec34033f9e 100644 --- a/src/NHibernate/Dialect/DB2Dialect.cs +++ b/src/NHibernate/Dialect/DB2Dialect.cs @@ -131,7 +131,7 @@ public DB2Dialect() RegisterFunction("length", new StandardSQLFunction("length", NHibernateUtil.Int32)); RegisterFunction("ltrim", new StandardSQLFunction("ltrim")); - RegisterFunction("mod", new StandardSQLFunction("mod", NHibernateUtil.Int32)); + RegisterFunction("mod", new ModulusFunction(true, false)); RegisterFunction("substring", new StandardSQLFunction("substr", NHibernateUtil.String)); diff --git a/src/NHibernate/Dialect/Dialect.cs b/src/NHibernate/Dialect/Dialect.cs index 53c30b55175..6153309c7ac 100644 --- a/src/NHibernate/Dialect/Dialect.cs +++ b/src/NHibernate/Dialect/Dialect.cs @@ -94,7 +94,7 @@ protected Dialect() RegisterFunction("coalesce", new StandardSQLFunction("coalesce")); RegisterFunction("nullif", new StandardSQLFunction("nullif")); RegisterFunction("abs", new StandardSQLFunction("abs")); - RegisterFunction("mod", new StandardSQLFunction("mod", NHibernateUtil.Int32)); + RegisterFunction("mod", new ModulusFunction(false, false)); RegisterFunction("sqrt", new StandardSQLFunction("sqrt", NHibernateUtil.Double)); RegisterFunction("upper", new StandardSQLFunction("upper")); RegisterFunction("lower", new StandardSQLFunction("lower")); diff --git a/src/NHibernate/Dialect/FirebirdDialect.cs b/src/NHibernate/Dialect/FirebirdDialect.cs index ba37c00cfaa..0f26f03d15a 100644 --- a/src/NHibernate/Dialect/FirebirdDialect.cs +++ b/src/NHibernate/Dialect/FirebirdDialect.cs @@ -1,5 +1,6 @@ using System; using System.Collections; +using System.Collections.Generic; using System.Data; using System.Data.Common; using NHibernate.Dialect.Function; @@ -147,7 +148,7 @@ public CastedFunction(string name, IType returnType) : base(name, returnType, fa public override SqlString Render(IList args, ISessionFactoryImplementor factory) { - return new SqlString("cast('", Name, "' as ", FunctionReturnType.SqlTypes(factory)[0].ToString(), ")"); + return new SqlString("cast('", FunctionName, "' as ", FunctionReturnType.SqlTypes(factory)[0].ToString(), ")"); } } @@ -160,7 +161,7 @@ public CurrentTimeStamp() : base("current_timestamp", NHibernateUtil.LocalDateTi public override SqlString Render(IList args, ISessionFactoryImplementor factory) { - return new SqlString(Name); + return new SqlString(FunctionName); } } @@ -205,7 +206,7 @@ public override string SelectGUIDString } [Serializable] - private class PositionFunction : ISQLFunction + private class PositionFunction : ISQLFunction, ISQLFunctionExtended { // The cast is needed, at least in the case that ?3 is a named integer parameter, otherwise firebird will generate an error. // We have a unit test to cover this potential firebird bug. @@ -214,11 +215,28 @@ private class PositionFunction : ISQLFunction private static readonly ISQLFunction LocateWith3Params = new SQLFunctionTemplate(NHibernateUtil.Int32, "position(?1, ?2, cast(?3 as int))"); + // Since v5.3 + [Obsolete("Use GetReturnType method instead.")] public IType ReturnType(IType columnType, IMapping mapping) { return NHibernateUtil.Int32; } + /// + public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + return NHibernateUtil.Int32; + } + + /// + public IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + return GetReturnType(argumentTypes, mapping, throwOnError); + } + + /// + public string FunctionName => "position"; + public bool HasArguments { get { return true; } @@ -418,7 +436,8 @@ private void OverrideStandardHQLFunctions() RegisterFunction("nullif", new StandardSafeSQLFunction("nullif", 2)); RegisterFunction("lower", new StandardSafeSQLFunction("lower", NHibernateUtil.String, 1)); RegisterFunction("upper", new StandardSafeSQLFunction("upper", NHibernateUtil.String, 1)); - RegisterFunction("mod", new StandardSafeSQLFunction("mod", NHibernateUtil.Double, 2)); + // Modulo does not throw for decimal parameters but they are casted to int by Firebird, which produces unexpected results + RegisterFunction("mod", new ModulusFunction(false, false)); RegisterFunction("str", new SQLFunctionTemplate(NHibernateUtil.String, "cast(?1 as VARCHAR(255))")); RegisterFunction("strguid", new StandardSQLFunction("uuid_to_char", NHibernateUtil.String)); RegisterFunction("sysdate", new CastedFunction("today", NHibernateUtil.Date)); @@ -437,7 +456,7 @@ private void RegisterFirebirdServerEmbeddedFunctions() RegisterFunction("yesterday", new CastedFunction("yesterday", NHibernateUtil.Date)); RegisterFunction("tomorrow", new CastedFunction("tomorrow", NHibernateUtil.Date)); RegisterFunction("now", new CastedFunction("now", NHibernateUtil.DateTime)); - RegisterFunction("iif", new StandardSafeSQLFunction("iif", 3)); + RegisterFunction("iif", new IifSafeSQLFunction()); // New embedded functions in FB 2.0 (http://www.firebirdsql.org/rlsnotes20/rnfbtwo-str.html#str-string-func) RegisterFunction("char_length", new StandardSafeSQLFunction("char_length", NHibernateUtil.Int64, 1)); RegisterFunction("bit_length", new StandardSafeSQLFunction("bit_length", NHibernateUtil.Int64, 1)); diff --git a/src/NHibernate/Dialect/Function/AnsiSubstringFunction.cs b/src/NHibernate/Dialect/Function/AnsiSubstringFunction.cs index b12e4e7e041..574aa13a136 100644 --- a/src/NHibernate/Dialect/Function/AnsiSubstringFunction.cs +++ b/src/NHibernate/Dialect/Function/AnsiSubstringFunction.cs @@ -1,5 +1,7 @@ using System; using System.Collections; +using System.Collections.Generic; +using System.Linq; using System.Text; using NHibernate.Engine; using NHibernate.SqlCommand; @@ -22,15 +24,34 @@ namespace NHibernate.Dialect.Function ///]]> /// [Serializable] - public class AnsiSubstringFunction : ISQLFunction + public class AnsiSubstringFunction : ISQLFunction, ISQLFunctionExtended { #region ISQLFunction Members + // Since v5.3 + [Obsolete("Use GetReturnType method instead.")] public IType ReturnType(IType columnType, IMapping mapping) { return NHibernateUtil.String; } + /// + public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { +#pragma warning disable 618 + return ReturnType(argumentTypes.FirstOrDefault(), mapping); +#pragma warning restore 618 + } + + /// + public IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + return GetReturnType(argumentTypes, mapping, throwOnError); + } + + /// + public string FunctionName => "substring"; + public bool HasArguments { get { return true; } diff --git a/src/NHibernate/Dialect/Function/AnsiTrimEmulationFunction.cs b/src/NHibernate/Dialect/Function/AnsiTrimEmulationFunction.cs index e9cde7d70f2..995764d303e 100644 --- a/src/NHibernate/Dialect/Function/AnsiTrimEmulationFunction.cs +++ b/src/NHibernate/Dialect/Function/AnsiTrimEmulationFunction.cs @@ -1,7 +1,7 @@ using System; using System.Collections; using System.Collections.Generic; - +using System.Linq; using NHibernate.Engine; using NHibernate.SqlCommand; using NHibernate.Type; @@ -17,7 +17,7 @@ namespace NHibernate.Dialect.Function /// functionality. /// [Serializable] - public class AnsiTrimEmulationFunction : ISQLFunction, IFunctionGrammar + public class AnsiTrimEmulationFunction : ISQLFunction, IFunctionGrammar, ISQLFunctionExtended { private static readonly ISQLFunction LeadingSpaceTrim = new SQLFunctionTemplate(NHibernateUtil.String, "ltrim( ?1 )"); private static readonly ISQLFunction TrailingSpaceTrim = new SQLFunctionTemplate(NHibernateUtil.String, "rtrim( ?1 )"); @@ -76,11 +76,30 @@ public AnsiTrimEmulationFunction(string replaceFunction) #region ISQLFunction Members + // Since v5.3 + [Obsolete("Use GetReturnType method instead.")] public IType ReturnType(IType columnType, IMapping mapping) { return NHibernateUtil.String; } + /// + public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { +#pragma warning disable 618 + return ReturnType(argumentTypes.FirstOrDefault(), mapping); +#pragma warning restore 618 + } + + /// + public IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + return GetReturnType(argumentTypes, mapping, throwOnError); + } + + /// + public string FunctionName => null; + public bool HasArguments { get { return true; } diff --git a/src/NHibernate/Dialect/Function/AvgQueryFunctionInfo.cs b/src/NHibernate/Dialect/Function/AvgQueryFunctionInfo.cs index 47f8903cb48..c780479747f 100644 --- a/src/NHibernate/Dialect/Function/AvgQueryFunctionInfo.cs +++ b/src/NHibernate/Dialect/Function/AvgQueryFunctionInfo.cs @@ -1,6 +1,6 @@ using System; +using System.Collections.Generic; using NHibernate.Engine; -using NHibernate.SqlTypes; using NHibernate.Type; namespace NHibernate.Dialect.Function @@ -10,27 +10,22 @@ class AvgQueryFunctionInfo : ClassicAggregateFunction { public AvgQueryFunctionInfo() : base("avg", false) { } + // Since v5.3 + [Obsolete("Use GetReturnType method instead.")] public override IType ReturnType(IType columnType, IMapping mapping) { - if (columnType == null) - { - throw new ArgumentNullException("columnType"); - } - SqlType[] sqlTypes; - try - { - sqlTypes = columnType.SqlTypes(mapping); - } - catch (MappingException me) - { - throw new QueryException(me); - } + return GetReturnType(new[] {columnType}, mapping, true); + } - if (sqlTypes.Length != 1) + /// + public override IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + if (!TryGetArgumentType(argumentTypes, mapping, throwOnError, out _, out _)) { - throw new QueryException("multi-column type can not be in avg()"); + return null; } + return NHibernateUtil.Double; } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Dialect/Function/BitwiseFunctionOperation.cs b/src/NHibernate/Dialect/Function/BitwiseFunctionOperation.cs index 5ebfbbe4c59..db6ba0c5365 100644 --- a/src/NHibernate/Dialect/Function/BitwiseFunctionOperation.cs +++ b/src/NHibernate/Dialect/Function/BitwiseFunctionOperation.cs @@ -1,5 +1,7 @@ using System; using System.Collections; +using System.Collections.Generic; +using System.Linq; using NHibernate.Engine; using NHibernate.SqlCommand; using NHibernate.Type; @@ -27,8 +29,9 @@ namespace NHibernate.Dialect.Function /// Treats bitwise operations as SQL function calls. /// [Serializable] - public class BitwiseFunctionOperation : ISQLFunction + public class BitwiseFunctionOperation : ISQLFunction, ISQLFunctionExtended { + // TODO 6.0: convert FunctionName to read-only auto-property private readonly string _functionName; /// @@ -45,11 +48,30 @@ public BitwiseFunctionOperation(string functionName) #region ISQLFunction Members /// + // Since v5.3 + [Obsolete("Use GetReturnType method instead.")] public IType ReturnType(IType columnType, IMapping mapping) { return NHibernateUtil.Int64; } + /// + public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { +#pragma warning disable 618 + return ReturnType(argumentTypes.FirstOrDefault(), mapping); +#pragma warning restore 618 + } + + /// + public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + return GetReturnType(argumentTypes, mapping, throwOnError); + } + + /// + public string FunctionName => _functionName; + /// public bool HasArguments => true; diff --git a/src/NHibernate/Dialect/Function/BitwiseNativeOperation.cs b/src/NHibernate/Dialect/Function/BitwiseNativeOperation.cs index 01e9e462d97..3bffb4b1d59 100644 --- a/src/NHibernate/Dialect/Function/BitwiseNativeOperation.cs +++ b/src/NHibernate/Dialect/Function/BitwiseNativeOperation.cs @@ -1,5 +1,7 @@ using System; using System.Collections; +using System.Collections.Generic; +using System.Linq; using NHibernate.Engine; using NHibernate.SqlCommand; using NHibernate.Type; @@ -32,7 +34,7 @@ namespace NHibernate.Dialect.Function /// Treats bitwise operations as native operations. /// [Serializable] - public class BitwiseNativeOperation : ISQLFunction + public class BitwiseNativeOperation : ISQLFunction, ISQLFunctionExtended { private readonly string _sqlOpToken; private readonly bool _isUnary; @@ -65,11 +67,30 @@ public BitwiseNativeOperation(string sqlOpToken, bool isUnary) #region ISQLFunction Members /// + // Since v5.3 + [Obsolete("Use GetReturnType method instead.")] public IType ReturnType(IType columnType, IMapping mapping) { return NHibernateUtil.Int64; } + /// + public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { +#pragma warning disable 618 + return ReturnType(argumentTypes.FirstOrDefault(), mapping); +#pragma warning restore 618 + } + + /// + public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + return GetReturnType(argumentTypes, mapping, throwOnError); + } + + /// + public string FunctionName => null; + /// public bool HasArguments => true; diff --git a/src/NHibernate/Dialect/Function/CastFunction.cs b/src/NHibernate/Dialect/Function/CastFunction.cs index 8580da1997f..2e02ea9dad1 100644 --- a/src/NHibernate/Dialect/Function/CastFunction.cs +++ b/src/NHibernate/Dialect/Function/CastFunction.cs @@ -1,5 +1,7 @@ using System; using System.Collections; +using System.Collections.Generic; +using System.Linq; using System.Xml; using NHibernate.Engine; using NHibernate.SqlCommand; @@ -12,10 +14,12 @@ namespace NHibernate.Dialect.Function /// ANSI-SQL style cast(foo as type) where the type is a NHibernate type /// [Serializable] - public class CastFunction : ISQLFunction, IFunctionGrammar + public class CastFunction : ISQLFunction, IFunctionGrammar, ISQLFunctionExtended { #region ISQLFunction Members + // Since v5.3 + [Obsolete("Use GetReturnType method instead.")] public IType ReturnType(IType columnType, IMapping mapping) { //note there is a weird implementation in the client side @@ -23,6 +27,23 @@ public IType ReturnType(IType columnType, IMapping mapping) return columnType; } + /// + public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { +#pragma warning disable 618 + return ReturnType(argumentTypes.FirstOrDefault(), mapping); +#pragma warning restore 618 + } + + /// + public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + return GetReturnType(argumentTypes, mapping, throwOnError); + } + + /// + public string FunctionName => "cast"; + public bool HasArguments { get { return true; } diff --git a/src/NHibernate/Dialect/Function/CharIndexFunction.cs b/src/NHibernate/Dialect/Function/CharIndexFunction.cs index 1cdfafbff8c..7cc9f7bf543 100644 --- a/src/NHibernate/Dialect/Function/CharIndexFunction.cs +++ b/src/NHibernate/Dialect/Function/CharIndexFunction.cs @@ -1,5 +1,7 @@ using System; using System.Collections; +using System.Collections.Generic; +using System.Linq; using System.Text; using NHibernate.Engine; using NHibernate.SqlCommand; @@ -11,7 +13,7 @@ namespace NHibernate.Dialect.Function /// Emulation of locate() on Sybase /// [Serializable] - public class CharIndexFunction : ISQLFunction + public class CharIndexFunction : ISQLFunction, ISQLFunctionExtended { public CharIndexFunction() { @@ -19,11 +21,30 @@ public CharIndexFunction() #region ISQLFunction Members + // Since v5.3 + [Obsolete("Use GetReturnType method instead.")] public IType ReturnType(IType columnType, IMapping mapping) { return NHibernateUtil.Int32; } + /// + public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { +#pragma warning disable 618 + return ReturnType(argumentTypes.FirstOrDefault(), mapping); +#pragma warning restore 618 + } + + /// + public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + return GetReturnType(argumentTypes, mapping, throwOnError); + } + + /// + public string FunctionName => "charindex"; + public bool HasArguments { get { return true; } diff --git a/src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs b/src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs index e0b78f1e1e2..cf0d4cf219c 100644 --- a/src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs +++ b/src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs @@ -4,6 +4,7 @@ using System.Linq; using NHibernate.Engine; using NHibernate.SqlCommand; +using NHibernate.SqlTypes; using NHibernate.Type; namespace NHibernate.Dialect.Function @@ -40,15 +41,25 @@ public ClassicAggregateFunction(string name, bool acceptAsterisk, IType typeValu #region ISQLFunction Members + // Since v5.3 + [Obsolete("Use GetReturnType method instead.")] public virtual IType ReturnType(IType columnType, IMapping mapping) { return returnType ?? columnType; } /// - public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + public virtual IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) { +#pragma warning disable 618 return ReturnType(argumentTypes.FirstOrDefault(), mapping); +#pragma warning restore 618 + } + + /// + public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + return GetReturnType(argumentTypes, mapping, throwOnError); } /// @@ -100,6 +111,63 @@ public SqlString Render(IList args, ISessionFactoryImplementor factory) #endregion + protected bool TryGetArgumentType( + IEnumerable argumentTypes, + IMapping mapping, + bool throwOnError, + out IType argumentType, + out SqlType sqlType) + { + sqlType = null; + argumentType = null; + if (argumentTypes.Count() != 1) + { + if (throwOnError) + { + throw new QueryException($"Invalid number of arguments for {name}()"); + } + + return false; + } + + argumentType = argumentTypes.First(); + if (argumentType == null) + { + // The argument is a parameter (e.g. select avg(:p1) from OrderLine). In that case, if the datatype is needed + // a QueryException will be thrown in SelectClause class, otherwise the query will be executed + // (e.g. select case when avg(:p1) > 0 then 1 else 0 end from OrderLine). + return false; + } + + SqlType[] sqlTypes; + try + { + sqlTypes = argumentType.SqlTypes(mapping); + } + catch (MappingException me) + { + if (throwOnError) + { + throw new QueryException(me); + } + + return false; + } + + if (sqlTypes.Length != 1) + { + if (throwOnError) + { + throw new QueryException($"Multi-column type can not be in {name}()"); + } + + return false; + } + + sqlType = sqlTypes[0]; + return true; + } + public override string ToString() { return name; diff --git a/src/NHibernate/Dialect/Function/ClassicAvgFunction.cs b/src/NHibernate/Dialect/Function/ClassicAvgFunction.cs index ec802a85096..737ed7f2ae4 100644 --- a/src/NHibernate/Dialect/Function/ClassicAvgFunction.cs +++ b/src/NHibernate/Dialect/Function/ClassicAvgFunction.cs @@ -1,7 +1,7 @@ using System; +using System.Collections.Generic; using System.Data; using NHibernate.Engine; -using NHibernate.SqlTypes; using NHibernate.Type; namespace NHibernate.Dialect.Function @@ -16,36 +16,28 @@ public ClassicAvgFunction() : base("avg", false) { } + // Since v5.3 + [Obsolete("Use GetReturnType method instead.")] public override IType ReturnType(IType columnType, IMapping mapping) { - if (columnType == null) - { - throw new ArgumentNullException("columnType"); - } - SqlType[] sqlTypes; - try - { - sqlTypes = columnType.SqlTypes(mapping); - } - catch (MappingException me) - { - throw new QueryException(me); - } + return GetReturnType(new[] {columnType}, mapping, true); + } - if (sqlTypes.Length != 1) + /// + public override IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + if (!TryGetArgumentType(argumentTypes, mapping, throwOnError, out var argumentType, out var sqlType)) { - throw new QueryException("multi-column type can not be in avg()"); + return null; } - SqlType sqlType = sqlTypes[0]; - if (sqlType.DbType == DbType.Int16 || sqlType.DbType == DbType.Int32 || sqlType.DbType == DbType.Int64) { return NHibernateUtil.Single; } else { - return columnType; + return argumentType; } } } diff --git a/src/NHibernate/Dialect/Function/ClassicCountFunction.cs b/src/NHibernate/Dialect/Function/ClassicCountFunction.cs index 821146519e9..6cf138866df 100644 --- a/src/NHibernate/Dialect/Function/ClassicCountFunction.cs +++ b/src/NHibernate/Dialect/Function/ClassicCountFunction.cs @@ -1,4 +1,6 @@ using System; +using System.Collections.Generic; +using System.Linq; using NHibernate.Engine; using NHibernate.Type; @@ -14,9 +16,19 @@ public ClassicCountFunction() : base("count", true) { } + // Since v5.3 + [Obsolete("Use GetReturnType method instead.")] public override IType ReturnType(IType columnType, IMapping mapping) { return NHibernateUtil.Int32; } + + public override IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + // 6.0 TODO: return NHibernateUtil.Int32; +#pragma warning disable 618 + return ReturnType(argumentTypes.FirstOrDefault(), mapping); +#pragma warning restore 618 + } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Dialect/Function/CountQueryFunctionInfo.cs b/src/NHibernate/Dialect/Function/CountQueryFunctionInfo.cs index b201a7522ea..6f2d873be28 100644 --- a/src/NHibernate/Dialect/Function/CountQueryFunctionInfo.cs +++ b/src/NHibernate/Dialect/Function/CountQueryFunctionInfo.cs @@ -1,4 +1,6 @@ using System; +using System.Collections.Generic; +using System.Linq; using NHibernate.Engine; using NHibernate.Type; @@ -9,9 +11,19 @@ class CountQueryFunctionInfo : ClassicAggregateFunction { public CountQueryFunctionInfo() : base("count", true) { } + // Since v5.3 + [Obsolete("Use GetReturnType method instead.")] public override IType ReturnType(IType columnType, IMapping mapping) { return NHibernateUtil.Int64; } + + public override IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + // 6.0 TODO: return NHibernateUtil.Int64; +#pragma warning disable 618 + return ReturnType(argumentTypes.FirstOrDefault(), mapping); +#pragma warning restore 618 + } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Dialect/Function/ISQLFunction.cs b/src/NHibernate/Dialect/Function/ISQLFunction.cs index 5302625a01e..8ced743a4b3 100644 --- a/src/NHibernate/Dialect/Function/ISQLFunction.cs +++ b/src/NHibernate/Dialect/Function/ISQLFunction.cs @@ -1,3 +1,4 @@ +using System; using System.Collections; using System.Collections.Generic; using System.Linq; @@ -23,6 +24,8 @@ public interface ISQLFunction /// The type of the first argument /// /// + // Since v5.3 + [Obsolete("Use GetReturnType extension method instead.")] IType ReturnType(IType columnType, IMapping mapping); /// @@ -45,7 +48,7 @@ public interface ISQLFunction } // 6.0 TODO: Remove - internal static class SQLFunctionExtensions + public static class SQLFunctionExtensions { /// /// Get the type that will be effectively returned by the underlying database. @@ -68,7 +71,9 @@ public static IType GetEffectiveReturnType( { try { +#pragma warning disable 618 return sqlFunction.ReturnType(argumentTypes.FirstOrDefault(), mapping); +#pragma warning restore 618 } catch (QueryException) { @@ -83,5 +88,42 @@ public static IType GetEffectiveReturnType( return extendedSqlFunction.GetEffectiveReturnType(argumentTypes, mapping, throwOnError); } + + /// + /// Get the function general return type, ignoring underlying database specifics. + /// + /// The sql function. + /// The types of arguments. + /// The mapping for retrieving the argument sql types. + /// Whether to throw when the number of arguments is invalid or they are not supported. + /// The type returned by the underlying database or when the number of arguments + /// is invalid or they are not supported. + public static IType GetReturnType( + this ISQLFunction sqlFunction, + IEnumerable argumentTypes, + IMapping mapping, + bool throwOnError) + { + if (!(sqlFunction is ISQLFunctionExtended extendedSqlFunction)) + { + try + { +#pragma warning disable 618 + return sqlFunction.ReturnType(argumentTypes.FirstOrDefault(), mapping); +#pragma warning restore 618 + } + catch (QueryException) + { + if (throwOnError) + { + throw; + } + + return null; + } + } + + return extendedSqlFunction.GetReturnType(argumentTypes, mapping, throwOnError); + } } } diff --git a/src/NHibernate/Dialect/Function/ISQLFunctionExtended.cs b/src/NHibernate/Dialect/Function/ISQLFunctionExtended.cs index e2db4747198..8534f347c80 100644 --- a/src/NHibernate/Dialect/Function/ISQLFunctionExtended.cs +++ b/src/NHibernate/Dialect/Function/ISQLFunctionExtended.cs @@ -12,6 +12,16 @@ internal interface ISQLFunctionExtended : ISQLFunction /// string FunctionName { get; } + /// + /// Get the function general return type, ignoring underlying database specifics. + /// + /// The types of arguments. + /// The mapping for retrieving the argument sql types. + /// Whether to throw when the number of arguments is invalid or they are not supported. + /// The type returned by the underlying database or when the number of arguments + /// is invalid or they are not supported. + IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError); + /// /// Get the type that will be effectively returned by the underlying database. /// diff --git a/src/NHibernate/Dialect/Function/IifSQLFunction.cs b/src/NHibernate/Dialect/Function/IifSQLFunction.cs new file mode 100644 index 00000000000..3aee491dc3f --- /dev/null +++ b/src/NHibernate/Dialect/Function/IifSQLFunction.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using NHibernate.Engine; +using NHibernate.Type; + +namespace NHibernate.Dialect.Function +{ + [Serializable] + internal class IifSQLFunction : SQLFunctionTemplate + { + public IifSQLFunction() : base(null, "case when ?1 then ?2 else ?3 end") + { + } + + /// + public override IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + var args = argumentTypes.ToList(); + if (args.Count != 3) + { + if (throwOnError) + { + throw new QueryException($"Invalid number of arguments for iif()"); + } + + return null; + } + + return args[1] ?? args[2]; + } + } +} diff --git a/src/NHibernate/Dialect/Function/IifSafeSQLFunction.cs b/src/NHibernate/Dialect/Function/IifSafeSQLFunction.cs new file mode 100644 index 00000000000..7828d68b9aa --- /dev/null +++ b/src/NHibernate/Dialect/Function/IifSafeSQLFunction.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using NHibernate.Engine; +using NHibernate.Type; + +namespace NHibernate.Dialect.Function +{ + [Serializable] + internal class IifSafeSQLFunction : StandardSafeSQLFunction + { + public IifSafeSQLFunction() : base("iif", 3) + { + } + + /// + public override IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + var args = argumentTypes.ToList(); + if (args.Count != 3) + { + return null; // Not enough information + } + + return args[1] ?? args[2]; + } + } +} diff --git a/src/NHibernate/Dialect/Function/ModulusFunction.cs b/src/NHibernate/Dialect/Function/ModulusFunction.cs new file mode 100644 index 00000000000..46b287b4ab4 --- /dev/null +++ b/src/NHibernate/Dialect/Function/ModulusFunction.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using NHibernate.Engine; +using NHibernate.Type; + +namespace NHibernate.Dialect.Function +{ + [Serializable] + internal class ModulusFunction : StandardSafeSQLFunction + { + private readonly ModulusFunctionTypeDetector _modulusFunctionTypeDetector; + + public ModulusFunction(bool supportDecimals, bool supportFloatingNumbers) + : this(new ModulusFunctionTypeDetector(supportDecimals, supportFloatingNumbers)) + { + } + + public ModulusFunction(ModulusFunctionTypeDetector modulusFunction) : base("mod", NHibernateUtil.Int32, 2) + { + _modulusFunctionTypeDetector = modulusFunction; + } + + /// + public override IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + return _modulusFunctionTypeDetector.GetReturnType(argumentTypes, mapping, throwOnError); + } + } +} diff --git a/src/NHibernate/Dialect/Function/ModulusFunctionTemplate.cs b/src/NHibernate/Dialect/Function/ModulusFunctionTemplate.cs new file mode 100644 index 00000000000..881cc7816be --- /dev/null +++ b/src/NHibernate/Dialect/Function/ModulusFunctionTemplate.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using NHibernate.Engine; +using NHibernate.Type; + +namespace NHibernate.Dialect.Function +{ + [Serializable] + internal class ModulusFunctionTemplate : SQLFunctionTemplate + { + private readonly ModulusFunctionTypeDetector _modulusFunctionTypeDetector; + + public ModulusFunctionTemplate(bool supportDecimals) : this(new ModulusFunctionTypeDetector(supportDecimals)) + { + } + + public ModulusFunctionTemplate(ModulusFunctionTypeDetector modulusFunction) : base(NHibernateUtil.Int32, "((?1) % (?2))") + { + _modulusFunctionTypeDetector = modulusFunction; + } + + /// + public override IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + return _modulusFunctionTypeDetector.GetReturnType(argumentTypes, mapping, throwOnError); + } + } +} diff --git a/src/NHibernate/Dialect/Function/ModulusFunctionTypeDetector.cs b/src/NHibernate/Dialect/Function/ModulusFunctionTypeDetector.cs new file mode 100644 index 00000000000..eab42090b12 --- /dev/null +++ b/src/NHibernate/Dialect/Function/ModulusFunctionTypeDetector.cs @@ -0,0 +1,113 @@ +using System; +using System.Collections.Generic; +using System.Data; +using NHibernate.Engine; +using NHibernate.SqlTypes; +using NHibernate.Type; + +namespace NHibernate.Dialect.Function +{ + [Serializable] + internal class ModulusFunctionTypeDetector + { + // The supported DbTypes with their priorities in order to detect which is the + // returned type when mixing them + private readonly Lazy>> _supportedDbTypesLazy; + private readonly bool _supportDecimals; + private readonly bool _supportFloatingNumbers; + + public ModulusFunctionTypeDetector(bool supportDecimals, bool supportFloatingNumbers) + { + _supportDecimals = supportDecimals; + _supportFloatingNumbers = supportFloatingNumbers; + _supportedDbTypesLazy = new Lazy>>(GetSupportedTypes); + } + + public ModulusFunctionTypeDetector(bool supportDecimals) : this(supportDecimals, false) + { + } + + protected virtual Dictionary> GetSupportedTypes() + { + var types = new Dictionary>() + { + {DbType.Int16, new KeyValuePair(1, NHibernateUtil.Int16)}, + {DbType.Int32, new KeyValuePair(2, NHibernateUtil.Int32)}, + {DbType.Int64, new KeyValuePair(3, NHibernateUtil.Int64)}, + }; + + if (_supportDecimals) + { + types.Add(DbType.Currency, new KeyValuePair(4, NHibernateUtil.Decimal)); + types.Add(DbType.Decimal, new KeyValuePair(4, NHibernateUtil.Decimal)); + } + + if (_supportFloatingNumbers) + { + types.Add(DbType.Single, new KeyValuePair(5, NHibernateUtil.Single)); + types.Add(DbType.Double, new KeyValuePair(6, NHibernateUtil.Double)); + } + + return types; + } + + public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + KeyValuePair currentReturnType = default; + int totalArguments = 0; + foreach (var argumentType in argumentTypes) + { + if (argumentType == null) + { + return null; + } + + SqlType[] sqlTypes; + try + { + sqlTypes = argumentType.SqlTypes(mapping); + } + catch (MappingException me) + { + if (throwOnError) + { + throw new QueryException(me); + } + + return null; + } + + if (sqlTypes.Length != 1) + { + return ThrowOrReturnDefault("Multi-column type can not be in mod()", throwOnError); + } + + if (!_supportedDbTypesLazy.Value.TryGetValue(sqlTypes[0].DbType, out var returnType)) + { + return ThrowOrReturnDefault($"DbType {sqlTypes[0].DbType} is not supported for mod()", throwOnError); + } + + if (returnType.Key > currentReturnType.Key) + { + currentReturnType = returnType; + } + + totalArguments++; + } + + return totalArguments == 2 + ? currentReturnType.Value + : ThrowOrReturnDefault("Invalid number of arguments for mod()", throwOnError); + } + + private IType ThrowOrReturnDefault(string error, bool throwOnError) + { + if (throwOnError) + { + throw new QueryException(error); + } + + return null; + } + } +} diff --git a/src/NHibernate/Dialect/Function/NoArgSQLFunction.cs b/src/NHibernate/Dialect/Function/NoArgSQLFunction.cs index e8e700a9773..1cca8bd6cac 100644 --- a/src/NHibernate/Dialect/Function/NoArgSQLFunction.cs +++ b/src/NHibernate/Dialect/Function/NoArgSQLFunction.cs @@ -3,6 +3,8 @@ using NHibernate.SqlCommand; using NHibernate.Type; using System; +using System.Collections.Generic; +using System.Linq; namespace NHibernate.Dialect.Function { @@ -10,7 +12,7 @@ namespace NHibernate.Dialect.Function /// Summary description for NoArgSQLFunction. /// [Serializable] - public class NoArgSQLFunction : ISQLFunction + public class NoArgSQLFunction : ISQLFunction, ISQLFunctionExtended { public NoArgSQLFunction(string name, IType returnType) : this(name, returnType, true) @@ -19,22 +21,47 @@ public NoArgSQLFunction(string name, IType returnType) public NoArgSQLFunction(string name, IType returnType, bool hasParenthesesIfNoArguments) { +#pragma warning disable 618 Name = name; +#pragma warning restore 618 FunctionReturnType = returnType; HasParenthesesIfNoArguments = hasParenthesesIfNoArguments; } public IType FunctionReturnType { get; protected set; } + // Since v5.3 + [Obsolete("Use FunctionName property instead.")] public string Name { get; protected set; } #region ISQLFunction Members + // Since v5.3 + [Obsolete("Use GetReturnType method instead.")] public IType ReturnType(IType columnType, IMapping mapping) { return FunctionReturnType; } + /// + public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { +#pragma warning disable 618 + return ReturnType(argumentTypes.FirstOrDefault(), mapping); +#pragma warning restore 618 + } + + /// + public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + return GetReturnType(argumentTypes, mapping, throwOnError); + } + + /// +#pragma warning disable 618 + public string FunctionName => Name; +#pragma warning restore 618 + public bool HasArguments { get { return false; } @@ -46,15 +73,15 @@ public virtual SqlString Render(IList args, ISessionFactoryImplementor factory) { if (args.Count > 0) { - throw new QueryException("function takes no arguments: " + Name); + throw new QueryException("function takes no arguments: " + FunctionName); } if (HasParenthesesIfNoArguments) { - return new SqlString(Name + "()"); + return new SqlString(FunctionName + "()"); } - return new SqlString(Name); + return new SqlString(FunctionName); } #endregion diff --git a/src/NHibernate/Dialect/Function/NvlFunction.cs b/src/NHibernate/Dialect/Function/NvlFunction.cs index aa17cf0fbd1..9ffba90f04f 100644 --- a/src/NHibernate/Dialect/Function/NvlFunction.cs +++ b/src/NHibernate/Dialect/Function/NvlFunction.cs @@ -1,5 +1,7 @@ using System; using System.Collections; +using System.Collections.Generic; +using System.Linq; using NHibernate.Engine; using NHibernate.SqlCommand; using NHibernate.Type; @@ -10,7 +12,7 @@ namespace NHibernate.Dialect.Function /// Emulation of coalesce() on Oracle, using multiple nvl() calls /// [Serializable] - public class NvlFunction : ISQLFunction + public class NvlFunction : ISQLFunction, ISQLFunctionExtended { public NvlFunction() { @@ -18,11 +20,30 @@ public NvlFunction() #region ISQLFunction Members + // Since v5.3 + [Obsolete("Use GetReturnType method instead.")] public IType ReturnType(IType columnType, IMapping mapping) { return columnType; } + /// + public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { +#pragma warning disable 618 + return ReturnType(argumentTypes.FirstOrDefault(), mapping); +#pragma warning restore 618 + } + + /// + public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + return GetReturnType(argumentTypes, mapping, throwOnError); + } + + /// + public string FunctionName => "nvl"; + public bool HasArguments { get { return true; } diff --git a/src/NHibernate/Dialect/Function/PositionSubstringFunction.cs b/src/NHibernate/Dialect/Function/PositionSubstringFunction.cs index 9a89bda39c4..717503d6b12 100644 --- a/src/NHibernate/Dialect/Function/PositionSubstringFunction.cs +++ b/src/NHibernate/Dialect/Function/PositionSubstringFunction.cs @@ -1,5 +1,7 @@ using System; using System.Collections; +using System.Collections.Generic; +using System.Linq; using System.Text; using Antlr.Runtime; using NHibernate.Engine; @@ -12,7 +14,7 @@ namespace NHibernate.Dialect.Function /// Emulation of locate() on PostgreSQL /// [Serializable] - public class PositionSubstringFunction : ISQLFunction + public class PositionSubstringFunction : ISQLFunction, ISQLFunctionExtended { public PositionSubstringFunction() { @@ -20,11 +22,30 @@ public PositionSubstringFunction() #region ISQLFunction Members + // Since v5.3 + [Obsolete("Use GetReturnType method instead.")] public IType ReturnType(IType columnType, IMapping mapping) { return NHibernateUtil.Int32; } + /// + public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { +#pragma warning disable 618 + return ReturnType(argumentTypes.FirstOrDefault(), mapping); +#pragma warning restore 618 + } + + /// + public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + return GetReturnType(argumentTypes, mapping, throwOnError); + } + + /// + public string FunctionName => "position"; + public bool HasArguments { get { return true; } diff --git a/src/NHibernate/Dialect/Function/SQLFunctionTemplate.cs b/src/NHibernate/Dialect/Function/SQLFunctionTemplate.cs index 976a5f0e127..4b5a7e0fdc7 100644 --- a/src/NHibernate/Dialect/Function/SQLFunctionTemplate.cs +++ b/src/NHibernate/Dialect/Function/SQLFunctionTemplate.cs @@ -6,6 +6,8 @@ using NHibernate.SqlCommand; using NHibernate.Type; using System; +using System.Collections.Generic; +using System.Linq; namespace NHibernate.Dialect.Function { @@ -18,7 +20,7 @@ namespace NHibernate.Dialect.Function /// parameters with '?' followed by parameter's index (first index is 1). /// [Serializable] - public class SQLFunctionTemplate : ISQLFunction + public class SQLFunctionTemplate : ISQLFunction, ISQLFunctionExtended { private const int InvalidArgumentIndex = -1; private static readonly Regex SplitRegex = new Regex("(\\?[0-9]+)"); @@ -80,11 +82,30 @@ private void InitFromTemplate() #region ISQLFunction Members + // Since v5.3 + [Obsolete("Use GetReturnType method instead.")] public IType ReturnType(IType columnType, IMapping mapping) { return (returnType == null) ? columnType : returnType; } + /// + public virtual IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { +#pragma warning disable 618 + return ReturnType(argumentTypes.FirstOrDefault(), mapping); +#pragma warning restore 618 + } + + /// + public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + return GetReturnType(argumentTypes, mapping, throwOnError); + } + + /// + public virtual string FunctionName => null; + public bool HasArguments { get { return hasArguments; } diff --git a/src/NHibernate/Dialect/Function/StandardSQLFunction.cs b/src/NHibernate/Dialect/Function/StandardSQLFunction.cs index df1adc53f09..279e9089919 100644 --- a/src/NHibernate/Dialect/Function/StandardSQLFunction.cs +++ b/src/NHibernate/Dialect/Function/StandardSQLFunction.cs @@ -4,6 +4,8 @@ using NHibernate.SqlCommand; using NHibernate.Type; using System; +using System.Collections.Generic; +using System.Linq; namespace NHibernate.Dialect.Function { @@ -16,7 +18,7 @@ namespace NHibernate.Dialect.Function /// for processing of the associated function. /// [Serializable] - public class StandardSQLFunction : ISQLFunction + public class StandardSQLFunction : ISQLFunction, ISQLFunctionExtended { private IType returnType = null; protected readonly string name; @@ -43,11 +45,30 @@ public StandardSQLFunction(string name, IType typeValue) #region ISQLFunction Members + // Since v5.3 + [Obsolete("Use GetReturnType method instead.")] public virtual IType ReturnType(IType columnType, IMapping mapping) { return returnType ?? columnType; } + /// + public virtual IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { +#pragma warning disable 618 + return ReturnType(argumentTypes.FirstOrDefault(), mapping); +#pragma warning restore 618 + } + + /// + public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + return GetReturnType(argumentTypes, mapping, throwOnError); + } + + /// + public string FunctionName => name; + public bool HasArguments { get { return true; } diff --git a/src/NHibernate/Dialect/Function/SumQueryFunctionInfo.cs b/src/NHibernate/Dialect/Function/SumQueryFunctionInfo.cs index 3ee6c7d6f2e..6c0af160841 100644 --- a/src/NHibernate/Dialect/Function/SumQueryFunctionInfo.cs +++ b/src/NHibernate/Dialect/Function/SumQueryFunctionInfo.cs @@ -1,7 +1,7 @@ using System; +using System.Collections.Generic; using System.Data; using NHibernate.Engine; -using NHibernate.SqlTypes; using NHibernate.Type; namespace NHibernate.Dialect.Function @@ -12,29 +12,21 @@ class SumQueryFunctionInfo : ClassicAggregateFunction public SumQueryFunctionInfo() : base("sum", false) { } //H3.2 behavior + // Since v5.3 + [Obsolete("Use GetReturnType method instead.")] public override IType ReturnType(IType columnType, IMapping mapping) { - if (columnType == null) - { - throw new ArgumentNullException("columnType"); - } - SqlType[] sqlTypes; - try - { - sqlTypes = columnType.SqlTypes(mapping); - } - catch (MappingException me) - { - throw new QueryException(me); - } + return GetReturnType(new[] { columnType }, mapping, true); + } - if (sqlTypes.Length != 1) + /// + public override IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + if (!TryGetArgumentType(argumentTypes, mapping, throwOnError, out var argumentType, out var sqlType)) { - throw new QueryException("multi-column type can not be in sum()"); + return null; } - SqlType sqlType = sqlTypes[0]; - // TODO: (H3.2 for nullable types) First allow the actual type to control the return value. (the actual underlying sqltype could actually be different) // finally use the sqltype if == on Hibernate types did not find a match. @@ -57,8 +49,8 @@ public override IType ReturnType(IType columnType, IMapping mapping) return NHibernateUtil.UInt64; default: - return columnType; + return argumentType; } } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Dialect/Function/VarArgsSQLFunction.cs b/src/NHibernate/Dialect/Function/VarArgsSQLFunction.cs index 457a1b86d8e..64a0aa213da 100644 --- a/src/NHibernate/Dialect/Function/VarArgsSQLFunction.cs +++ b/src/NHibernate/Dialect/Function/VarArgsSQLFunction.cs @@ -1,5 +1,7 @@ using System; using System.Collections; +using System.Collections.Generic; +using System.Linq; using System.Text; using NHibernate.Engine; using NHibernate.SqlCommand; @@ -12,7 +14,7 @@ namespace NHibernate.Dialect.Function /// with an unlimited number of arguments. /// [Serializable] - public class VarArgsSQLFunction : ISQLFunction + public class VarArgsSQLFunction : ISQLFunction, ISQLFunctionExtended { private readonly string begin; private readonly string sep; @@ -34,11 +36,30 @@ public VarArgsSQLFunction(IType type, string begin, string sep, string end) #region ISQLFunction Members + // Since v5.3 + [Obsolete("Use GetReturnType method instead.")] public virtual IType ReturnType(IType columnType, IMapping mapping) { return (returnType == null) ? columnType : returnType; } + /// + public virtual IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { +#pragma warning disable 618 + return ReturnType(argumentTypes.FirstOrDefault(), mapping); +#pragma warning restore 618 + } + + /// + public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + return GetReturnType(argumentTypes, mapping, throwOnError); + } + + /// + public virtual string FunctionName => null; + public bool HasArguments { get { return true; } diff --git a/src/NHibernate/Dialect/HanaDialectBase.cs b/src/NHibernate/Dialect/HanaDialectBase.cs index 4a8d70db2e8..99203e74605 100644 --- a/src/NHibernate/Dialect/HanaDialectBase.cs +++ b/src/NHibernate/Dialect/HanaDialectBase.cs @@ -1,7 +1,9 @@ using System; using System.Collections; +using System.Collections.Generic; using System.Data; using System.Data.Common; +using System.Linq; using NHibernate.Dialect.Function; using NHibernate.Dialect.Schema; using NHibernate.Engine; @@ -17,7 +19,7 @@ namespace NHibernate.Dialect public abstract class HanaDialectBase : Dialect { [Serializable] - private class TypeConvertingVarArgsSQLFunction : ISQLFunction + private class TypeConvertingVarArgsSQLFunction : ISQLFunction, ISQLFunctionExtended { private readonly string _begin; private readonly string _sep; @@ -33,12 +35,31 @@ public TypeConvertingVarArgsSQLFunction(string begin, string sep, string end) #region ISQLFunction Members + // Since v5.3 + [Obsolete("Use GetReturnType method instead.")] public IType ReturnType(IType columnType, IMapping mapping) { _type = columnType.SqlTypes(mapping)[0]; return columnType; } + /// + public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { +#pragma warning disable 618 + return ReturnType(argumentTypes.FirstOrDefault(), mapping); +#pragma warning restore 618 + } + + /// + public IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + return GetReturnType(argumentTypes, mapping, throwOnError); + } + + /// + public virtual string FunctionName => "cast"; + public bool HasArguments => true; public bool HasParenthesesIfNoArguments => true; @@ -392,7 +413,7 @@ protected virtual void RegisterNHibernateFunctions() RegisterFunction("ceiling", new StandardSQLFunction("ceil")); RegisterFunction("chr", new StandardSQLFunction("char", NHibernateUtil.AnsiChar)); RegisterFunction("date", new SQLFunctionTemplate(NHibernateUtil.Date, "to_date(?1)")); - RegisterFunction("iif", new SQLFunctionTemplate(null, "case when ?1 then ?2 else ?3 end")); + RegisterFunction("iif", new IifSQLFunction()); RegisterFunction("sysdate", new NoArgSQLFunction("current_timestamp", NHibernateUtil.DateTime, false)); RegisterFunction("truncate", new SQLFunctionTemplateWithRequiredParameters(null, "floor(?1 * power(10, ?2)) / power(10, ?2)", new object[] { null, "0" })); RegisterFunction("new_uuid", new NoArgSQLFunction("sysuuid", NHibernateUtil.Guid, false)); @@ -498,7 +519,7 @@ protected virtual void RegisterHANAFunctions() RegisterFunction("map", new VarArgsSQLFunction("map(", ",", ")")); RegisterFunction("mimetype", new StandardSQLFunction("mimetype", NHibernateUtil.String)); RegisterFunction("minute", new StandardSQLFunction("minute", NHibernateUtil.Int32)); - RegisterFunction("mod", new StandardSQLFunction("mod", NHibernateUtil.Int32)); + RegisterFunction("mod", new ModulusFunction(false, false)); RegisterFunction("month", new StandardSQLFunction("month", NHibernateUtil.Int32)); RegisterFunction("monthname", new StandardSQLFunction("monthname", NHibernateUtil.String)); RegisterFunction("months_between", new StandardSQLFunction("months_between", NHibernateUtil.Int32)); diff --git a/src/NHibernate/Dialect/MsSql2000Dialect.cs b/src/NHibernate/Dialect/MsSql2000Dialect.cs index 35c760e85b3..1aec544ed2d 100644 --- a/src/NHibernate/Dialect/MsSql2000Dialect.cs +++ b/src/NHibernate/Dialect/MsSql2000Dialect.cs @@ -313,7 +313,7 @@ protected virtual void RegisterFunctions() RegisterFunction("ln", new StandardSQLFunction("ln", NHibernateUtil.Double)); RegisterFunction("log", new StandardSQLFunction("log", NHibernateUtil.Double)); RegisterFunction("log10", new StandardSQLFunction("log10", NHibernateUtil.Double)); - RegisterFunction("mod", new SQLFunctionTemplate(NHibernateUtil.Int32, "((?1) % (?2))")); + RegisterFunction("mod", new ModulusFunctionTemplate(true)); RegisterFunction("radians", new StandardSQLFunction("radians", NHibernateUtil.Double)); RegisterFunction("rand", new NoArgSQLFunction("rand", NHibernateUtil.Double)); // SQL Server rand returns the same value for each row, unless hacking it with a random seed per row @@ -351,7 +351,7 @@ protected virtual void RegisterFunctions() RegisterFunction("ltrim", new StandardSQLFunction("ltrim")); RegisterFunction("trim", new AnsiTrimEmulationFunction()); - RegisterFunction("iif", new SQLFunctionTemplate(null, "case when ?1 then ?2 else ?3 end")); + RegisterFunction("iif", new IifSQLFunction()); RegisterFunction("replace", new StandardSafeSQLFunction("replace", NHibernateUtil.String, 3)); // Casting to CHAR (without specified length) truncates to 30 characters. diff --git a/src/NHibernate/Dialect/MsSql2012Dialect.cs b/src/NHibernate/Dialect/MsSql2012Dialect.cs index 8649d6861db..7ec84ef66f9 100644 --- a/src/NHibernate/Dialect/MsSql2012Dialect.cs +++ b/src/NHibernate/Dialect/MsSql2012Dialect.cs @@ -53,7 +53,7 @@ public override string QuerySequencesString protected override void RegisterFunctions() { base.RegisterFunctions(); - RegisterFunction("iif", new StandardSafeSQLFunction("iif", 3)); + RegisterFunction("iif", new IifSafeSQLFunction()); } public override SqlString GetLimitString(SqlString querySqlString, SqlString offset, SqlString limit) diff --git a/src/NHibernate/Dialect/MsSqlCeDialect.cs b/src/NHibernate/Dialect/MsSqlCeDialect.cs index fd0baed402b..86edd426e6e 100644 --- a/src/NHibernate/Dialect/MsSqlCeDialect.cs +++ b/src/NHibernate/Dialect/MsSqlCeDialect.cs @@ -191,10 +191,11 @@ protected virtual void RegisterFunctions() RegisterFunction("lower", new StandardSQLFunction("lower")); RegisterFunction("trim", new AnsiTrimEmulationFunction()); - RegisterFunction("iif", new SQLFunctionTemplate(null, "case when ?1 then ?2 else ?3 end")); + RegisterFunction("iif", new IifSQLFunction()); RegisterFunction("concat", new VarArgsSQLFunction(NHibernateUtil.String, "(", "+", ")")); - RegisterFunction("mod", new SQLFunctionTemplate(NHibernateUtil.Int32, "((?1) % (?2))")); + // Modulo is not supported on real, float, money, and numeric data types + RegisterFunction("mod", new ModulusFunctionTemplate(false)); RegisterFunction("round", new StandardSQLFunctionWithRequiredParameters("round", new object[] {null, "0"})); RegisterFunction("truncate", new StandardSQLFunctionWithRequiredParameters("round", new object[] {null, "0", "1"})); diff --git a/src/NHibernate/Dialect/MySQLDialect.cs b/src/NHibernate/Dialect/MySQLDialect.cs index a6caa6d236a..b43c2c632c3 100644 --- a/src/NHibernate/Dialect/MySQLDialect.cs +++ b/src/NHibernate/Dialect/MySQLDialect.cs @@ -1,11 +1,15 @@ using System; +using System.Collections.Generic; using System.Data; using System.Data.Common; +using System.Linq; using System.Text; using NHibernate.Dialect.Function; using NHibernate.Dialect.Schema; +using NHibernate.Engine; using NHibernate.SqlCommand; using NHibernate.SqlTypes; +using NHibernate.Type; using NHibernate.Util; using Environment=NHibernate.Cfg.Environment; @@ -242,7 +246,8 @@ protected virtual void RegisterKeywords() protected virtual void RegisterFunctions() { - RegisterFunction("iif", new StandardSQLFunction("if")); + RegisterFunction("iif", new IfSQLFunction()); + RegisterFunction("mod", new ModulusFunction(true, true)); RegisterFunction("sign", new StandardSQLFunction("sign", NHibernateUtil.Int32)); @@ -549,5 +554,30 @@ public override long TimestampResolutionInTicks public override bool SupportsDistributedTransactions => false; #endregion + + [Serializable] + internal class IfSQLFunction : StandardSQLFunction + { + public IfSQLFunction() : base("if") + { + } + + /// + public override IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + var args = argumentTypes.ToList(); + if (args.Count != 3) + { + if (throwOnError) + { + throw new QueryException($"Invalid number of arguments for iif()"); + } + + return null; + } + + return args[1] ?? args[2]; + } + } } } diff --git a/src/NHibernate/Dialect/Oracle8iDialect.cs b/src/NHibernate/Dialect/Oracle8iDialect.cs index 749c1f0d056..c484b7f35a2 100644 --- a/src/NHibernate/Dialect/Oracle8iDialect.cs +++ b/src/NHibernate/Dialect/Oracle8iDialect.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Data; using System.Data.Common; +using System.Linq; using NHibernate.Dialect.Function; using NHibernate.Dialect.Schema; using NHibernate.Engine; @@ -291,7 +292,7 @@ protected virtual void RegisterFunctions() // Multi-param numeric dialect functions... RegisterFunction("atan2", new StandardSQLFunction("atan2", NHibernateUtil.Double)); RegisterFunction("log", new StandardSQLFunction("log", NHibernateUtil.Int32)); - RegisterFunction("mod", new StandardSQLFunction("mod", NHibernateUtil.Int32)); + RegisterFunction("mod", new ModulusFunction(true, true)); RegisterFunction("nvl", new StandardSQLFunction("nvl")); RegisterFunction("nvl2", new StandardSQLFunction("nvl2")); RegisterFunction("power", new StandardSQLFunction("power", NHibernateUtil.Double)); @@ -304,7 +305,7 @@ protected virtual void RegisterFunctions() RegisterFunction("str", new StandardSQLFunction("to_char", NHibernateUtil.String)); RegisterFunction("strguid", new SQLFunctionTemplate(NHibernateUtil.String, "substr(rawtohex(?1), 7, 2) || substr(rawtohex(?1), 5, 2) || substr(rawtohex(?1), 3, 2) || substr(rawtohex(?1), 1, 2) || '-' || substr(rawtohex(?1), 11, 2) || substr(rawtohex(?1), 9, 2) || '-' || substr(rawtohex(?1), 15, 2) || substr(rawtohex(?1), 13, 2) || '-' || substr(rawtohex(?1), 17, 4) || '-' || substr(rawtohex(?1), 21) ")); - RegisterFunction("iif", new SQLFunctionTemplate(null, "case when ?1 then ?2 else ?3 end")); + RegisterFunction("iif", new IifSQLFunction()); RegisterFunction("band", new Function.BitwiseFunctionOperation("bitand")); RegisterFunction("bor", new SQLFunctionTemplate(null, "?1 + ?2 - BITAND(?1, ?2)")); @@ -577,11 +578,11 @@ public CurrentTimeStamp() : base("current_timestamp", NHibernateUtil.LocalDateTi public override SqlString Render(IList args, ISessionFactoryImplementor factory) { - return new SqlString(Name); + return new SqlString(FunctionName); } } [Serializable] - private class LocateFunction : ISQLFunction + private class LocateFunction : ISQLFunction, ISQLFunctionExtended { private static readonly ISQLFunction LocateWith2Params = new SQLFunctionTemplate(NHibernateUtil.Int32, "instr(?2, ?1)"); @@ -591,11 +592,30 @@ private class LocateFunction : ISQLFunction #region Implementation of ISQLFunction + // Since v5.3 + [Obsolete("Use GetReturnType method instead.")] public IType ReturnType(IType columnType, IMapping mapping) { return NHibernateUtil.Int32; } + /// + public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { +#pragma warning disable 618 + return ReturnType(argumentTypes.FirstOrDefault(), mapping); +#pragma warning restore 618 + } + + /// + public IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + return GetReturnType(argumentTypes, mapping, throwOnError); + } + + /// + public string FunctionName => "instr"; + public bool HasArguments { get { return true; } diff --git a/src/NHibernate/Dialect/OracleLiteDialect.cs b/src/NHibernate/Dialect/OracleLiteDialect.cs index cf5e33914a5..eb5af510652 100644 --- a/src/NHibernate/Dialect/OracleLiteDialect.cs +++ b/src/NHibernate/Dialect/OracleLiteDialect.cs @@ -101,7 +101,7 @@ public OracleLiteDialect() RegisterFunction("translate", new StandardSQLFunction("translate", NHibernateUtil.String)); // Multi-param numeric dialect functions... - RegisterFunction("mod", new StandardSQLFunction("mod", NHibernateUtil.Int32)); + RegisterFunction("mod", new ModulusFunction(true, false)); RegisterFunction("nvl", new StandardSQLFunction("nvl")); // Multi-param date dialect functions... diff --git a/src/NHibernate/Dialect/PostgreSQLDialect.cs b/src/NHibernate/Dialect/PostgreSQLDialect.cs index baa9334d788..44bbdaff493 100644 --- a/src/NHibernate/Dialect/PostgreSQLDialect.cs +++ b/src/NHibernate/Dialect/PostgreSQLDialect.cs @@ -1,7 +1,9 @@ using System; using System.Collections; +using System.Collections.Generic; using System.Data; using System.Data.Common; +using System.Linq; using NHibernate.Dialect.Function; using NHibernate.Dialect.Schema; using NHibernate.Engine; @@ -63,10 +65,10 @@ public PostgreSQLDialect() RegisterFunction("current_timestamp", new NoArgSQLFunction("now", NHibernateUtil.LocalDateTime, true)); RegisterFunction("str", new SQLFunctionTemplate(NHibernateUtil.String, "cast(?1 as varchar)")); RegisterFunction("locate", new PositionSubstringFunction()); - RegisterFunction("iif", new SQLFunctionTemplate(null, "case when ?1 then ?2 else ?3 end")); + RegisterFunction("iif", new IifSQLFunction()); RegisterFunction("replace", new StandardSQLFunction("replace", NHibernateUtil.String)); RegisterFunction("left", new SQLFunctionTemplate(NHibernateUtil.String, "substr(?1,1,?2)")); - RegisterFunction("mod", new SQLFunctionTemplate(NHibernateUtil.Int32, "((?1) % (?2))")); + RegisterFunction("mod", new ModulusFunctionTemplate(true)); RegisterFunction("sign", new StandardSQLFunction("sign", NHibernateUtil.Int32)); RegisterFunction("round", new RoundFunction(false)); @@ -347,7 +349,7 @@ public override string CurrentTimestampSelectString #endregion [Serializable] - private class RoundFunction : ISQLFunction + private class RoundFunction : ISQLFunction, ISQLFunctionExtended { private static readonly ISQLFunction Round = new StandardSQLFunction("round"); private static readonly ISQLFunction Truncate = new StandardSQLFunction("trunc"); @@ -361,7 +363,7 @@ private class RoundFunction : ISQLFunction private readonly ISQLFunction _singleParamFunction; private readonly ISQLFunction _twoParamFunction; - private readonly string _name; + private readonly string _name; // TODO 6.0: convert FunctionName to read-only auto property public RoundFunction(bool truncate) { @@ -379,8 +381,27 @@ public RoundFunction(bool truncate) } } + // Since v5.3 + [Obsolete("Use GetReturnType method instead.")] public IType ReturnType(IType columnType, IMapping mapping) => columnType; + /// + public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { +#pragma warning disable 618 + return ReturnType(argumentTypes.FirstOrDefault(), mapping); +#pragma warning restore 618 + } + + /// + public IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + return GetReturnType(argumentTypes, mapping, throwOnError); + } + + /// + public string FunctionName => _name; + public bool HasArguments => true; public bool HasParenthesesIfNoArguments => true; diff --git a/src/NHibernate/Dialect/SQLiteDialect.cs b/src/NHibernate/Dialect/SQLiteDialect.cs index 5eb1f51dd87..fa71342185f 100644 --- a/src/NHibernate/Dialect/SQLiteDialect.cs +++ b/src/NHibernate/Dialect/SQLiteDialect.cs @@ -110,9 +110,9 @@ protected virtual void RegisterFunctions() RegisterFunction("replace", new StandardSafeSQLFunction("replace", NHibernateUtil.String, 3)); RegisterFunction("chr", new StandardSQLFunction("char", NHibernateUtil.Character)); - RegisterFunction("mod", new SQLFunctionTemplate(NHibernateUtil.Int32, "((?1) % (?2))")); + RegisterFunction("mod", new ModulusFunctionTemplate(false)); - RegisterFunction("iif", new SQLFunctionTemplate(null, "case when ?1 then ?2 else ?3 end")); + RegisterFunction("iif", new IifSQLFunction()); RegisterFunction("round", new StandardSQLFunction("round")); diff --git a/src/NHibernate/Dialect/SybaseASE15Dialect.cs b/src/NHibernate/Dialect/SybaseASE15Dialect.cs index a5233395d36..e45b1f7700a 100644 --- a/src/NHibernate/Dialect/SybaseASE15Dialect.cs +++ b/src/NHibernate/Dialect/SybaseASE15Dialect.cs @@ -93,7 +93,7 @@ public SybaseASE15Dialect() RegisterFunction("lower", new StandardSQLFunction("lower")); RegisterFunction("ltrim", new StandardSQLFunction("ltrim")); RegisterFunction("minute", new SQLFunctionTemplate(NHibernateUtil.Int32, "datepart(minute, ?1)")); - RegisterFunction("mod", new SQLFunctionTemplate(NHibernateUtil.Int32, "?1 % ?2")); + RegisterFunction("mod", new ModulusFunctionTemplate(false)); RegisterFunction("month", new StandardSQLFunction("month", NHibernateUtil.Int32)); RegisterFunction("pi", new NoArgSQLFunction("pi", NHibernateUtil.Double)); RegisterFunction("radians", new StandardSQLFunction("radians", NHibernateUtil.Double)); diff --git a/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs b/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs index f5b61250049..290e626671f 100644 --- a/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs +++ b/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs @@ -137,7 +137,7 @@ protected virtual void RegisterMathFunctions() RegisterFunction("floor", new StandardSQLFunction("floor", NHibernateUtil.Double)); RegisterFunction("log", new StandardSQLFunction("log", NHibernateUtil.Double)); RegisterFunction("log10", new StandardSQLFunction("log10", NHibernateUtil.Double)); - RegisterFunction("mod", new StandardSQLFunction("mod")); + RegisterFunction("mod", new ModulusFunction(false, false)); RegisterFunction("pi", new NoArgSQLFunction("pi", NHibernateUtil.Double, true)); RegisterFunction("power", new StandardSQLFunction("power", NHibernateUtil.Double)); RegisterFunction("radians", new StandardSQLFunction("radians", NHibernateUtil.Double)); @@ -355,7 +355,7 @@ protected virtual void RegisterMiscellaneousFunctions() RegisterFunction("transactsql", new StandardSQLFunction("transactsql", NHibernateUtil.String)); RegisterFunction("varexists", new StandardSQLFunction("varexists", NHibernateUtil.Int32)); RegisterFunction("watcomsql", new StandardSQLFunction("watcomsql", NHibernateUtil.String)); - RegisterFunction("iif", new SQLFunctionTemplate(null, "case when ?1 then ?2 else ?3 end")); + RegisterFunction("iif", new IifSQLFunction()); } #region private static readonly string[] DialectKeywords = { ... } diff --git a/src/NHibernate/Hql/Ast/ANTLR/SessionFactoryHelperExtensions.cs b/src/NHibernate/Hql/Ast/ANTLR/SessionFactoryHelperExtensions.cs index b88f22ec0a6..33d6103ec86 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/SessionFactoryHelperExtensions.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/SessionFactoryHelperExtensions.cs @@ -1,4 +1,6 @@ using System; +using System.Collections.Generic; +using System.Linq; using NHibernate.Dialect.Function; using NHibernate.Engine; using NHibernate.Hql.Ast.ANTLR.Tree; @@ -9,6 +11,7 @@ using NHibernate.Type; using NHibernate.Util; using IASTNode=NHibernate.Hql.Ast.ANTLR.Tree.IASTNode; +using IQueryable = NHibernate.Persister.Entity.IQueryable; namespace NHibernate.Hql.Ast.ANTLR { @@ -67,6 +70,8 @@ private ISQLFunction RequireSQLFunction(string functionName) /// The function name. /// The first argument expression. /// the function return type given the function name and the first argument expression node. + // Since v5.3 + [Obsolete("Please use overload with a IEnumerable parameter instead.")] public IType FindFunctionReturnType(String functionName, IASTNode first) { // locate the registered function by the given name @@ -90,6 +95,44 @@ public IType FindFunctionReturnType(String functionName, IASTNode first) return sqlFunction.ReturnType(argumentType, _sfi); } + /// + /// Find the function return type given the function name and the arguments expression nodes. + /// + /// The function name. + /// The function arguments expression nodes. + /// The function return type given the function name and the arguments expression nodes. + public IType FindFunctionReturnType(string functionName, IEnumerable arguments) + { + var sqlFunction = RequireSQLFunction(functionName); + if (!(sqlFunction is ISQLFunctionExtended extendedSqlFunction)) + { +#pragma warning disable 618 + return FindFunctionReturnType(functionName, arguments.FirstOrDefault()); +#pragma warning restore 618 + } + + var argumentTypes = new List(); + if (sqlFunction is CastFunction) + { + argumentTypes.Add(TypeFactory.HeuristicType(arguments.First().NextSibling.Text)); + } + else + { + foreach (var argument in arguments) + { + IType type = null; + if (argument is SqlNode sqlNode) + { + type = sqlNode.DataType; + } + + argumentTypes.Add(type); + } + } + + return extendedSqlFunction.GetReturnType(argumentTypes, _sfi, true); + } + /// /// Given a (potentially unqualified) class name, locate its imported qualified name. /// diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs index a3887a56c01..d330103082a 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using Antlr.Runtime; using NHibernate.Dialect.Function; using NHibernate.Type; @@ -38,7 +39,7 @@ public override IType DataType get { // Get the function return value type, based on the type of the first argument. - return SessionFactoryHelper.FindFunctionReturnType(Text, GetChild(0)); + return SessionFactoryHelper.FindFunctionReturnType(Text, (IEnumerable) this); } set { diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs index ef6d6272043..e381c921bbe 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs @@ -1,4 +1,5 @@ -using Antlr.Runtime; +using System.Linq; +using Antlr.Runtime; using NHibernate.Dialect.Function; using NHibernate.Hql.Ast.ANTLR.Util; using NHibernate.Type; @@ -20,7 +21,7 @@ public override IType DataType { get { - return SessionFactoryHelper.FindFunctionReturnType(Text, null); + return SessionFactoryHelper.FindFunctionReturnType(Text, Enumerable.Empty()); } set { diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/IdentNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/IdentNode.cs index 38bb20080bd..095f8b03545 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/IdentNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/IdentNode.cs @@ -1,13 +1,14 @@ using System; using System.Collections.Generic; +using System.Linq; using Antlr.Runtime; using NHibernate.Dialect.Function; using NHibernate.Hql.Ast.ANTLR.Util; using NHibernate.Persister.Collection; -using NHibernate.Persister.Entity; using NHibernate.SqlCommand; using NHibernate.Type; using NHibernate.Util; +using IQueryable = NHibernate.Persister.Entity.IQueryable; namespace NHibernate.Hql.Ast.ANTLR.Tree { @@ -38,7 +39,7 @@ public override IType DataType return fe.DataType; } ISQLFunction sf = Walker.SessionFactoryHelper.FindSQLFunction(Text); - return sf?.ReturnType(null, Walker.SessionFactoryHelper.Factory); + return sf?.GetReturnType(Enumerable.Empty(), Walker.SessionFactoryHelper.Factory, true); } set diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/MethodNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/MethodNode.cs index f36ac08ff69..dc6c04042e7 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/MethodNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/MethodNode.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using Antlr.Runtime; using NHibernate.Dialect.Function; @@ -194,14 +195,7 @@ private void DialectFunction(IASTNode exprList) if (_function != null) { - IASTNode child = null; - - if (exprList != null) - { - child = _methodName == "iif" ? exprList.GetChild(1) : exprList.GetChild(0); - } - - DataType = SessionFactoryHelper.FindFunctionReturnType(_methodName, child); + DataType = SessionFactoryHelper.FindFunctionReturnType(_methodName, (IEnumerable) exprList); } //TODO: /*else {