Skip to content

Commit ccc90fe

Browse files
committed
Propagate ExpectedType to CaseNode in hql
1 parent 5d374ab commit ccc90fe

File tree

5 files changed

+274
-31
lines changed

5 files changed

+274
-31
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
//------------------------------------------------------------------------------
2+
// <auto-generated>
3+
// This code was generated by AsyncGenerator.
4+
//
5+
// Changes to this file may cause incorrect behavior and will be lost if
6+
// the code is regenerated.
7+
// </auto-generated>
8+
//------------------------------------------------------------------------------
9+
10+
11+
using System.Linq;
12+
using NHibernate.Cfg.MappingSchema;
13+
using NHibernate.Mapping.ByCode;
14+
using NUnit.Framework;
15+
using NHibernate.Linq;
16+
17+
namespace NHibernate.Test.NHSpecificTest.GH2704
18+
{
19+
using System.Threading.Tasks;
20+
[TestFixture]
21+
public class EnhancedUserTypeFixtureAsync : TestCaseMappingByCode
22+
{
23+
protected override HbmMapping GetMappings()
24+
{
25+
var mapper = new ModelMapper();
26+
27+
mapper.AddMapping<Entity1Map>();
28+
return mapper.CompileMappingForAllExplicitlyAddedEntities();
29+
}
30+
31+
protected override void OnSetUp()
32+
{
33+
using (var session = OpenSession())
34+
using (var transaction = session.BeginTransaction())
35+
{
36+
session.CreateSQLQuery(
37+
@"insert into TA (id,ischiusa) values ('id1','S');
38+
insert into TA (id,ischiusa) values ('id2','N');").ExecuteUpdate();
39+
transaction.Commit();
40+
}
41+
}
42+
43+
protected override void OnTearDown()
44+
{
45+
using (var session = OpenSession())
46+
using (var transaction = session.BeginTransaction())
47+
{
48+
session.CreateQuery("delete from System.Object").ExecuteUpdate();
49+
transaction.Commit();
50+
}
51+
}
52+
53+
[Test]
54+
public async Task CompareWithConstantAsync()
55+
{
56+
var yes = true;
57+
using (var s = OpenSession())
58+
Assert.IsTrue(await (s.Query<Entity1>().Where(x => x.IsChiusa == yes).AnyAsync()));
59+
}
60+
61+
[Test]
62+
public async Task NotOnPropertyAsync()
63+
{
64+
using (var s = OpenSession())
65+
Assert.IsTrue(await (s.Query<Entity1>().Where(x => !x.IsChiusa).AllAsync(x => !x.IsChiusa)));
66+
}
67+
68+
[Test]
69+
public async Task CompareWithInlineConstantAsync()
70+
{
71+
using (var s = OpenSession())
72+
Assert.IsTrue(await (s.Query<Entity1>().Where(x => x.IsChiusa == false).AnyAsync()));
73+
}
74+
75+
[Test]
76+
public async Task CompareWithNotOnConstantAsync()
77+
{
78+
var no = false;
79+
using (var s = OpenSession())
80+
Assert.IsTrue(await (s.Query<Entity1>().Where(x => x.IsChiusa == !no).AnyAsync()));
81+
}
82+
}
83+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using System;
2+
using NHibernate.Mapping.ByCode.Conformist;
3+
4+
namespace NHibernate.Test.NHSpecificTest.GH2704
5+
{
6+
public class Entity1
7+
{
8+
public virtual string Id { get; set; }
9+
public virtual bool IsChiusa { get; set; }
10+
}
11+
12+
class Entity1Map : ClassMapping<Entity1>
13+
{
14+
public Entity1Map()
15+
{
16+
Table("TA");
17+
18+
Id(x => x.Id);
19+
Property(x => x.IsChiusa, m => m.Type<StringBoolToBoolUserType>());
20+
}
21+
}
22+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
using System.Linq;
2+
using NHibernate.Cfg.MappingSchema;
3+
using NHibernate.Mapping.ByCode;
4+
using NUnit.Framework;
5+
6+
namespace NHibernate.Test.NHSpecificTest.GH2704
7+
{
8+
[TestFixture]
9+
public class EnhancedUserTypeFixture : TestCaseMappingByCode
10+
{
11+
protected override HbmMapping GetMappings()
12+
{
13+
var mapper = new ModelMapper();
14+
15+
mapper.AddMapping<Entity1Map>();
16+
return mapper.CompileMappingForAllExplicitlyAddedEntities();
17+
}
18+
19+
protected override void OnSetUp()
20+
{
21+
using (var session = OpenSession())
22+
using (var transaction = session.BeginTransaction())
23+
{
24+
session.CreateSQLQuery(
25+
@"insert into TA (id,ischiusa) values ('id1','S');
26+
insert into TA (id,ischiusa) values ('id2','N');").ExecuteUpdate();
27+
transaction.Commit();
28+
}
29+
}
30+
31+
protected override void OnTearDown()
32+
{
33+
using (var session = OpenSession())
34+
using (var transaction = session.BeginTransaction())
35+
{
36+
session.CreateQuery("delete from System.Object").ExecuteUpdate();
37+
transaction.Commit();
38+
}
39+
}
40+
41+
[Test]
42+
public void CompareWithConstant()
43+
{
44+
var yes = true;
45+
using (var s = OpenSession())
46+
Assert.IsTrue(s.Query<Entity1>().Where(x => x.IsChiusa == yes).Any());
47+
}
48+
49+
[Test]
50+
public void NotOnProperty()
51+
{
52+
using (var s = OpenSession())
53+
Assert.IsTrue(s.Query<Entity1>().Where(x => !x.IsChiusa).All(x => !x.IsChiusa));
54+
}
55+
56+
[Test]
57+
public void CompareWithInlineConstant()
58+
{
59+
using (var s = OpenSession())
60+
Assert.IsTrue(s.Query<Entity1>().Where(x => x.IsChiusa == false).Any());
61+
}
62+
63+
[Test]
64+
public void CompareWithNotOnConstant()
65+
{
66+
var no = false;
67+
using (var s = OpenSession())
68+
Assert.IsTrue(s.Query<Entity1>().Where(x => x.IsChiusa == !no).Any());
69+
}
70+
}
71+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
using System.Data;
2+
using System.Data.Common;
3+
using NHibernate.Engine;
4+
using NHibernate.SqlTypes;
5+
using NHibernate.UserTypes;
6+
7+
namespace NHibernate.Test.NHSpecificTest.GH2704
8+
{
9+
public class StringBoolToBoolUserType
10+
: IEnhancedUserType
11+
{
12+
public object Assemble(object cached, object owner) => cached;
13+
14+
public bool IsMutable => false;
15+
public object DeepCopy(object value) => value;
16+
public object Disassemble(object value) => value;
17+
public object Replace(object original, object target, object owner) => original;
18+
19+
public object FromXMLString(string xml) => xml;
20+
public string ToXMLString(object value) => ((bool) value) ? "'S'" : "'N'";
21+
public string ObjectToSQLString(object value) => ((bool) value) ? "'S'" : "'N'";
22+
23+
bool IUserType.Equals(object x, object y) => x == null ? false : x.Equals(y);
24+
public int GetHashCode(object x) => x == null ? typeof(bool).GetHashCode() + 473 : x.GetHashCode();
25+
26+
public object NullSafeGet(DbDataReader rs, string[] names, ISessionImplementor session, object owner)
27+
{
28+
29+
var value = NHibernateUtil.String.NullSafeGet(rs, names[0], session);
30+
if (value == null) return false;
31+
32+
return (string) value == "S";
33+
}
34+
35+
public void NullSafeSet(DbCommand cmd, object value, int index, ISessionImplementor session)
36+
{
37+
38+
if (value == null)
39+
{
40+
NHibernateUtil.String.NullSafeSet(cmd, null, index, session);
41+
return;
42+
}
43+
44+
value = (bool) value ? "S" : "N";
45+
NHibernateUtil.String.NullSafeSet(cmd, value, index, session);
46+
}
47+
48+
public System.Type ReturnedType => typeof(bool);
49+
public SqlType[] SqlTypes => new SqlType[] {new SqlType(DbType.String)};
50+
}
51+
}
+47-31
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
24
using Antlr.Runtime;
35
using NHibernate.Hql.Ast.ANTLR.Util;
46
using NHibernate.Type;
@@ -12,8 +14,10 @@ namespace NHibernate.Hql.Ast.ANTLR.Tree
1214
/// Ported by: Steve Strong
1315
/// </summary>
1416
[CLSCompliant(false)]
15-
public class CaseNode : AbstractSelectExpression, ISelectExpression
17+
public class CaseNode : AbstractSelectExpression, ISelectExpression, IExpectedTypeAwareNode
1618
{
19+
private IType _expectedType;
20+
1721
public CaseNode(IToken token) : base(token)
1822
{
1923
}
@@ -22,46 +26,58 @@ public override IType DataType
2226
{
2327
get
2428
{
25-
for (int i = 0; i < ChildCount; i++)
29+
if (ExpectedType != null)
30+
return ExpectedType;
31+
32+
foreach (var node in GetResultNodes())
2633
{
27-
IASTNode whenOrElseClause = GetChild(i);
28-
if (whenOrElseClause.Type == HqlParser.WHEN)
29-
{
30-
// WHEN Child(0) THEN Child(1)
31-
IASTNode thenClause = whenOrElseClause.GetChild(1);
32-
if (thenClause is ISelectExpression)
33-
{
34-
if (!(thenClause is ParameterNode))
35-
{
36-
return (thenClause as ISelectExpression).DataType;
37-
}
38-
}
39-
}
40-
else if (whenOrElseClause.Type == HqlParser.ELSE)
41-
{
42-
// ELSE Child(0)
43-
IASTNode elseClause = whenOrElseClause.GetChild(0);
44-
if (elseClause is ISelectExpression)
45-
{
46-
if (!(elseClause is ParameterNode))
47-
{
48-
return (elseClause as ISelectExpression).DataType;
49-
}
50-
}
51-
}
52-
else
53-
{
54-
throw new HibernateException("Was expecting a WHEN or ELSE, but found a: " + whenOrElseClause.Text);
55-
}
34+
if (node is ISelectExpression select && !(node is ParameterNode))
35+
return select.DataType;
5636
}
37+
5738
throw new HibernateException("Unable to determine data type of CASE statement.");
5839
}
5940
set { base.DataType = value; }
6041
}
6142

43+
public IEnumerable<IASTNode> GetResultNodes()
44+
{
45+
for (int i = 0; i < ChildCount; i++)
46+
{
47+
IASTNode whenOrElseClause = GetChild(i);
48+
if (whenOrElseClause.Type == HqlParser.WHEN)
49+
{
50+
// WHEN Child(0) THEN Child(1)
51+
yield return whenOrElseClause.GetChild(1);
52+
}
53+
else if (whenOrElseClause.Type == HqlParser.ELSE)
54+
{
55+
// ELSE Child(0)
56+
yield return whenOrElseClause.GetChild(0);
57+
}
58+
else
59+
{
60+
throw new HibernateException("Was expecting a WHEN or ELSE, but found a: " + whenOrElseClause.Text);
61+
}
62+
}
63+
}
64+
6265
public override void SetScalarColumnText(int i)
6366
{
6467
ColumnHelper.GenerateSingleScalarColumn(ASTFactory, this, i );
6568
}
69+
70+
public IType ExpectedType
71+
{
72+
get => _expectedType;
73+
set
74+
{
75+
_expectedType = value;
76+
foreach (var node in GetResultNodes().OfType<IExpectedTypeAwareNode>())
77+
{
78+
node.ExpectedType = ExpectedType;
79+
}
80+
}
81+
}
6682
}
6783
}

0 commit comments

Comments
 (0)