diff --git a/src/NHibernate/Criterion/SqlFunctionProjection.cs b/src/NHibernate/Criterion/SqlFunctionProjection.cs
index 95d5298fa77..0c3bba98f98 100644
--- a/src/NHibernate/Criterion/SqlFunctionProjection.cs
+++ b/src/NHibernate/Criterion/SqlFunctionProjection.cs
@@ -125,7 +125,7 @@ private IType GetReturnType(ICriteria criteria, ICriteriaQuery criteriaQuery)
var resultType = returnType ?? returnTypeProjection?.GetTypes(criteria, criteriaQuery).FirstOrDefault();
- return sqlFunction.ReturnType(resultType, criteriaQuery.Factory);
+ return sqlFunction.GetReturnType(new[] {resultType}, criteriaQuery.Factory, true);
}
public override TypedValue[] GetTypedValues(ICriteria criteria, ICriteriaQuery criteriaQuery)
diff --git a/src/NHibernate/Dialect/DB2Dialect.cs b/src/NHibernate/Dialect/DB2Dialect.cs
index bd24dda28d7..4ec34033f9e 100644
--- a/src/NHibernate/Dialect/DB2Dialect.cs
+++ b/src/NHibernate/Dialect/DB2Dialect.cs
@@ -131,7 +131,7 @@ public DB2Dialect()
RegisterFunction("length", new StandardSQLFunction("length", NHibernateUtil.Int32));
RegisterFunction("ltrim", new StandardSQLFunction("ltrim"));
- RegisterFunction("mod", new StandardSQLFunction("mod", NHibernateUtil.Int32));
+ RegisterFunction("mod", new ModulusFunction(true, false));
RegisterFunction("substring", new StandardSQLFunction("substr", NHibernateUtil.String));
diff --git a/src/NHibernate/Dialect/Dialect.cs b/src/NHibernate/Dialect/Dialect.cs
index 53c30b55175..6153309c7ac 100644
--- a/src/NHibernate/Dialect/Dialect.cs
+++ b/src/NHibernate/Dialect/Dialect.cs
@@ -94,7 +94,7 @@ protected Dialect()
RegisterFunction("coalesce", new StandardSQLFunction("coalesce"));
RegisterFunction("nullif", new StandardSQLFunction("nullif"));
RegisterFunction("abs", new StandardSQLFunction("abs"));
- RegisterFunction("mod", new StandardSQLFunction("mod", NHibernateUtil.Int32));
+ RegisterFunction("mod", new ModulusFunction(false, false));
RegisterFunction("sqrt", new StandardSQLFunction("sqrt", NHibernateUtil.Double));
RegisterFunction("upper", new StandardSQLFunction("upper"));
RegisterFunction("lower", new StandardSQLFunction("lower"));
diff --git a/src/NHibernate/Dialect/FirebirdDialect.cs b/src/NHibernate/Dialect/FirebirdDialect.cs
index ba37c00cfaa..0f26f03d15a 100644
--- a/src/NHibernate/Dialect/FirebirdDialect.cs
+++ b/src/NHibernate/Dialect/FirebirdDialect.cs
@@ -1,5 +1,6 @@
using System;
using System.Collections;
+using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using NHibernate.Dialect.Function;
@@ -147,7 +148,7 @@ public CastedFunction(string name, IType returnType) : base(name, returnType, fa
public override SqlString Render(IList args, ISessionFactoryImplementor factory)
{
- return new SqlString("cast('", Name, "' as ", FunctionReturnType.SqlTypes(factory)[0].ToString(), ")");
+ return new SqlString("cast('", FunctionName, "' as ", FunctionReturnType.SqlTypes(factory)[0].ToString(), ")");
}
}
@@ -160,7 +161,7 @@ public CurrentTimeStamp() : base("current_timestamp", NHibernateUtil.LocalDateTi
public override SqlString Render(IList args, ISessionFactoryImplementor factory)
{
- return new SqlString(Name);
+ return new SqlString(FunctionName);
}
}
@@ -205,7 +206,7 @@ public override string SelectGUIDString
}
[Serializable]
- private class PositionFunction : ISQLFunction
+ private class PositionFunction : ISQLFunction, ISQLFunctionExtended
{
// The cast is needed, at least in the case that ?3 is a named integer parameter, otherwise firebird will generate an error.
// We have a unit test to cover this potential firebird bug.
@@ -214,11 +215,28 @@ private class PositionFunction : ISQLFunction
private static readonly ISQLFunction LocateWith3Params = new SQLFunctionTemplate(NHibernateUtil.Int32,
"position(?1, ?2, cast(?3 as int))");
+ // Since v5.3
+ [Obsolete("Use GetReturnType method instead.")]
public IType ReturnType(IType columnType, IMapping mapping)
{
return NHibernateUtil.Int32;
}
+ ///
+ public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ return NHibernateUtil.Int32;
+ }
+
+ ///
+ public IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ return GetReturnType(argumentTypes, mapping, throwOnError);
+ }
+
+ ///
+ public string FunctionName => "position";
+
public bool HasArguments
{
get { return true; }
@@ -418,7 +436,8 @@ private void OverrideStandardHQLFunctions()
RegisterFunction("nullif", new StandardSafeSQLFunction("nullif", 2));
RegisterFunction("lower", new StandardSafeSQLFunction("lower", NHibernateUtil.String, 1));
RegisterFunction("upper", new StandardSafeSQLFunction("upper", NHibernateUtil.String, 1));
- RegisterFunction("mod", new StandardSafeSQLFunction("mod", NHibernateUtil.Double, 2));
+ // Modulo does not throw for decimal parameters but they are casted to int by Firebird, which produces unexpected results
+ RegisterFunction("mod", new ModulusFunction(false, false));
RegisterFunction("str", new SQLFunctionTemplate(NHibernateUtil.String, "cast(?1 as VARCHAR(255))"));
RegisterFunction("strguid", new StandardSQLFunction("uuid_to_char", NHibernateUtil.String));
RegisterFunction("sysdate", new CastedFunction("today", NHibernateUtil.Date));
@@ -437,7 +456,7 @@ private void RegisterFirebirdServerEmbeddedFunctions()
RegisterFunction("yesterday", new CastedFunction("yesterday", NHibernateUtil.Date));
RegisterFunction("tomorrow", new CastedFunction("tomorrow", NHibernateUtil.Date));
RegisterFunction("now", new CastedFunction("now", NHibernateUtil.DateTime));
- RegisterFunction("iif", new StandardSafeSQLFunction("iif", 3));
+ RegisterFunction("iif", new IifSafeSQLFunction());
// New embedded functions in FB 2.0 (http://www.firebirdsql.org/rlsnotes20/rnfbtwo-str.html#str-string-func)
RegisterFunction("char_length", new StandardSafeSQLFunction("char_length", NHibernateUtil.Int64, 1));
RegisterFunction("bit_length", new StandardSafeSQLFunction("bit_length", NHibernateUtil.Int64, 1));
diff --git a/src/NHibernate/Dialect/Function/AnsiSubstringFunction.cs b/src/NHibernate/Dialect/Function/AnsiSubstringFunction.cs
index b12e4e7e041..574aa13a136 100644
--- a/src/NHibernate/Dialect/Function/AnsiSubstringFunction.cs
+++ b/src/NHibernate/Dialect/Function/AnsiSubstringFunction.cs
@@ -1,5 +1,7 @@
using System;
using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
using System.Text;
using NHibernate.Engine;
using NHibernate.SqlCommand;
@@ -22,15 +24,34 @@ namespace NHibernate.Dialect.Function
///]]>
///
[Serializable]
- public class AnsiSubstringFunction : ISQLFunction
+ public class AnsiSubstringFunction : ISQLFunction, ISQLFunctionExtended
{
#region ISQLFunction Members
+ // Since v5.3
+ [Obsolete("Use GetReturnType method instead.")]
public IType ReturnType(IType columnType, IMapping mapping)
{
return NHibernateUtil.String;
}
+ ///
+ public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+#pragma warning disable 618
+ return ReturnType(argumentTypes.FirstOrDefault(), mapping);
+#pragma warning restore 618
+ }
+
+ ///
+ public IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ return GetReturnType(argumentTypes, mapping, throwOnError);
+ }
+
+ ///
+ public string FunctionName => "substring";
+
public bool HasArguments
{
get { return true; }
diff --git a/src/NHibernate/Dialect/Function/AnsiTrimEmulationFunction.cs b/src/NHibernate/Dialect/Function/AnsiTrimEmulationFunction.cs
index e9cde7d70f2..995764d303e 100644
--- a/src/NHibernate/Dialect/Function/AnsiTrimEmulationFunction.cs
+++ b/src/NHibernate/Dialect/Function/AnsiTrimEmulationFunction.cs
@@ -1,7 +1,7 @@
using System;
using System.Collections;
using System.Collections.Generic;
-
+using System.Linq;
using NHibernate.Engine;
using NHibernate.SqlCommand;
using NHibernate.Type;
@@ -17,7 +17,7 @@ namespace NHibernate.Dialect.Function
/// functionality.
///
[Serializable]
- public class AnsiTrimEmulationFunction : ISQLFunction, IFunctionGrammar
+ public class AnsiTrimEmulationFunction : ISQLFunction, IFunctionGrammar, ISQLFunctionExtended
{
private static readonly ISQLFunction LeadingSpaceTrim = new SQLFunctionTemplate(NHibernateUtil.String, "ltrim( ?1 )");
private static readonly ISQLFunction TrailingSpaceTrim = new SQLFunctionTemplate(NHibernateUtil.String, "rtrim( ?1 )");
@@ -76,11 +76,30 @@ public AnsiTrimEmulationFunction(string replaceFunction)
#region ISQLFunction Members
+ // Since v5.3
+ [Obsolete("Use GetReturnType method instead.")]
public IType ReturnType(IType columnType, IMapping mapping)
{
return NHibernateUtil.String;
}
+ ///
+ public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+#pragma warning disable 618
+ return ReturnType(argumentTypes.FirstOrDefault(), mapping);
+#pragma warning restore 618
+ }
+
+ ///
+ public IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ return GetReturnType(argumentTypes, mapping, throwOnError);
+ }
+
+ ///
+ public string FunctionName => null;
+
public bool HasArguments
{
get { return true; }
diff --git a/src/NHibernate/Dialect/Function/AvgQueryFunctionInfo.cs b/src/NHibernate/Dialect/Function/AvgQueryFunctionInfo.cs
index 47f8903cb48..c780479747f 100644
--- a/src/NHibernate/Dialect/Function/AvgQueryFunctionInfo.cs
+++ b/src/NHibernate/Dialect/Function/AvgQueryFunctionInfo.cs
@@ -1,6 +1,6 @@
using System;
+using System.Collections.Generic;
using NHibernate.Engine;
-using NHibernate.SqlTypes;
using NHibernate.Type;
namespace NHibernate.Dialect.Function
@@ -10,27 +10,22 @@ class AvgQueryFunctionInfo : ClassicAggregateFunction
{
public AvgQueryFunctionInfo() : base("avg", false) { }
+ // Since v5.3
+ [Obsolete("Use GetReturnType method instead.")]
public override IType ReturnType(IType columnType, IMapping mapping)
{
- if (columnType == null)
- {
- throw new ArgumentNullException("columnType");
- }
- SqlType[] sqlTypes;
- try
- {
- sqlTypes = columnType.SqlTypes(mapping);
- }
- catch (MappingException me)
- {
- throw new QueryException(me);
- }
+ return GetReturnType(new[] {columnType}, mapping, true);
+ }
- if (sqlTypes.Length != 1)
+ ///
+ public override IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ if (!TryGetArgumentType(argumentTypes, mapping, throwOnError, out _, out _))
{
- throw new QueryException("multi-column type can not be in avg()");
+ return null;
}
+
return NHibernateUtil.Double;
}
}
-}
\ No newline at end of file
+}
diff --git a/src/NHibernate/Dialect/Function/BitwiseFunctionOperation.cs b/src/NHibernate/Dialect/Function/BitwiseFunctionOperation.cs
index 5ebfbbe4c59..db6ba0c5365 100644
--- a/src/NHibernate/Dialect/Function/BitwiseFunctionOperation.cs
+++ b/src/NHibernate/Dialect/Function/BitwiseFunctionOperation.cs
@@ -1,5 +1,7 @@
using System;
using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
using NHibernate.Engine;
using NHibernate.SqlCommand;
using NHibernate.Type;
@@ -27,8 +29,9 @@ namespace NHibernate.Dialect.Function
/// Treats bitwise operations as SQL function calls.
///
[Serializable]
- public class BitwiseFunctionOperation : ISQLFunction
+ public class BitwiseFunctionOperation : ISQLFunction, ISQLFunctionExtended
{
+ // TODO 6.0: convert FunctionName to read-only auto-property
private readonly string _functionName;
///
@@ -45,11 +48,30 @@ public BitwiseFunctionOperation(string functionName)
#region ISQLFunction Members
///
+ // Since v5.3
+ [Obsolete("Use GetReturnType method instead.")]
public IType ReturnType(IType columnType, IMapping mapping)
{
return NHibernateUtil.Int64;
}
+ ///
+ public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+#pragma warning disable 618
+ return ReturnType(argumentTypes.FirstOrDefault(), mapping);
+#pragma warning restore 618
+ }
+
+ ///
+ public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ return GetReturnType(argumentTypes, mapping, throwOnError);
+ }
+
+ ///
+ public string FunctionName => _functionName;
+
///
public bool HasArguments => true;
diff --git a/src/NHibernate/Dialect/Function/BitwiseNativeOperation.cs b/src/NHibernate/Dialect/Function/BitwiseNativeOperation.cs
index 01e9e462d97..3bffb4b1d59 100644
--- a/src/NHibernate/Dialect/Function/BitwiseNativeOperation.cs
+++ b/src/NHibernate/Dialect/Function/BitwiseNativeOperation.cs
@@ -1,5 +1,7 @@
using System;
using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
using NHibernate.Engine;
using NHibernate.SqlCommand;
using NHibernate.Type;
@@ -32,7 +34,7 @@ namespace NHibernate.Dialect.Function
/// Treats bitwise operations as native operations.
///
[Serializable]
- public class BitwiseNativeOperation : ISQLFunction
+ public class BitwiseNativeOperation : ISQLFunction, ISQLFunctionExtended
{
private readonly string _sqlOpToken;
private readonly bool _isUnary;
@@ -65,11 +67,30 @@ public BitwiseNativeOperation(string sqlOpToken, bool isUnary)
#region ISQLFunction Members
///
+ // Since v5.3
+ [Obsolete("Use GetReturnType method instead.")]
public IType ReturnType(IType columnType, IMapping mapping)
{
return NHibernateUtil.Int64;
}
+ ///
+ public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+#pragma warning disable 618
+ return ReturnType(argumentTypes.FirstOrDefault(), mapping);
+#pragma warning restore 618
+ }
+
+ ///
+ public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ return GetReturnType(argumentTypes, mapping, throwOnError);
+ }
+
+ ///
+ public string FunctionName => null;
+
///
public bool HasArguments => true;
diff --git a/src/NHibernate/Dialect/Function/CastFunction.cs b/src/NHibernate/Dialect/Function/CastFunction.cs
index 8580da1997f..2e02ea9dad1 100644
--- a/src/NHibernate/Dialect/Function/CastFunction.cs
+++ b/src/NHibernate/Dialect/Function/CastFunction.cs
@@ -1,5 +1,7 @@
using System;
using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
using System.Xml;
using NHibernate.Engine;
using NHibernate.SqlCommand;
@@ -12,10 +14,12 @@ namespace NHibernate.Dialect.Function
/// ANSI-SQL style cast(foo as type) where the type is a NHibernate type
///
[Serializable]
- public class CastFunction : ISQLFunction, IFunctionGrammar
+ public class CastFunction : ISQLFunction, IFunctionGrammar, ISQLFunctionExtended
{
#region ISQLFunction Members
+ // Since v5.3
+ [Obsolete("Use GetReturnType method instead.")]
public IType ReturnType(IType columnType, IMapping mapping)
{
//note there is a weird implementation in the client side
@@ -23,6 +27,23 @@ public IType ReturnType(IType columnType, IMapping mapping)
return columnType;
}
+ ///
+ public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+#pragma warning disable 618
+ return ReturnType(argumentTypes.FirstOrDefault(), mapping);
+#pragma warning restore 618
+ }
+
+ ///
+ public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ return GetReturnType(argumentTypes, mapping, throwOnError);
+ }
+
+ ///
+ public string FunctionName => "cast";
+
public bool HasArguments
{
get { return true; }
diff --git a/src/NHibernate/Dialect/Function/CharIndexFunction.cs b/src/NHibernate/Dialect/Function/CharIndexFunction.cs
index 1cdfafbff8c..7cc9f7bf543 100644
--- a/src/NHibernate/Dialect/Function/CharIndexFunction.cs
+++ b/src/NHibernate/Dialect/Function/CharIndexFunction.cs
@@ -1,5 +1,7 @@
using System;
using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
using System.Text;
using NHibernate.Engine;
using NHibernate.SqlCommand;
@@ -11,7 +13,7 @@ namespace NHibernate.Dialect.Function
/// Emulation of locate() on Sybase
///
[Serializable]
- public class CharIndexFunction : ISQLFunction
+ public class CharIndexFunction : ISQLFunction, ISQLFunctionExtended
{
public CharIndexFunction()
{
@@ -19,11 +21,30 @@ public CharIndexFunction()
#region ISQLFunction Members
+ // Since v5.3
+ [Obsolete("Use GetReturnType method instead.")]
public IType ReturnType(IType columnType, IMapping mapping)
{
return NHibernateUtil.Int32;
}
+ ///
+ public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+#pragma warning disable 618
+ return ReturnType(argumentTypes.FirstOrDefault(), mapping);
+#pragma warning restore 618
+ }
+
+ ///
+ public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ return GetReturnType(argumentTypes, mapping, throwOnError);
+ }
+
+ ///
+ public string FunctionName => "charindex";
+
public bool HasArguments
{
get { return true; }
diff --git a/src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs b/src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs
index e0b78f1e1e2..cf0d4cf219c 100644
--- a/src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs
+++ b/src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs
@@ -4,6 +4,7 @@
using System.Linq;
using NHibernate.Engine;
using NHibernate.SqlCommand;
+using NHibernate.SqlTypes;
using NHibernate.Type;
namespace NHibernate.Dialect.Function
@@ -40,15 +41,25 @@ public ClassicAggregateFunction(string name, bool acceptAsterisk, IType typeValu
#region ISQLFunction Members
+ // Since v5.3
+ [Obsolete("Use GetReturnType method instead.")]
public virtual IType ReturnType(IType columnType, IMapping mapping)
{
return returnType ?? columnType;
}
///
- public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ public virtual IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
{
+#pragma warning disable 618
return ReturnType(argumentTypes.FirstOrDefault(), mapping);
+#pragma warning restore 618
+ }
+
+ ///
+ public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ return GetReturnType(argumentTypes, mapping, throwOnError);
}
///
@@ -100,6 +111,63 @@ public SqlString Render(IList args, ISessionFactoryImplementor factory)
#endregion
+ protected bool TryGetArgumentType(
+ IEnumerable argumentTypes,
+ IMapping mapping,
+ bool throwOnError,
+ out IType argumentType,
+ out SqlType sqlType)
+ {
+ sqlType = null;
+ argumentType = null;
+ if (argumentTypes.Count() != 1)
+ {
+ if (throwOnError)
+ {
+ throw new QueryException($"Invalid number of arguments for {name}()");
+ }
+
+ return false;
+ }
+
+ argumentType = argumentTypes.First();
+ if (argumentType == null)
+ {
+ // The argument is a parameter (e.g. select avg(:p1) from OrderLine). In that case, if the datatype is needed
+ // a QueryException will be thrown in SelectClause class, otherwise the query will be executed
+ // (e.g. select case when avg(:p1) > 0 then 1 else 0 end from OrderLine).
+ return false;
+ }
+
+ SqlType[] sqlTypes;
+ try
+ {
+ sqlTypes = argumentType.SqlTypes(mapping);
+ }
+ catch (MappingException me)
+ {
+ if (throwOnError)
+ {
+ throw new QueryException(me);
+ }
+
+ return false;
+ }
+
+ if (sqlTypes.Length != 1)
+ {
+ if (throwOnError)
+ {
+ throw new QueryException($"Multi-column type can not be in {name}()");
+ }
+
+ return false;
+ }
+
+ sqlType = sqlTypes[0];
+ return true;
+ }
+
public override string ToString()
{
return name;
diff --git a/src/NHibernate/Dialect/Function/ClassicAvgFunction.cs b/src/NHibernate/Dialect/Function/ClassicAvgFunction.cs
index ec802a85096..737ed7f2ae4 100644
--- a/src/NHibernate/Dialect/Function/ClassicAvgFunction.cs
+++ b/src/NHibernate/Dialect/Function/ClassicAvgFunction.cs
@@ -1,7 +1,7 @@
using System;
+using System.Collections.Generic;
using System.Data;
using NHibernate.Engine;
-using NHibernate.SqlTypes;
using NHibernate.Type;
namespace NHibernate.Dialect.Function
@@ -16,36 +16,28 @@ public ClassicAvgFunction() : base("avg", false)
{
}
+ // Since v5.3
+ [Obsolete("Use GetReturnType method instead.")]
public override IType ReturnType(IType columnType, IMapping mapping)
{
- if (columnType == null)
- {
- throw new ArgumentNullException("columnType");
- }
- SqlType[] sqlTypes;
- try
- {
- sqlTypes = columnType.SqlTypes(mapping);
- }
- catch (MappingException me)
- {
- throw new QueryException(me);
- }
+ return GetReturnType(new[] {columnType}, mapping, true);
+ }
- if (sqlTypes.Length != 1)
+ ///
+ public override IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ if (!TryGetArgumentType(argumentTypes, mapping, throwOnError, out var argumentType, out var sqlType))
{
- throw new QueryException("multi-column type can not be in avg()");
+ return null;
}
- SqlType sqlType = sqlTypes[0];
-
if (sqlType.DbType == DbType.Int16 || sqlType.DbType == DbType.Int32 || sqlType.DbType == DbType.Int64)
{
return NHibernateUtil.Single;
}
else
{
- return columnType;
+ return argumentType;
}
}
}
diff --git a/src/NHibernate/Dialect/Function/ClassicCountFunction.cs b/src/NHibernate/Dialect/Function/ClassicCountFunction.cs
index 821146519e9..6cf138866df 100644
--- a/src/NHibernate/Dialect/Function/ClassicCountFunction.cs
+++ b/src/NHibernate/Dialect/Function/ClassicCountFunction.cs
@@ -1,4 +1,6 @@
using System;
+using System.Collections.Generic;
+using System.Linq;
using NHibernate.Engine;
using NHibernate.Type;
@@ -14,9 +16,19 @@ public ClassicCountFunction() : base("count", true)
{
}
+ // Since v5.3
+ [Obsolete("Use GetReturnType method instead.")]
public override IType ReturnType(IType columnType, IMapping mapping)
{
return NHibernateUtil.Int32;
}
+
+ public override IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ // 6.0 TODO: return NHibernateUtil.Int32;
+#pragma warning disable 618
+ return ReturnType(argumentTypes.FirstOrDefault(), mapping);
+#pragma warning restore 618
+ }
}
-}
\ No newline at end of file
+}
diff --git a/src/NHibernate/Dialect/Function/CountQueryFunctionInfo.cs b/src/NHibernate/Dialect/Function/CountQueryFunctionInfo.cs
index b201a7522ea..6f2d873be28 100644
--- a/src/NHibernate/Dialect/Function/CountQueryFunctionInfo.cs
+++ b/src/NHibernate/Dialect/Function/CountQueryFunctionInfo.cs
@@ -1,4 +1,6 @@
using System;
+using System.Collections.Generic;
+using System.Linq;
using NHibernate.Engine;
using NHibernate.Type;
@@ -9,9 +11,19 @@ class CountQueryFunctionInfo : ClassicAggregateFunction
{
public CountQueryFunctionInfo() : base("count", true) { }
+ // Since v5.3
+ [Obsolete("Use GetReturnType method instead.")]
public override IType ReturnType(IType columnType, IMapping mapping)
{
return NHibernateUtil.Int64;
}
+
+ public override IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ // 6.0 TODO: return NHibernateUtil.Int64;
+#pragma warning disable 618
+ return ReturnType(argumentTypes.FirstOrDefault(), mapping);
+#pragma warning restore 618
+ }
}
-}
\ No newline at end of file
+}
diff --git a/src/NHibernate/Dialect/Function/ISQLFunction.cs b/src/NHibernate/Dialect/Function/ISQLFunction.cs
index 5302625a01e..8ced743a4b3 100644
--- a/src/NHibernate/Dialect/Function/ISQLFunction.cs
+++ b/src/NHibernate/Dialect/Function/ISQLFunction.cs
@@ -1,3 +1,4 @@
+using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
@@ -23,6 +24,8 @@ public interface ISQLFunction
/// The type of the first argument
///
///
+ // Since v5.3
+ [Obsolete("Use GetReturnType extension method instead.")]
IType ReturnType(IType columnType, IMapping mapping);
///
@@ -45,7 +48,7 @@ public interface ISQLFunction
}
// 6.0 TODO: Remove
- internal static class SQLFunctionExtensions
+ public static class SQLFunctionExtensions
{
///
/// Get the type that will be effectively returned by the underlying database.
@@ -68,7 +71,9 @@ public static IType GetEffectiveReturnType(
{
try
{
+#pragma warning disable 618
return sqlFunction.ReturnType(argumentTypes.FirstOrDefault(), mapping);
+#pragma warning restore 618
}
catch (QueryException)
{
@@ -83,5 +88,42 @@ public static IType GetEffectiveReturnType(
return extendedSqlFunction.GetEffectiveReturnType(argumentTypes, mapping, throwOnError);
}
+
+ ///
+ /// Get the function general return type, ignoring underlying database specifics.
+ ///
+ /// The sql function.
+ /// The types of arguments.
+ /// The mapping for retrieving the argument sql types.
+ /// Whether to throw when the number of arguments is invalid or they are not supported.
+ /// The type returned by the underlying database or when the number of arguments
+ /// is invalid or they are not supported.
+ public static IType GetReturnType(
+ this ISQLFunction sqlFunction,
+ IEnumerable argumentTypes,
+ IMapping mapping,
+ bool throwOnError)
+ {
+ if (!(sqlFunction is ISQLFunctionExtended extendedSqlFunction))
+ {
+ try
+ {
+#pragma warning disable 618
+ return sqlFunction.ReturnType(argumentTypes.FirstOrDefault(), mapping);
+#pragma warning restore 618
+ }
+ catch (QueryException)
+ {
+ if (throwOnError)
+ {
+ throw;
+ }
+
+ return null;
+ }
+ }
+
+ return extendedSqlFunction.GetReturnType(argumentTypes, mapping, throwOnError);
+ }
}
}
diff --git a/src/NHibernate/Dialect/Function/ISQLFunctionExtended.cs b/src/NHibernate/Dialect/Function/ISQLFunctionExtended.cs
index e2db4747198..8534f347c80 100644
--- a/src/NHibernate/Dialect/Function/ISQLFunctionExtended.cs
+++ b/src/NHibernate/Dialect/Function/ISQLFunctionExtended.cs
@@ -12,6 +12,16 @@ internal interface ISQLFunctionExtended : ISQLFunction
///
string FunctionName { get; }
+ ///
+ /// Get the function general return type, ignoring underlying database specifics.
+ ///
+ /// The types of arguments.
+ /// The mapping for retrieving the argument sql types.
+ /// Whether to throw when the number of arguments is invalid or they are not supported.
+ /// The type returned by the underlying database or when the number of arguments
+ /// is invalid or they are not supported.
+ IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError);
+
///
/// Get the type that will be effectively returned by the underlying database.
///
diff --git a/src/NHibernate/Dialect/Function/IifSQLFunction.cs b/src/NHibernate/Dialect/Function/IifSQLFunction.cs
new file mode 100644
index 00000000000..3aee491dc3f
--- /dev/null
+++ b/src/NHibernate/Dialect/Function/IifSQLFunction.cs
@@ -0,0 +1,33 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using NHibernate.Engine;
+using NHibernate.Type;
+
+namespace NHibernate.Dialect.Function
+{
+ [Serializable]
+ internal class IifSQLFunction : SQLFunctionTemplate
+ {
+ public IifSQLFunction() : base(null, "case when ?1 then ?2 else ?3 end")
+ {
+ }
+
+ ///
+ public override IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ var args = argumentTypes.ToList();
+ if (args.Count != 3)
+ {
+ if (throwOnError)
+ {
+ throw new QueryException($"Invalid number of arguments for iif()");
+ }
+
+ return null;
+ }
+
+ return args[1] ?? args[2];
+ }
+ }
+}
diff --git a/src/NHibernate/Dialect/Function/IifSafeSQLFunction.cs b/src/NHibernate/Dialect/Function/IifSafeSQLFunction.cs
new file mode 100644
index 00000000000..7828d68b9aa
--- /dev/null
+++ b/src/NHibernate/Dialect/Function/IifSafeSQLFunction.cs
@@ -0,0 +1,28 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using NHibernate.Engine;
+using NHibernate.Type;
+
+namespace NHibernate.Dialect.Function
+{
+ [Serializable]
+ internal class IifSafeSQLFunction : StandardSafeSQLFunction
+ {
+ public IifSafeSQLFunction() : base("iif", 3)
+ {
+ }
+
+ ///
+ public override IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ var args = argumentTypes.ToList();
+ if (args.Count != 3)
+ {
+ return null; // Not enough information
+ }
+
+ return args[1] ?? args[2];
+ }
+ }
+}
diff --git a/src/NHibernate/Dialect/Function/ModulusFunction.cs b/src/NHibernate/Dialect/Function/ModulusFunction.cs
new file mode 100644
index 00000000000..46b287b4ab4
--- /dev/null
+++ b/src/NHibernate/Dialect/Function/ModulusFunction.cs
@@ -0,0 +1,29 @@
+using System;
+using System.Collections.Generic;
+using NHibernate.Engine;
+using NHibernate.Type;
+
+namespace NHibernate.Dialect.Function
+{
+ [Serializable]
+ internal class ModulusFunction : StandardSafeSQLFunction
+ {
+ private readonly ModulusFunctionTypeDetector _modulusFunctionTypeDetector;
+
+ public ModulusFunction(bool supportDecimals, bool supportFloatingNumbers)
+ : this(new ModulusFunctionTypeDetector(supportDecimals, supportFloatingNumbers))
+ {
+ }
+
+ public ModulusFunction(ModulusFunctionTypeDetector modulusFunction) : base("mod", NHibernateUtil.Int32, 2)
+ {
+ _modulusFunctionTypeDetector = modulusFunction;
+ }
+
+ ///
+ public override IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ return _modulusFunctionTypeDetector.GetReturnType(argumentTypes, mapping, throwOnError);
+ }
+ }
+}
diff --git a/src/NHibernate/Dialect/Function/ModulusFunctionTemplate.cs b/src/NHibernate/Dialect/Function/ModulusFunctionTemplate.cs
new file mode 100644
index 00000000000..881cc7816be
--- /dev/null
+++ b/src/NHibernate/Dialect/Function/ModulusFunctionTemplate.cs
@@ -0,0 +1,28 @@
+using System;
+using System.Collections.Generic;
+using NHibernate.Engine;
+using NHibernate.Type;
+
+namespace NHibernate.Dialect.Function
+{
+ [Serializable]
+ internal class ModulusFunctionTemplate : SQLFunctionTemplate
+ {
+ private readonly ModulusFunctionTypeDetector _modulusFunctionTypeDetector;
+
+ public ModulusFunctionTemplate(bool supportDecimals) : this(new ModulusFunctionTypeDetector(supportDecimals))
+ {
+ }
+
+ public ModulusFunctionTemplate(ModulusFunctionTypeDetector modulusFunction) : base(NHibernateUtil.Int32, "((?1) % (?2))")
+ {
+ _modulusFunctionTypeDetector = modulusFunction;
+ }
+
+ ///
+ public override IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ return _modulusFunctionTypeDetector.GetReturnType(argumentTypes, mapping, throwOnError);
+ }
+ }
+}
diff --git a/src/NHibernate/Dialect/Function/ModulusFunctionTypeDetector.cs b/src/NHibernate/Dialect/Function/ModulusFunctionTypeDetector.cs
new file mode 100644
index 00000000000..eab42090b12
--- /dev/null
+++ b/src/NHibernate/Dialect/Function/ModulusFunctionTypeDetector.cs
@@ -0,0 +1,113 @@
+using System;
+using System.Collections.Generic;
+using System.Data;
+using NHibernate.Engine;
+using NHibernate.SqlTypes;
+using NHibernate.Type;
+
+namespace NHibernate.Dialect.Function
+{
+ [Serializable]
+ internal class ModulusFunctionTypeDetector
+ {
+ // The supported DbTypes with their priorities in order to detect which is the
+ // returned type when mixing them
+ private readonly Lazy>> _supportedDbTypesLazy;
+ private readonly bool _supportDecimals;
+ private readonly bool _supportFloatingNumbers;
+
+ public ModulusFunctionTypeDetector(bool supportDecimals, bool supportFloatingNumbers)
+ {
+ _supportDecimals = supportDecimals;
+ _supportFloatingNumbers = supportFloatingNumbers;
+ _supportedDbTypesLazy = new Lazy>>(GetSupportedTypes);
+ }
+
+ public ModulusFunctionTypeDetector(bool supportDecimals) : this(supportDecimals, false)
+ {
+ }
+
+ protected virtual Dictionary> GetSupportedTypes()
+ {
+ var types = new Dictionary>()
+ {
+ {DbType.Int16, new KeyValuePair(1, NHibernateUtil.Int16)},
+ {DbType.Int32, new KeyValuePair(2, NHibernateUtil.Int32)},
+ {DbType.Int64, new KeyValuePair(3, NHibernateUtil.Int64)},
+ };
+
+ if (_supportDecimals)
+ {
+ types.Add(DbType.Currency, new KeyValuePair(4, NHibernateUtil.Decimal));
+ types.Add(DbType.Decimal, new KeyValuePair(4, NHibernateUtil.Decimal));
+ }
+
+ if (_supportFloatingNumbers)
+ {
+ types.Add(DbType.Single, new KeyValuePair(5, NHibernateUtil.Single));
+ types.Add(DbType.Double, new KeyValuePair(6, NHibernateUtil.Double));
+ }
+
+ return types;
+ }
+
+ public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ KeyValuePair currentReturnType = default;
+ int totalArguments = 0;
+ foreach (var argumentType in argumentTypes)
+ {
+ if (argumentType == null)
+ {
+ return null;
+ }
+
+ SqlType[] sqlTypes;
+ try
+ {
+ sqlTypes = argumentType.SqlTypes(mapping);
+ }
+ catch (MappingException me)
+ {
+ if (throwOnError)
+ {
+ throw new QueryException(me);
+ }
+
+ return null;
+ }
+
+ if (sqlTypes.Length != 1)
+ {
+ return ThrowOrReturnDefault("Multi-column type can not be in mod()", throwOnError);
+ }
+
+ if (!_supportedDbTypesLazy.Value.TryGetValue(sqlTypes[0].DbType, out var returnType))
+ {
+ return ThrowOrReturnDefault($"DbType {sqlTypes[0].DbType} is not supported for mod()", throwOnError);
+ }
+
+ if (returnType.Key > currentReturnType.Key)
+ {
+ currentReturnType = returnType;
+ }
+
+ totalArguments++;
+ }
+
+ return totalArguments == 2
+ ? currentReturnType.Value
+ : ThrowOrReturnDefault("Invalid number of arguments for mod()", throwOnError);
+ }
+
+ private IType ThrowOrReturnDefault(string error, bool throwOnError)
+ {
+ if (throwOnError)
+ {
+ throw new QueryException(error);
+ }
+
+ return null;
+ }
+ }
+}
diff --git a/src/NHibernate/Dialect/Function/NoArgSQLFunction.cs b/src/NHibernate/Dialect/Function/NoArgSQLFunction.cs
index e8e700a9773..1cca8bd6cac 100644
--- a/src/NHibernate/Dialect/Function/NoArgSQLFunction.cs
+++ b/src/NHibernate/Dialect/Function/NoArgSQLFunction.cs
@@ -3,6 +3,8 @@
using NHibernate.SqlCommand;
using NHibernate.Type;
using System;
+using System.Collections.Generic;
+using System.Linq;
namespace NHibernate.Dialect.Function
{
@@ -10,7 +12,7 @@ namespace NHibernate.Dialect.Function
/// Summary description for NoArgSQLFunction.
///
[Serializable]
- public class NoArgSQLFunction : ISQLFunction
+ public class NoArgSQLFunction : ISQLFunction, ISQLFunctionExtended
{
public NoArgSQLFunction(string name, IType returnType)
: this(name, returnType, true)
@@ -19,22 +21,47 @@ public NoArgSQLFunction(string name, IType returnType)
public NoArgSQLFunction(string name, IType returnType, bool hasParenthesesIfNoArguments)
{
+#pragma warning disable 618
Name = name;
+#pragma warning restore 618
FunctionReturnType = returnType;
HasParenthesesIfNoArguments = hasParenthesesIfNoArguments;
}
public IType FunctionReturnType { get; protected set; }
+ // Since v5.3
+ [Obsolete("Use FunctionName property instead.")]
public string Name { get; protected set; }
#region ISQLFunction Members
+ // Since v5.3
+ [Obsolete("Use GetReturnType method instead.")]
public IType ReturnType(IType columnType, IMapping mapping)
{
return FunctionReturnType;
}
+ ///
+ public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+#pragma warning disable 618
+ return ReturnType(argumentTypes.FirstOrDefault(), mapping);
+#pragma warning restore 618
+ }
+
+ ///
+ public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ return GetReturnType(argumentTypes, mapping, throwOnError);
+ }
+
+ ///
+#pragma warning disable 618
+ public string FunctionName => Name;
+#pragma warning restore 618
+
public bool HasArguments
{
get { return false; }
@@ -46,15 +73,15 @@ public virtual SqlString Render(IList args, ISessionFactoryImplementor factory)
{
if (args.Count > 0)
{
- throw new QueryException("function takes no arguments: " + Name);
+ throw new QueryException("function takes no arguments: " + FunctionName);
}
if (HasParenthesesIfNoArguments)
{
- return new SqlString(Name + "()");
+ return new SqlString(FunctionName + "()");
}
- return new SqlString(Name);
+ return new SqlString(FunctionName);
}
#endregion
diff --git a/src/NHibernate/Dialect/Function/NvlFunction.cs b/src/NHibernate/Dialect/Function/NvlFunction.cs
index aa17cf0fbd1..9ffba90f04f 100644
--- a/src/NHibernate/Dialect/Function/NvlFunction.cs
+++ b/src/NHibernate/Dialect/Function/NvlFunction.cs
@@ -1,5 +1,7 @@
using System;
using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
using NHibernate.Engine;
using NHibernate.SqlCommand;
using NHibernate.Type;
@@ -10,7 +12,7 @@ namespace NHibernate.Dialect.Function
/// Emulation of coalesce() on Oracle, using multiple nvl() calls
///
[Serializable]
- public class NvlFunction : ISQLFunction
+ public class NvlFunction : ISQLFunction, ISQLFunctionExtended
{
public NvlFunction()
{
@@ -18,11 +20,30 @@ public NvlFunction()
#region ISQLFunction Members
+ // Since v5.3
+ [Obsolete("Use GetReturnType method instead.")]
public IType ReturnType(IType columnType, IMapping mapping)
{
return columnType;
}
+ ///
+ public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+#pragma warning disable 618
+ return ReturnType(argumentTypes.FirstOrDefault(), mapping);
+#pragma warning restore 618
+ }
+
+ ///
+ public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ return GetReturnType(argumentTypes, mapping, throwOnError);
+ }
+
+ ///
+ public string FunctionName => "nvl";
+
public bool HasArguments
{
get { return true; }
diff --git a/src/NHibernate/Dialect/Function/PositionSubstringFunction.cs b/src/NHibernate/Dialect/Function/PositionSubstringFunction.cs
index 9a89bda39c4..717503d6b12 100644
--- a/src/NHibernate/Dialect/Function/PositionSubstringFunction.cs
+++ b/src/NHibernate/Dialect/Function/PositionSubstringFunction.cs
@@ -1,5 +1,7 @@
using System;
using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
using System.Text;
using Antlr.Runtime;
using NHibernate.Engine;
@@ -12,7 +14,7 @@ namespace NHibernate.Dialect.Function
/// Emulation of locate() on PostgreSQL
///
[Serializable]
- public class PositionSubstringFunction : ISQLFunction
+ public class PositionSubstringFunction : ISQLFunction, ISQLFunctionExtended
{
public PositionSubstringFunction()
{
@@ -20,11 +22,30 @@ public PositionSubstringFunction()
#region ISQLFunction Members
+ // Since v5.3
+ [Obsolete("Use GetReturnType method instead.")]
public IType ReturnType(IType columnType, IMapping mapping)
{
return NHibernateUtil.Int32;
}
+ ///
+ public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+#pragma warning disable 618
+ return ReturnType(argumentTypes.FirstOrDefault(), mapping);
+#pragma warning restore 618
+ }
+
+ ///
+ public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ return GetReturnType(argumentTypes, mapping, throwOnError);
+ }
+
+ ///
+ public string FunctionName => "position";
+
public bool HasArguments
{
get { return true; }
diff --git a/src/NHibernate/Dialect/Function/SQLFunctionTemplate.cs b/src/NHibernate/Dialect/Function/SQLFunctionTemplate.cs
index 976a5f0e127..4b5a7e0fdc7 100644
--- a/src/NHibernate/Dialect/Function/SQLFunctionTemplate.cs
+++ b/src/NHibernate/Dialect/Function/SQLFunctionTemplate.cs
@@ -6,6 +6,8 @@
using NHibernate.SqlCommand;
using NHibernate.Type;
using System;
+using System.Collections.Generic;
+using System.Linq;
namespace NHibernate.Dialect.Function
{
@@ -18,7 +20,7 @@ namespace NHibernate.Dialect.Function
/// parameters with '?' followed by parameter's index (first index is 1).
///
[Serializable]
- public class SQLFunctionTemplate : ISQLFunction
+ public class SQLFunctionTemplate : ISQLFunction, ISQLFunctionExtended
{
private const int InvalidArgumentIndex = -1;
private static readonly Regex SplitRegex = new Regex("(\\?[0-9]+)");
@@ -80,11 +82,30 @@ private void InitFromTemplate()
#region ISQLFunction Members
+ // Since v5.3
+ [Obsolete("Use GetReturnType method instead.")]
public IType ReturnType(IType columnType, IMapping mapping)
{
return (returnType == null) ? columnType : returnType;
}
+ ///
+ public virtual IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+#pragma warning disable 618
+ return ReturnType(argumentTypes.FirstOrDefault(), mapping);
+#pragma warning restore 618
+ }
+
+ ///
+ public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ return GetReturnType(argumentTypes, mapping, throwOnError);
+ }
+
+ ///
+ public virtual string FunctionName => null;
+
public bool HasArguments
{
get { return hasArguments; }
diff --git a/src/NHibernate/Dialect/Function/StandardSQLFunction.cs b/src/NHibernate/Dialect/Function/StandardSQLFunction.cs
index df1adc53f09..279e9089919 100644
--- a/src/NHibernate/Dialect/Function/StandardSQLFunction.cs
+++ b/src/NHibernate/Dialect/Function/StandardSQLFunction.cs
@@ -4,6 +4,8 @@
using NHibernate.SqlCommand;
using NHibernate.Type;
using System;
+using System.Collections.Generic;
+using System.Linq;
namespace NHibernate.Dialect.Function
{
@@ -16,7 +18,7 @@ namespace NHibernate.Dialect.Function
/// for processing of the associated function.
///
[Serializable]
- public class StandardSQLFunction : ISQLFunction
+ public class StandardSQLFunction : ISQLFunction, ISQLFunctionExtended
{
private IType returnType = null;
protected readonly string name;
@@ -43,11 +45,30 @@ public StandardSQLFunction(string name, IType typeValue)
#region ISQLFunction Members
+ // Since v5.3
+ [Obsolete("Use GetReturnType method instead.")]
public virtual IType ReturnType(IType columnType, IMapping mapping)
{
return returnType ?? columnType;
}
+ ///
+ public virtual IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+#pragma warning disable 618
+ return ReturnType(argumentTypes.FirstOrDefault(), mapping);
+#pragma warning restore 618
+ }
+
+ ///
+ public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ return GetReturnType(argumentTypes, mapping, throwOnError);
+ }
+
+ ///
+ public string FunctionName => name;
+
public bool HasArguments
{
get { return true; }
diff --git a/src/NHibernate/Dialect/Function/SumQueryFunctionInfo.cs b/src/NHibernate/Dialect/Function/SumQueryFunctionInfo.cs
index 3ee6c7d6f2e..6c0af160841 100644
--- a/src/NHibernate/Dialect/Function/SumQueryFunctionInfo.cs
+++ b/src/NHibernate/Dialect/Function/SumQueryFunctionInfo.cs
@@ -1,7 +1,7 @@
using System;
+using System.Collections.Generic;
using System.Data;
using NHibernate.Engine;
-using NHibernate.SqlTypes;
using NHibernate.Type;
namespace NHibernate.Dialect.Function
@@ -12,29 +12,21 @@ class SumQueryFunctionInfo : ClassicAggregateFunction
public SumQueryFunctionInfo() : base("sum", false) { }
//H3.2 behavior
+ // Since v5.3
+ [Obsolete("Use GetReturnType method instead.")]
public override IType ReturnType(IType columnType, IMapping mapping)
{
- if (columnType == null)
- {
- throw new ArgumentNullException("columnType");
- }
- SqlType[] sqlTypes;
- try
- {
- sqlTypes = columnType.SqlTypes(mapping);
- }
- catch (MappingException me)
- {
- throw new QueryException(me);
- }
+ return GetReturnType(new[] { columnType }, mapping, true);
+ }
- if (sqlTypes.Length != 1)
+ ///
+ public override IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ if (!TryGetArgumentType(argumentTypes, mapping, throwOnError, out var argumentType, out var sqlType))
{
- throw new QueryException("multi-column type can not be in sum()");
+ return null;
}
- SqlType sqlType = sqlTypes[0];
-
// TODO: (H3.2 for nullable types) First allow the actual type to control the return value. (the actual underlying sqltype could actually be different)
// finally use the sqltype if == on Hibernate types did not find a match.
@@ -57,8 +49,8 @@ public override IType ReturnType(IType columnType, IMapping mapping)
return NHibernateUtil.UInt64;
default:
- return columnType;
+ return argumentType;
}
}
}
-}
\ No newline at end of file
+}
diff --git a/src/NHibernate/Dialect/Function/VarArgsSQLFunction.cs b/src/NHibernate/Dialect/Function/VarArgsSQLFunction.cs
index 457a1b86d8e..64a0aa213da 100644
--- a/src/NHibernate/Dialect/Function/VarArgsSQLFunction.cs
+++ b/src/NHibernate/Dialect/Function/VarArgsSQLFunction.cs
@@ -1,5 +1,7 @@
using System;
using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
using System.Text;
using NHibernate.Engine;
using NHibernate.SqlCommand;
@@ -12,7 +14,7 @@ namespace NHibernate.Dialect.Function
/// with an unlimited number of arguments.
///
[Serializable]
- public class VarArgsSQLFunction : ISQLFunction
+ public class VarArgsSQLFunction : ISQLFunction, ISQLFunctionExtended
{
private readonly string begin;
private readonly string sep;
@@ -34,11 +36,30 @@ public VarArgsSQLFunction(IType type, string begin, string sep, string end)
#region ISQLFunction Members
+ // Since v5.3
+ [Obsolete("Use GetReturnType method instead.")]
public virtual IType ReturnType(IType columnType, IMapping mapping)
{
return (returnType == null) ? columnType : returnType;
}
+ ///
+ public virtual IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+#pragma warning disable 618
+ return ReturnType(argumentTypes.FirstOrDefault(), mapping);
+#pragma warning restore 618
+ }
+
+ ///
+ public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ return GetReturnType(argumentTypes, mapping, throwOnError);
+ }
+
+ ///
+ public virtual string FunctionName => null;
+
public bool HasArguments
{
get { return true; }
diff --git a/src/NHibernate/Dialect/HanaDialectBase.cs b/src/NHibernate/Dialect/HanaDialectBase.cs
index 4a8d70db2e8..99203e74605 100644
--- a/src/NHibernate/Dialect/HanaDialectBase.cs
+++ b/src/NHibernate/Dialect/HanaDialectBase.cs
@@ -1,7 +1,9 @@
using System;
using System.Collections;
+using System.Collections.Generic;
using System.Data;
using System.Data.Common;
+using System.Linq;
using NHibernate.Dialect.Function;
using NHibernate.Dialect.Schema;
using NHibernate.Engine;
@@ -17,7 +19,7 @@ namespace NHibernate.Dialect
public abstract class HanaDialectBase : Dialect
{
[Serializable]
- private class TypeConvertingVarArgsSQLFunction : ISQLFunction
+ private class TypeConvertingVarArgsSQLFunction : ISQLFunction, ISQLFunctionExtended
{
private readonly string _begin;
private readonly string _sep;
@@ -33,12 +35,31 @@ public TypeConvertingVarArgsSQLFunction(string begin, string sep, string end)
#region ISQLFunction Members
+ // Since v5.3
+ [Obsolete("Use GetReturnType method instead.")]
public IType ReturnType(IType columnType, IMapping mapping)
{
_type = columnType.SqlTypes(mapping)[0];
return columnType;
}
+ ///
+ public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+#pragma warning disable 618
+ return ReturnType(argumentTypes.FirstOrDefault(), mapping);
+#pragma warning restore 618
+ }
+
+ ///
+ public IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ return GetReturnType(argumentTypes, mapping, throwOnError);
+ }
+
+ ///
+ public virtual string FunctionName => "cast";
+
public bool HasArguments => true;
public bool HasParenthesesIfNoArguments => true;
@@ -392,7 +413,7 @@ protected virtual void RegisterNHibernateFunctions()
RegisterFunction("ceiling", new StandardSQLFunction("ceil"));
RegisterFunction("chr", new StandardSQLFunction("char", NHibernateUtil.AnsiChar));
RegisterFunction("date", new SQLFunctionTemplate(NHibernateUtil.Date, "to_date(?1)"));
- RegisterFunction("iif", new SQLFunctionTemplate(null, "case when ?1 then ?2 else ?3 end"));
+ RegisterFunction("iif", new IifSQLFunction());
RegisterFunction("sysdate", new NoArgSQLFunction("current_timestamp", NHibernateUtil.DateTime, false));
RegisterFunction("truncate", new SQLFunctionTemplateWithRequiredParameters(null, "floor(?1 * power(10, ?2)) / power(10, ?2)", new object[] { null, "0" }));
RegisterFunction("new_uuid", new NoArgSQLFunction("sysuuid", NHibernateUtil.Guid, false));
@@ -498,7 +519,7 @@ protected virtual void RegisterHANAFunctions()
RegisterFunction("map", new VarArgsSQLFunction("map(", ",", ")"));
RegisterFunction("mimetype", new StandardSQLFunction("mimetype", NHibernateUtil.String));
RegisterFunction("minute", new StandardSQLFunction("minute", NHibernateUtil.Int32));
- RegisterFunction("mod", new StandardSQLFunction("mod", NHibernateUtil.Int32));
+ RegisterFunction("mod", new ModulusFunction(false, false));
RegisterFunction("month", new StandardSQLFunction("month", NHibernateUtil.Int32));
RegisterFunction("monthname", new StandardSQLFunction("monthname", NHibernateUtil.String));
RegisterFunction("months_between", new StandardSQLFunction("months_between", NHibernateUtil.Int32));
diff --git a/src/NHibernate/Dialect/MsSql2000Dialect.cs b/src/NHibernate/Dialect/MsSql2000Dialect.cs
index 35c760e85b3..1aec544ed2d 100644
--- a/src/NHibernate/Dialect/MsSql2000Dialect.cs
+++ b/src/NHibernate/Dialect/MsSql2000Dialect.cs
@@ -313,7 +313,7 @@ protected virtual void RegisterFunctions()
RegisterFunction("ln", new StandardSQLFunction("ln", NHibernateUtil.Double));
RegisterFunction("log", new StandardSQLFunction("log", NHibernateUtil.Double));
RegisterFunction("log10", new StandardSQLFunction("log10", NHibernateUtil.Double));
- RegisterFunction("mod", new SQLFunctionTemplate(NHibernateUtil.Int32, "((?1) % (?2))"));
+ RegisterFunction("mod", new ModulusFunctionTemplate(true));
RegisterFunction("radians", new StandardSQLFunction("radians", NHibernateUtil.Double));
RegisterFunction("rand", new NoArgSQLFunction("rand", NHibernateUtil.Double));
// SQL Server rand returns the same value for each row, unless hacking it with a random seed per row
@@ -351,7 +351,7 @@ protected virtual void RegisterFunctions()
RegisterFunction("ltrim", new StandardSQLFunction("ltrim"));
RegisterFunction("trim", new AnsiTrimEmulationFunction());
- RegisterFunction("iif", new SQLFunctionTemplate(null, "case when ?1 then ?2 else ?3 end"));
+ RegisterFunction("iif", new IifSQLFunction());
RegisterFunction("replace", new StandardSafeSQLFunction("replace", NHibernateUtil.String, 3));
// Casting to CHAR (without specified length) truncates to 30 characters.
diff --git a/src/NHibernate/Dialect/MsSql2012Dialect.cs b/src/NHibernate/Dialect/MsSql2012Dialect.cs
index 8649d6861db..7ec84ef66f9 100644
--- a/src/NHibernate/Dialect/MsSql2012Dialect.cs
+++ b/src/NHibernate/Dialect/MsSql2012Dialect.cs
@@ -53,7 +53,7 @@ public override string QuerySequencesString
protected override void RegisterFunctions()
{
base.RegisterFunctions();
- RegisterFunction("iif", new StandardSafeSQLFunction("iif", 3));
+ RegisterFunction("iif", new IifSafeSQLFunction());
}
public override SqlString GetLimitString(SqlString querySqlString, SqlString offset, SqlString limit)
diff --git a/src/NHibernate/Dialect/MsSqlCeDialect.cs b/src/NHibernate/Dialect/MsSqlCeDialect.cs
index fd0baed402b..86edd426e6e 100644
--- a/src/NHibernate/Dialect/MsSqlCeDialect.cs
+++ b/src/NHibernate/Dialect/MsSqlCeDialect.cs
@@ -191,10 +191,11 @@ protected virtual void RegisterFunctions()
RegisterFunction("lower", new StandardSQLFunction("lower"));
RegisterFunction("trim", new AnsiTrimEmulationFunction());
- RegisterFunction("iif", new SQLFunctionTemplate(null, "case when ?1 then ?2 else ?3 end"));
+ RegisterFunction("iif", new IifSQLFunction());
RegisterFunction("concat", new VarArgsSQLFunction(NHibernateUtil.String, "(", "+", ")"));
- RegisterFunction("mod", new SQLFunctionTemplate(NHibernateUtil.Int32, "((?1) % (?2))"));
+ // Modulo is not supported on real, float, money, and numeric data types
+ RegisterFunction("mod", new ModulusFunctionTemplate(false));
RegisterFunction("round", new StandardSQLFunctionWithRequiredParameters("round", new object[] {null, "0"}));
RegisterFunction("truncate", new StandardSQLFunctionWithRequiredParameters("round", new object[] {null, "0", "1"}));
diff --git a/src/NHibernate/Dialect/MySQLDialect.cs b/src/NHibernate/Dialect/MySQLDialect.cs
index a6caa6d236a..b43c2c632c3 100644
--- a/src/NHibernate/Dialect/MySQLDialect.cs
+++ b/src/NHibernate/Dialect/MySQLDialect.cs
@@ -1,11 +1,15 @@
using System;
+using System.Collections.Generic;
using System.Data;
using System.Data.Common;
+using System.Linq;
using System.Text;
using NHibernate.Dialect.Function;
using NHibernate.Dialect.Schema;
+using NHibernate.Engine;
using NHibernate.SqlCommand;
using NHibernate.SqlTypes;
+using NHibernate.Type;
using NHibernate.Util;
using Environment=NHibernate.Cfg.Environment;
@@ -242,7 +246,8 @@ protected virtual void RegisterKeywords()
protected virtual void RegisterFunctions()
{
- RegisterFunction("iif", new StandardSQLFunction("if"));
+ RegisterFunction("iif", new IfSQLFunction());
+ RegisterFunction("mod", new ModulusFunction(true, true));
RegisterFunction("sign", new StandardSQLFunction("sign", NHibernateUtil.Int32));
@@ -549,5 +554,30 @@ public override long TimestampResolutionInTicks
public override bool SupportsDistributedTransactions => false;
#endregion
+
+ [Serializable]
+ internal class IfSQLFunction : StandardSQLFunction
+ {
+ public IfSQLFunction() : base("if")
+ {
+ }
+
+ ///
+ public override IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ var args = argumentTypes.ToList();
+ if (args.Count != 3)
+ {
+ if (throwOnError)
+ {
+ throw new QueryException($"Invalid number of arguments for iif()");
+ }
+
+ return null;
+ }
+
+ return args[1] ?? args[2];
+ }
+ }
}
}
diff --git a/src/NHibernate/Dialect/Oracle8iDialect.cs b/src/NHibernate/Dialect/Oracle8iDialect.cs
index 749c1f0d056..c484b7f35a2 100644
--- a/src/NHibernate/Dialect/Oracle8iDialect.cs
+++ b/src/NHibernate/Dialect/Oracle8iDialect.cs
@@ -3,6 +3,7 @@
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
+using System.Linq;
using NHibernate.Dialect.Function;
using NHibernate.Dialect.Schema;
using NHibernate.Engine;
@@ -291,7 +292,7 @@ protected virtual void RegisterFunctions()
// Multi-param numeric dialect functions...
RegisterFunction("atan2", new StandardSQLFunction("atan2", NHibernateUtil.Double));
RegisterFunction("log", new StandardSQLFunction("log", NHibernateUtil.Int32));
- RegisterFunction("mod", new StandardSQLFunction("mod", NHibernateUtil.Int32));
+ RegisterFunction("mod", new ModulusFunction(true, true));
RegisterFunction("nvl", new StandardSQLFunction("nvl"));
RegisterFunction("nvl2", new StandardSQLFunction("nvl2"));
RegisterFunction("power", new StandardSQLFunction("power", NHibernateUtil.Double));
@@ -304,7 +305,7 @@ protected virtual void RegisterFunctions()
RegisterFunction("str", new StandardSQLFunction("to_char", NHibernateUtil.String));
RegisterFunction("strguid", new SQLFunctionTemplate(NHibernateUtil.String, "substr(rawtohex(?1), 7, 2) || substr(rawtohex(?1), 5, 2) || substr(rawtohex(?1), 3, 2) || substr(rawtohex(?1), 1, 2) || '-' || substr(rawtohex(?1), 11, 2) || substr(rawtohex(?1), 9, 2) || '-' || substr(rawtohex(?1), 15, 2) || substr(rawtohex(?1), 13, 2) || '-' || substr(rawtohex(?1), 17, 4) || '-' || substr(rawtohex(?1), 21) "));
- RegisterFunction("iif", new SQLFunctionTemplate(null, "case when ?1 then ?2 else ?3 end"));
+ RegisterFunction("iif", new IifSQLFunction());
RegisterFunction("band", new Function.BitwiseFunctionOperation("bitand"));
RegisterFunction("bor", new SQLFunctionTemplate(null, "?1 + ?2 - BITAND(?1, ?2)"));
@@ -577,11 +578,11 @@ public CurrentTimeStamp() : base("current_timestamp", NHibernateUtil.LocalDateTi
public override SqlString Render(IList args, ISessionFactoryImplementor factory)
{
- return new SqlString(Name);
+ return new SqlString(FunctionName);
}
}
[Serializable]
- private class LocateFunction : ISQLFunction
+ private class LocateFunction : ISQLFunction, ISQLFunctionExtended
{
private static readonly ISQLFunction LocateWith2Params = new SQLFunctionTemplate(NHibernateUtil.Int32,
"instr(?2, ?1)");
@@ -591,11 +592,30 @@ private class LocateFunction : ISQLFunction
#region Implementation of ISQLFunction
+ // Since v5.3
+ [Obsolete("Use GetReturnType method instead.")]
public IType ReturnType(IType columnType, IMapping mapping)
{
return NHibernateUtil.Int32;
}
+ ///
+ public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+#pragma warning disable 618
+ return ReturnType(argumentTypes.FirstOrDefault(), mapping);
+#pragma warning restore 618
+ }
+
+ ///
+ public IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ return GetReturnType(argumentTypes, mapping, throwOnError);
+ }
+
+ ///
+ public string FunctionName => "instr";
+
public bool HasArguments
{
get { return true; }
diff --git a/src/NHibernate/Dialect/OracleLiteDialect.cs b/src/NHibernate/Dialect/OracleLiteDialect.cs
index cf5e33914a5..eb5af510652 100644
--- a/src/NHibernate/Dialect/OracleLiteDialect.cs
+++ b/src/NHibernate/Dialect/OracleLiteDialect.cs
@@ -101,7 +101,7 @@ public OracleLiteDialect()
RegisterFunction("translate", new StandardSQLFunction("translate", NHibernateUtil.String));
// Multi-param numeric dialect functions...
- RegisterFunction("mod", new StandardSQLFunction("mod", NHibernateUtil.Int32));
+ RegisterFunction("mod", new ModulusFunction(true, false));
RegisterFunction("nvl", new StandardSQLFunction("nvl"));
// Multi-param date dialect functions...
diff --git a/src/NHibernate/Dialect/PostgreSQLDialect.cs b/src/NHibernate/Dialect/PostgreSQLDialect.cs
index baa9334d788..44bbdaff493 100644
--- a/src/NHibernate/Dialect/PostgreSQLDialect.cs
+++ b/src/NHibernate/Dialect/PostgreSQLDialect.cs
@@ -1,7 +1,9 @@
using System;
using System.Collections;
+using System.Collections.Generic;
using System.Data;
using System.Data.Common;
+using System.Linq;
using NHibernate.Dialect.Function;
using NHibernate.Dialect.Schema;
using NHibernate.Engine;
@@ -63,10 +65,10 @@ public PostgreSQLDialect()
RegisterFunction("current_timestamp", new NoArgSQLFunction("now", NHibernateUtil.LocalDateTime, true));
RegisterFunction("str", new SQLFunctionTemplate(NHibernateUtil.String, "cast(?1 as varchar)"));
RegisterFunction("locate", new PositionSubstringFunction());
- RegisterFunction("iif", new SQLFunctionTemplate(null, "case when ?1 then ?2 else ?3 end"));
+ RegisterFunction("iif", new IifSQLFunction());
RegisterFunction("replace", new StandardSQLFunction("replace", NHibernateUtil.String));
RegisterFunction("left", new SQLFunctionTemplate(NHibernateUtil.String, "substr(?1,1,?2)"));
- RegisterFunction("mod", new SQLFunctionTemplate(NHibernateUtil.Int32, "((?1) % (?2))"));
+ RegisterFunction("mod", new ModulusFunctionTemplate(true));
RegisterFunction("sign", new StandardSQLFunction("sign", NHibernateUtil.Int32));
RegisterFunction("round", new RoundFunction(false));
@@ -347,7 +349,7 @@ public override string CurrentTimestampSelectString
#endregion
[Serializable]
- private class RoundFunction : ISQLFunction
+ private class RoundFunction : ISQLFunction, ISQLFunctionExtended
{
private static readonly ISQLFunction Round = new StandardSQLFunction("round");
private static readonly ISQLFunction Truncate = new StandardSQLFunction("trunc");
@@ -361,7 +363,7 @@ private class RoundFunction : ISQLFunction
private readonly ISQLFunction _singleParamFunction;
private readonly ISQLFunction _twoParamFunction;
- private readonly string _name;
+ private readonly string _name; // TODO 6.0: convert FunctionName to read-only auto property
public RoundFunction(bool truncate)
{
@@ -379,8 +381,27 @@ public RoundFunction(bool truncate)
}
}
+ // Since v5.3
+ [Obsolete("Use GetReturnType method instead.")]
public IType ReturnType(IType columnType, IMapping mapping) => columnType;
+ ///
+ public IType GetReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+#pragma warning disable 618
+ return ReturnType(argumentTypes.FirstOrDefault(), mapping);
+#pragma warning restore 618
+ }
+
+ ///
+ public IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError)
+ {
+ return GetReturnType(argumentTypes, mapping, throwOnError);
+ }
+
+ ///
+ public string FunctionName => _name;
+
public bool HasArguments => true;
public bool HasParenthesesIfNoArguments => true;
diff --git a/src/NHibernate/Dialect/SQLiteDialect.cs b/src/NHibernate/Dialect/SQLiteDialect.cs
index 5eb1f51dd87..fa71342185f 100644
--- a/src/NHibernate/Dialect/SQLiteDialect.cs
+++ b/src/NHibernate/Dialect/SQLiteDialect.cs
@@ -110,9 +110,9 @@ protected virtual void RegisterFunctions()
RegisterFunction("replace", new StandardSafeSQLFunction("replace", NHibernateUtil.String, 3));
RegisterFunction("chr", new StandardSQLFunction("char", NHibernateUtil.Character));
- RegisterFunction("mod", new SQLFunctionTemplate(NHibernateUtil.Int32, "((?1) % (?2))"));
+ RegisterFunction("mod", new ModulusFunctionTemplate(false));
- RegisterFunction("iif", new SQLFunctionTemplate(null, "case when ?1 then ?2 else ?3 end"));
+ RegisterFunction("iif", new IifSQLFunction());
RegisterFunction("round", new StandardSQLFunction("round"));
diff --git a/src/NHibernate/Dialect/SybaseASE15Dialect.cs b/src/NHibernate/Dialect/SybaseASE15Dialect.cs
index a5233395d36..e45b1f7700a 100644
--- a/src/NHibernate/Dialect/SybaseASE15Dialect.cs
+++ b/src/NHibernate/Dialect/SybaseASE15Dialect.cs
@@ -93,7 +93,7 @@ public SybaseASE15Dialect()
RegisterFunction("lower", new StandardSQLFunction("lower"));
RegisterFunction("ltrim", new StandardSQLFunction("ltrim"));
RegisterFunction("minute", new SQLFunctionTemplate(NHibernateUtil.Int32, "datepart(minute, ?1)"));
- RegisterFunction("mod", new SQLFunctionTemplate(NHibernateUtil.Int32, "?1 % ?2"));
+ RegisterFunction("mod", new ModulusFunctionTemplate(false));
RegisterFunction("month", new StandardSQLFunction("month", NHibernateUtil.Int32));
RegisterFunction("pi", new NoArgSQLFunction("pi", NHibernateUtil.Double));
RegisterFunction("radians", new StandardSQLFunction("radians", NHibernateUtil.Double));
diff --git a/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs b/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs
index f5b61250049..290e626671f 100644
--- a/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs
+++ b/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs
@@ -137,7 +137,7 @@ protected virtual void RegisterMathFunctions()
RegisterFunction("floor", new StandardSQLFunction("floor", NHibernateUtil.Double));
RegisterFunction("log", new StandardSQLFunction("log", NHibernateUtil.Double));
RegisterFunction("log10", new StandardSQLFunction("log10", NHibernateUtil.Double));
- RegisterFunction("mod", new StandardSQLFunction("mod"));
+ RegisterFunction("mod", new ModulusFunction(false, false));
RegisterFunction("pi", new NoArgSQLFunction("pi", NHibernateUtil.Double, true));
RegisterFunction("power", new StandardSQLFunction("power", NHibernateUtil.Double));
RegisterFunction("radians", new StandardSQLFunction("radians", NHibernateUtil.Double));
@@ -355,7 +355,7 @@ protected virtual void RegisterMiscellaneousFunctions()
RegisterFunction("transactsql", new StandardSQLFunction("transactsql", NHibernateUtil.String));
RegisterFunction("varexists", new StandardSQLFunction("varexists", NHibernateUtil.Int32));
RegisterFunction("watcomsql", new StandardSQLFunction("watcomsql", NHibernateUtil.String));
- RegisterFunction("iif", new SQLFunctionTemplate(null, "case when ?1 then ?2 else ?3 end"));
+ RegisterFunction("iif", new IifSQLFunction());
}
#region private static readonly string[] DialectKeywords = { ... }
diff --git a/src/NHibernate/Hql/Ast/ANTLR/SessionFactoryHelperExtensions.cs b/src/NHibernate/Hql/Ast/ANTLR/SessionFactoryHelperExtensions.cs
index b88f22ec0a6..33d6103ec86 100644
--- a/src/NHibernate/Hql/Ast/ANTLR/SessionFactoryHelperExtensions.cs
+++ b/src/NHibernate/Hql/Ast/ANTLR/SessionFactoryHelperExtensions.cs
@@ -1,4 +1,6 @@
using System;
+using System.Collections.Generic;
+using System.Linq;
using NHibernate.Dialect.Function;
using NHibernate.Engine;
using NHibernate.Hql.Ast.ANTLR.Tree;
@@ -9,6 +11,7 @@
using NHibernate.Type;
using NHibernate.Util;
using IASTNode=NHibernate.Hql.Ast.ANTLR.Tree.IASTNode;
+using IQueryable = NHibernate.Persister.Entity.IQueryable;
namespace NHibernate.Hql.Ast.ANTLR
{
@@ -67,6 +70,8 @@ private ISQLFunction RequireSQLFunction(string functionName)
/// The function name.
/// The first argument expression.
/// the function return type given the function name and the first argument expression node.
+ // Since v5.3
+ [Obsolete("Please use overload with a IEnumerable parameter instead.")]
public IType FindFunctionReturnType(String functionName, IASTNode first)
{
// locate the registered function by the given name
@@ -90,6 +95,44 @@ public IType FindFunctionReturnType(String functionName, IASTNode first)
return sqlFunction.ReturnType(argumentType, _sfi);
}
+ ///
+ /// Find the function return type given the function name and the arguments expression nodes.
+ ///
+ /// The function name.
+ /// The function arguments expression nodes.
+ /// The function return type given the function name and the arguments expression nodes.
+ public IType FindFunctionReturnType(string functionName, IEnumerable arguments)
+ {
+ var sqlFunction = RequireSQLFunction(functionName);
+ if (!(sqlFunction is ISQLFunctionExtended extendedSqlFunction))
+ {
+#pragma warning disable 618
+ return FindFunctionReturnType(functionName, arguments.FirstOrDefault());
+#pragma warning restore 618
+ }
+
+ var argumentTypes = new List();
+ if (sqlFunction is CastFunction)
+ {
+ argumentTypes.Add(TypeFactory.HeuristicType(arguments.First().NextSibling.Text));
+ }
+ else
+ {
+ foreach (var argument in arguments)
+ {
+ IType type = null;
+ if (argument is SqlNode sqlNode)
+ {
+ type = sqlNode.DataType;
+ }
+
+ argumentTypes.Add(type);
+ }
+ }
+
+ return extendedSqlFunction.GetReturnType(argumentTypes, _sfi, true);
+ }
+
///
/// Given a (potentially unqualified) class name, locate its imported qualified name.
///
diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs
index a3887a56c01..d330103082a 100644
--- a/src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs
+++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs
@@ -1,4 +1,5 @@
using System;
+using System.Collections.Generic;
using Antlr.Runtime;
using NHibernate.Dialect.Function;
using NHibernate.Type;
@@ -38,7 +39,7 @@ public override IType DataType
get
{
// Get the function return value type, based on the type of the first argument.
- return SessionFactoryHelper.FindFunctionReturnType(Text, GetChild(0));
+ return SessionFactoryHelper.FindFunctionReturnType(Text, (IEnumerable) this);
}
set
{
diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs
index ef6d6272043..e381c921bbe 100644
--- a/src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs
+++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs
@@ -1,4 +1,5 @@
-using Antlr.Runtime;
+using System.Linq;
+using Antlr.Runtime;
using NHibernate.Dialect.Function;
using NHibernate.Hql.Ast.ANTLR.Util;
using NHibernate.Type;
@@ -20,7 +21,7 @@ public override IType DataType
{
get
{
- return SessionFactoryHelper.FindFunctionReturnType(Text, null);
+ return SessionFactoryHelper.FindFunctionReturnType(Text, Enumerable.Empty());
}
set
{
diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/IdentNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/IdentNode.cs
index 38bb20080bd..095f8b03545 100644
--- a/src/NHibernate/Hql/Ast/ANTLR/Tree/IdentNode.cs
+++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/IdentNode.cs
@@ -1,13 +1,14 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using Antlr.Runtime;
using NHibernate.Dialect.Function;
using NHibernate.Hql.Ast.ANTLR.Util;
using NHibernate.Persister.Collection;
-using NHibernate.Persister.Entity;
using NHibernate.SqlCommand;
using NHibernate.Type;
using NHibernate.Util;
+using IQueryable = NHibernate.Persister.Entity.IQueryable;
namespace NHibernate.Hql.Ast.ANTLR.Tree
{
@@ -38,7 +39,7 @@ public override IType DataType
return fe.DataType;
}
ISQLFunction sf = Walker.SessionFactoryHelper.FindSQLFunction(Text);
- return sf?.ReturnType(null, Walker.SessionFactoryHelper.Factory);
+ return sf?.GetReturnType(Enumerable.Empty(), Walker.SessionFactoryHelper.Factory, true);
}
set
diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/MethodNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/MethodNode.cs
index f36ac08ff69..dc6c04042e7 100644
--- a/src/NHibernate/Hql/Ast/ANTLR/Tree/MethodNode.cs
+++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/MethodNode.cs
@@ -1,4 +1,5 @@
using System;
+using System.Collections.Generic;
using Antlr.Runtime;
using NHibernate.Dialect.Function;
@@ -194,14 +195,7 @@ private void DialectFunction(IASTNode exprList)
if (_function != null)
{
- IASTNode child = null;
-
- if (exprList != null)
- {
- child = _methodName == "iif" ? exprList.GetChild(1) : exprList.GetChild(0);
- }
-
- DataType = SessionFactoryHelper.FindFunctionReturnType(_methodName, child);
+ DataType = SessionFactoryHelper.FindFunctionReturnType(_methodName, (IEnumerable) exprList);
}
//TODO:
/*else {