Skip to content

Commit 83db507

Browse files
authored
LINQ subqueries wrongly altered by SelectClauseVisitor (#3271)
1 parent 3088367 commit 83db507

File tree

3 files changed

+39
-4
lines changed

3 files changed

+39
-4
lines changed

src/NHibernate.Test/Async/Linq/WhereTests.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,17 @@ where sheet.Users.Contains(user)
644644
Assert.That(query.Count, Is.EqualTo(2));
645645
}
646646

647+
[Test]
648+
public async Task TimesheetsWithEnumerableContainsOnSelectAsync()
649+
{
650+
var value = (EnumStoredAsInt32) 1000;
651+
var query = await ((from sheet in db.Timesheets
652+
where sheet.Users.Select(x => x.NullableEnum2 ?? value).Contains(value)
653+
select sheet).ToListAsync());
654+
655+
Assert.That(query.Count, Is.EqualTo(1));
656+
}
657+
647658
[Test]
648659
public async Task SearchOnObjectTypeWithExtensionMethodAsync()
649660
{

src/NHibernate.Test/Linq/WhereTests.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,17 @@ where sheet.Users.Contains(user)
645645
Assert.That(query.Count, Is.EqualTo(2));
646646
}
647647

648+
[Test]
649+
public void TimesheetsWithEnumerableContainsOnSelect()
650+
{
651+
var value = (EnumStoredAsInt32) 1000;
652+
var query = (from sheet in db.Timesheets
653+
where sheet.Users.Select(x => x.NullableEnum2 ?? value).Contains(value)
654+
select sheet).ToList();
655+
656+
Assert.That(query.Count, Is.EqualTo(1));
657+
}
658+
648659
[Test]
649660
public void SearchOnObjectTypeWithExtensionMethod()
650661
{

src/NHibernate/Linq/Visitors/QueryModelVisitor.cs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ public static ExpressionToHqlTranslationResults GenerateHqlQuery(QueryModel quer
114114
private readonly NhLinqExpressionReturnType? _rootReturnType;
115115
private static readonly ResultOperatorMap ResultOperatorMap;
116116
private bool _serverSide = true;
117+
private readonly bool _root;
117118

118119
public VisitorParameters VisitorParameters { get; }
119120

@@ -161,6 +162,7 @@ private QueryModelVisitor(VisitorParameters visitorParameters, bool root, QueryM
161162
_queryMode = root ? visitorParameters.RootQueryMode : QueryMode.Select;
162163
VisitorParameters = visitorParameters;
163164
Model = queryModel;
165+
_root = root;
164166
_rootReturnType = root ? rootReturnType : null;
165167
_hqlTree = new IntermediateHqlTree(root, _queryMode);
166168
}
@@ -467,19 +469,27 @@ public override void VisitSelectClause(SelectClause selectClause, QueryModel que
467469
}
468470

469471
//This is a standard select query
472+
_hqlTree.AddSelectClause(GetSelectClause(selectClause.Selector));
473+
474+
base.VisitSelectClause(selectClause, queryModel);
475+
}
476+
477+
private HqlSelect GetSelectClause(Expression selectClause)
478+
{
479+
if (!_root)
480+
return _hqlTree.TreeBuilder.Select(
481+
HqlGeneratorExpressionVisitor.Visit(selectClause, VisitorParameters).AsExpression());
470482

471483
var visitor = new SelectClauseVisitor(typeof(object[]), VisitorParameters);
472484

473-
visitor.VisitSelector(selectClause.Selector);
485+
visitor.VisitSelector(selectClause);
474486

475487
if (visitor.ProjectionExpression != null)
476488
{
477489
_hqlTree.AddItemTransformer(visitor.ProjectionExpression);
478490
}
479491

480-
_hqlTree.AddSelectClause(_hqlTree.TreeBuilder.Select(visitor.GetHqlNodes()));
481-
482-
base.VisitSelectClause(selectClause, queryModel);
492+
return _hqlTree.TreeBuilder.Select(visitor.GetHqlNodes());
483493
}
484494

485495
private void VisitInsertClause(Expression expression)
@@ -527,6 +537,9 @@ private void VisitUpdateClause(Expression expression)
527537

528538
private void VisitDeleteClause(Expression expression)
529539
{
540+
if (!_root)
541+
return;
542+
530543
// We only need to check there is no unexpected select, for avoiding silently ignoring them.
531544
var visitor = new SelectClauseVisitor(typeof(object[]), VisitorParameters);
532545
visitor.VisitSelector(expression);

0 commit comments

Comments
 (0)