diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH2704/FixtureByCode.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH2704/FixtureByCode.cs
new file mode 100644
index 00000000000..ca7c8a1a16d
--- /dev/null
+++ b/src/NHibernate.Test/Async/NHSpecificTest/GH2704/FixtureByCode.cs
@@ -0,0 +1,82 @@
+//------------------------------------------------------------------------------
+//
+// This code was generated by AsyncGenerator.
+//
+// Changes to this file may cause incorrect behavior and will be lost if
+// the code is regenerated.
+//
+//------------------------------------------------------------------------------
+
+
+using System.Linq;
+using NHibernate.Cfg.MappingSchema;
+using NHibernate.Mapping.ByCode;
+using NUnit.Framework;
+using NHibernate.Linq;
+
+namespace NHibernate.Test.NHSpecificTest.GH2704
+{
+ using System.Threading.Tasks;
+ [TestFixture]
+ public class EnhancedUserTypeFixtureAsync : TestCaseMappingByCode
+ {
+ protected override HbmMapping GetMappings()
+ {
+ var mapper = new ModelMapper();
+
+ mapper.AddMapping();
+ return mapper.CompileMappingForAllExplicitlyAddedEntities();
+ }
+
+ protected override void OnSetUp()
+ {
+ using (var session = OpenSession())
+ using (var transaction = session.BeginTransaction())
+ {
+ session.Save(new Entity1() {Id = "id1", IsChiusa = true});
+ session.Save(new Entity1() {Id = "id2", IsChiusa = false});
+ transaction.Commit();
+ }
+ }
+
+ protected override void OnTearDown()
+ {
+ using (var session = OpenSession())
+ using (var transaction = session.BeginTransaction())
+ {
+ session.CreateQuery("delete from System.Object").ExecuteUpdate();
+ transaction.Commit();
+ }
+ }
+
+ [Test]
+ public async Task CompareWithConstantAsync()
+ {
+ var yes = true;
+ using (var s = OpenSession())
+ Assert.IsTrue(await (s.Query().Where(x => x.IsChiusa == yes).AnyAsync()));
+ }
+
+ [Test]
+ public async Task NotOnPropertyAsync()
+ {
+ using (var s = OpenSession())
+ Assert.IsTrue(await (s.Query().Where(x => !x.IsChiusa).AllAsync(x => !x.IsChiusa)));
+ }
+
+ [Test]
+ public async Task CompareWithInlineConstantAsync()
+ {
+ using (var s = OpenSession())
+ Assert.IsTrue(await (s.Query().Where(x => x.IsChiusa == false).AnyAsync()));
+ }
+
+ [Test]
+ public async Task CompareWithNotOnConstantAsync()
+ {
+ var no = false;
+ using (var s = OpenSession())
+ Assert.IsTrue(await (s.Query().Where(x => x.IsChiusa == !no).AnyAsync()));
+ }
+ }
+}
diff --git a/src/NHibernate.Test/NHSpecificTest/GH2704/Entity.cs b/src/NHibernate.Test/NHSpecificTest/GH2704/Entity.cs
new file mode 100644
index 00000000000..fb83ccf6766
--- /dev/null
+++ b/src/NHibernate.Test/NHSpecificTest/GH2704/Entity.cs
@@ -0,0 +1,22 @@
+using System;
+using NHibernate.Mapping.ByCode.Conformist;
+
+namespace NHibernate.Test.NHSpecificTest.GH2704
+{
+ public class Entity1
+ {
+ public virtual string Id { get; set; }
+ public virtual bool IsChiusa { get; set; }
+ }
+
+ class Entity1Map : ClassMapping
+ {
+ public Entity1Map()
+ {
+ Table("TA");
+
+ Id(x => x.Id);
+ Property(x => x.IsChiusa, m => m.Type());
+ }
+ }
+}
diff --git a/src/NHibernate.Test/NHSpecificTest/GH2704/FixtureByCode.cs b/src/NHibernate.Test/NHSpecificTest/GH2704/FixtureByCode.cs
new file mode 100644
index 00000000000..92bc0267d26
--- /dev/null
+++ b/src/NHibernate.Test/NHSpecificTest/GH2704/FixtureByCode.cs
@@ -0,0 +1,70 @@
+using System.Linq;
+using NHibernate.Cfg.MappingSchema;
+using NHibernate.Mapping.ByCode;
+using NUnit.Framework;
+
+namespace NHibernate.Test.NHSpecificTest.GH2704
+{
+ [TestFixture]
+ public class EnhancedUserTypeFixture : TestCaseMappingByCode
+ {
+ protected override HbmMapping GetMappings()
+ {
+ var mapper = new ModelMapper();
+
+ mapper.AddMapping();
+ return mapper.CompileMappingForAllExplicitlyAddedEntities();
+ }
+
+ protected override void OnSetUp()
+ {
+ using (var session = OpenSession())
+ using (var transaction = session.BeginTransaction())
+ {
+ session.Save(new Entity1() {Id = "id1", IsChiusa = true});
+ session.Save(new Entity1() {Id = "id2", IsChiusa = false});
+ transaction.Commit();
+ }
+ }
+
+ protected override void OnTearDown()
+ {
+ using (var session = OpenSession())
+ using (var transaction = session.BeginTransaction())
+ {
+ session.CreateQuery("delete from System.Object").ExecuteUpdate();
+ transaction.Commit();
+ }
+ }
+
+ [Test]
+ public void CompareWithConstant()
+ {
+ var yes = true;
+ using (var s = OpenSession())
+ Assert.IsTrue(s.Query().Where(x => x.IsChiusa == yes).Any());
+ }
+
+ [Test]
+ public void NotOnProperty()
+ {
+ using (var s = OpenSession())
+ Assert.IsTrue(s.Query().Where(x => !x.IsChiusa).All(x => !x.IsChiusa));
+ }
+
+ [Test]
+ public void CompareWithInlineConstant()
+ {
+ using (var s = OpenSession())
+ Assert.IsTrue(s.Query().Where(x => x.IsChiusa == false).Any());
+ }
+
+ [Test]
+ public void CompareWithNotOnConstant()
+ {
+ var no = false;
+ using (var s = OpenSession())
+ Assert.IsTrue(s.Query().Where(x => x.IsChiusa == !no).Any());
+ }
+ }
+}
diff --git a/src/NHibernate.Test/NHSpecificTest/GH2704/StringBoolToBoolUserType.cs b/src/NHibernate.Test/NHSpecificTest/GH2704/StringBoolToBoolUserType.cs
new file mode 100644
index 00000000000..e7101d8e099
--- /dev/null
+++ b/src/NHibernate.Test/NHSpecificTest/GH2704/StringBoolToBoolUserType.cs
@@ -0,0 +1,48 @@
+using System.Data;
+using System.Data.Common;
+using NHibernate.Engine;
+using NHibernate.SqlTypes;
+using NHibernate.UserTypes;
+
+namespace NHibernate.Test.NHSpecificTest.GH2704
+{
+ public class StringBoolToBoolUserType : IEnhancedUserType
+ {
+ public object Assemble(object cached, object owner) => cached;
+
+ public bool IsMutable => false;
+ public object DeepCopy(object value) => value;
+ public object Disassemble(object value) => value;
+ public object Replace(object original, object target, object owner) => original;
+
+ public object FromXMLString(string xml) => xml;
+ public string ToXMLString(object value) => ((bool) value) ? "'S'" : "'N'";
+ public string ObjectToSQLString(object value) => ((bool) value) ? "'S'" : "'N'";
+
+ bool IUserType.Equals(object x, object y) => x == null ? false : x.Equals(y);
+ public int GetHashCode(object x) => x == null ? typeof(bool).GetHashCode() + 473 : x.GetHashCode();
+
+ public object NullSafeGet(DbDataReader rs, string[] names, ISessionImplementor session, object owner)
+ {
+ var value = NHibernateUtil.String.NullSafeGet(rs, names[0], session);
+ if (value == null) return false;
+
+ return (string) value == "S";
+ }
+
+ public void NullSafeSet(DbCommand cmd, object value, int index, ISessionImplementor session)
+ {
+ if (value == null)
+ {
+ NHibernateUtil.String.NullSafeSet(cmd, null, index, session);
+ return;
+ }
+
+ value = (bool) value ? "S" : "N";
+ NHibernateUtil.String.NullSafeSet(cmd, value, index, session);
+ }
+
+ public System.Type ReturnedType => typeof(bool);
+ public SqlType[] SqlTypes => new SqlType[] {new SqlType(DbType.String)};
+ }
+}
diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/CaseNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/CaseNode.cs
index a60f08e7ad9..c48fd3aa248 100644
--- a/src/NHibernate/Hql/Ast/ANTLR/Tree/CaseNode.cs
+++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/CaseNode.cs
@@ -1,4 +1,6 @@
using System;
+using System.Collections.Generic;
+using System.Linq;
using Antlr.Runtime;
using NHibernate.Hql.Ast.ANTLR.Util;
using NHibernate.Type;
@@ -12,8 +14,10 @@ namespace NHibernate.Hql.Ast.ANTLR.Tree
/// Ported by: Steve Strong
///
[CLSCompliant(false)]
- public class CaseNode : AbstractSelectExpression, ISelectExpression
+ public class CaseNode : AbstractSelectExpression, ISelectExpression, IExpectedTypeAwareNode
{
+ private IType _expectedType;
+
public CaseNode(IToken token) : base(token)
{
}
@@ -22,46 +26,58 @@ public override IType DataType
{
get
{
- for (int i = 0; i < ChildCount; i++)
+ if (ExpectedType != null)
+ return ExpectedType;
+
+ foreach (var node in GetResultNodes())
{
- IASTNode whenOrElseClause = GetChild(i);
- if (whenOrElseClause.Type == HqlParser.WHEN)
- {
- // WHEN Child(0) THEN Child(1)
- IASTNode thenClause = whenOrElseClause.GetChild(1);
- if (thenClause is ISelectExpression)
- {
- if (!(thenClause is ParameterNode))
- {
- return (thenClause as ISelectExpression).DataType;
- }
- }
- }
- else if (whenOrElseClause.Type == HqlParser.ELSE)
- {
- // ELSE Child(0)
- IASTNode elseClause = whenOrElseClause.GetChild(0);
- if (elseClause is ISelectExpression)
- {
- if (!(elseClause is ParameterNode))
- {
- return (elseClause as ISelectExpression).DataType;
- }
- }
- }
- else
- {
- throw new HibernateException("Was expecting a WHEN or ELSE, but found a: " + whenOrElseClause.Text);
- }
+ if (node is ISelectExpression select && !(node is ParameterNode))
+ return select.DataType;
}
+
throw new HibernateException("Unable to determine data type of CASE statement.");
}
set { base.DataType = value; }
}
+ public IEnumerable GetResultNodes()
+ {
+ for (int i = 0; i < ChildCount; i++)
+ {
+ IASTNode whenOrElseClause = GetChild(i);
+ if (whenOrElseClause.Type == HqlParser.WHEN)
+ {
+ // WHEN Child(0) THEN Child(1)
+ yield return whenOrElseClause.GetChild(1);
+ }
+ else if (whenOrElseClause.Type == HqlParser.ELSE)
+ {
+ // ELSE Child(0)
+ yield return whenOrElseClause.GetChild(0);
+ }
+ else
+ {
+ throw new HibernateException("Was expecting a WHEN or ELSE, but found a: " + whenOrElseClause.Text);
+ }
+ }
+ }
+
public override void SetScalarColumnText(int i)
{
ColumnHelper.GenerateSingleScalarColumn(ASTFactory, this, i );
}
+
+ public IType ExpectedType
+ {
+ get => _expectedType;
+ set
+ {
+ _expectedType = value;
+ foreach (var node in GetResultNodes().OfType())
+ {
+ node.ExpectedType = ExpectedType;
+ }
+ }
+ }
}
}