Skip to content

Commit 5d374ab

Browse files
maca88fredericDelaporte
authored andcommitted
Fix parameter detection when using GroupBy method for Linq provider
1 parent f393a65 commit 5d374ab

File tree

7 files changed

+208
-1
lines changed

7 files changed

+208
-1
lines changed

src/NHibernate.Test/Async/Linq/ByMethod/GroupByHavingTests.cs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
using System;
1212
using System.Linq;
13+
using NHibernate.DomainModel.Northwind.Entities;
1314
using NUnit.Framework;
1415
using NHibernate.Linq;
1516

@@ -147,5 +148,45 @@ public async Task SingleKeyGroupAndCountWithHavingClauseAsync()
147148
var hornRow = orderCounts.Single(row => row.CompanyName == "Around the Horn");
148149
Assert.That(hornRow.OrderCount, Is.EqualTo(13));
149150
}
151+
152+
[Test]
153+
public async Task HavingWithStringEnumParameterAsync()
154+
{
155+
await (db.Users
156+
.GroupBy(p => p.Enum1)
157+
.Where(g => g.Key == EnumStoredAsString.Large)
158+
.Select(g => g.Count())
159+
.ToListAsync());
160+
await (db.Users
161+
.GroupBy(p => new StringEnumGroup {Enum = p.Enum1})
162+
.Where(g => g.Key.Enum == EnumStoredAsString.Large)
163+
.Select(g => g.Count())
164+
.ToListAsync());
165+
await (db.Users
166+
.GroupBy(p => new[] {p.Enum1})
167+
.Where(g => g.Key[0] == EnumStoredAsString.Large)
168+
.Select(g => g.Count())
169+
.ToListAsync());
170+
await (db.Users
171+
.GroupBy(p => new {p.Enum1})
172+
.Where(g => g.Key.Enum1 == EnumStoredAsString.Large)
173+
.Select(g => g.Count())
174+
.ToListAsync());
175+
await (db.Users
176+
.GroupBy(p => new {Test = new {Test2 = p.Enum1}})
177+
.Where(g => g.Key.Test.Test2 == EnumStoredAsString.Large)
178+
.Select(g => g.Count())
179+
.ToListAsync());
180+
await (db.Users
181+
.GroupBy(p => new {Test = new[] {p.Enum1}})
182+
.Where(g => g.Key.Test[0] == EnumStoredAsString.Large)
183+
.Select(g => g.Count())
184+
.ToListAsync());
185+
}
186+
187+
private class StringEnumGroup
188+
{
189+
public EnumStoredAsString Enum { get; set; }
190+
}
150191
}
151192
}

src/NHibernate.Test/Async/Linq/ByMethod/GroupByTests.cs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,40 @@ public async Task GroupByAndAnyAsync()
382382
Assert.That(namesAreNotEmpty, Is.True);
383383
}
384384

385+
[Test]
386+
public async Task GroupByWithStringEnumParameterAsync()
387+
{
388+
await (db.Users
389+
.GroupBy(p => p.Enum1)
390+
.Select(g => g.Key == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0)
391+
.ToListAsync());
392+
await (db.Users
393+
.GroupBy(p => new StringEnumGroup {Enum = p.Enum1})
394+
.Select(g => g.Key.Enum == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0)
395+
.ToListAsync());
396+
await (db.Users
397+
.GroupBy(p => new[] {p.Enum1})
398+
.Select(g => g.Key[0] == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0)
399+
.ToListAsync());
400+
await (db.Users
401+
.GroupBy(p => new {p.Enum1})
402+
.Select(g => g.Key.Enum1 == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0)
403+
.ToListAsync());
404+
await (db.Users
405+
.GroupBy(p => new {Test = new {Test2 = p.Enum1}})
406+
.Select(g => g.Key.Test.Test2 == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0)
407+
.ToListAsync());
408+
await (db.Users
409+
.GroupBy(p => new {Test = new[] {p.Enum1}})
410+
.Select(g => g.Key.Test[0] == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0)
411+
.ToListAsync());
412+
}
413+
414+
private class StringEnumGroup
415+
{
416+
public EnumStoredAsString Enum { get; set; }
417+
}
418+
385419
[Test]
386420
public async Task SelectFirstElementFromProductsGroupedByUnitPriceAsync()
387421
{

src/NHibernate.Test/Linq/ByMethod/GroupByHavingTests.cs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Linq;
3+
using NHibernate.DomainModel.Northwind.Entities;
34
using NUnit.Framework;
45

56
namespace NHibernate.Test.Linq.ByMethod
@@ -135,5 +136,45 @@ public void SingleKeyGroupAndCountWithHavingClause()
135136
var hornRow = orderCounts.Single(row => row.CompanyName == "Around the Horn");
136137
Assert.That(hornRow.OrderCount, Is.EqualTo(13));
137138
}
139+
140+
[Test]
141+
public void HavingWithStringEnumParameter()
142+
{
143+
db.Users
144+
.GroupBy(p => p.Enum1)
145+
.Where(g => g.Key == EnumStoredAsString.Large)
146+
.Select(g => g.Count())
147+
.ToList();
148+
db.Users
149+
.GroupBy(p => new StringEnumGroup {Enum = p.Enum1})
150+
.Where(g => g.Key.Enum == EnumStoredAsString.Large)
151+
.Select(g => g.Count())
152+
.ToList();
153+
db.Users
154+
.GroupBy(p => new[] {p.Enum1})
155+
.Where(g => g.Key[0] == EnumStoredAsString.Large)
156+
.Select(g => g.Count())
157+
.ToList();
158+
db.Users
159+
.GroupBy(p => new {p.Enum1})
160+
.Where(g => g.Key.Enum1 == EnumStoredAsString.Large)
161+
.Select(g => g.Count())
162+
.ToList();
163+
db.Users
164+
.GroupBy(p => new {Test = new {Test2 = p.Enum1}})
165+
.Where(g => g.Key.Test.Test2 == EnumStoredAsString.Large)
166+
.Select(g => g.Count())
167+
.ToList();
168+
db.Users
169+
.GroupBy(p => new {Test = new[] {p.Enum1}})
170+
.Where(g => g.Key.Test[0] == EnumStoredAsString.Large)
171+
.Select(g => g.Count())
172+
.ToList();
173+
}
174+
175+
private class StringEnumGroup
176+
{
177+
public EnumStoredAsString Enum { get; set; }
178+
}
138179
}
139180
}

src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,40 @@ public void GroupByAndAny()
371371
Assert.That(namesAreNotEmpty, Is.True);
372372
}
373373

374+
[Test]
375+
public void GroupByWithStringEnumParameter()
376+
{
377+
db.Users
378+
.GroupBy(p => p.Enum1)
379+
.Select(g => g.Key == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0)
380+
.ToList();
381+
db.Users
382+
.GroupBy(p => new StringEnumGroup {Enum = p.Enum1})
383+
.Select(g => g.Key.Enum == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0)
384+
.ToList();
385+
db.Users
386+
.GroupBy(p => new[] {p.Enum1})
387+
.Select(g => g.Key[0] == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0)
388+
.ToList();
389+
db.Users
390+
.GroupBy(p => new {p.Enum1})
391+
.Select(g => g.Key.Enum1 == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0)
392+
.ToList();
393+
db.Users
394+
.GroupBy(p => new {Test = new {Test2 = p.Enum1}})
395+
.Select(g => g.Key.Test.Test2 == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0)
396+
.ToList();
397+
db.Users
398+
.GroupBy(p => new {Test = new[] {p.Enum1}})
399+
.Select(g => g.Key.Test[0] == EnumStoredAsString.Large ? g.Sum(o => o.Id) : 0)
400+
.ToList();
401+
}
402+
403+
private class StringEnumGroup
404+
{
405+
public EnumStoredAsString Enum { get; set; }
406+
}
407+
374408
[Test]
375409
public void SelectFirstElementFromProductsGroupedByUnitPrice()
376410
{

src/NHibernate/Linq/ExpressionExtensions.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,23 @@ public static bool IsGroupingKey(this MemberExpression expression)
1515
expression.Member.DeclaringType.IsGenericType && expression.Member.DeclaringType.GetGenericTypeDefinition() == typeof(IGrouping<,>);
1616
}
1717

18+
internal static bool TryGetGroupResultOperator(this MemberExpression keyExpression, out GroupResultOperator groupBy)
19+
{
20+
if (keyExpression.IsGroupingKey() &&
21+
keyExpression.Expression is QuerySourceReferenceExpression querySource &&
22+
querySource.ReferencedQuerySource is MainFromClause fromClause &&
23+
fromClause.FromExpression is SubQueryExpression query)
24+
{
25+
groupBy = query.QueryModel.ResultOperators
26+
.OfType<GroupResultOperator>()
27+
.FirstOrDefault(o => o.KeySelector.Type == keyExpression.Type);
28+
return groupBy != null;
29+
}
30+
31+
groupBy = null;
32+
return false;
33+
}
34+
1835
public static bool IsGroupingKeyOf(this MemberExpression expression,GroupResultOperator groupBy)
1936
{
2037
if (!expression.IsGroupingKey())

src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ private void VisitAssign(Expression leftNode, Expression rightNode)
343343
private void AddRelatedExpression(Expression node, Expression left, Expression right)
344344
{
345345
if (left.NodeType == ExpressionType.MemberAccess ||
346+
left.NodeType == ExpressionType.ArrayIndex || // e.g. group.Key[0] == variable
346347
IsDynamicMember(left) ||
347348
left is QuerySourceReferenceExpression)
348349
{

src/NHibernate/Util/ExpressionsHelper.cs

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
using Remotion.Linq.Clauses;
1818
using Remotion.Linq.Clauses.Expressions;
1919
using Remotion.Linq.Clauses.ResultOperators;
20-
using Remotion.Linq.Parsing;
20+
using TransparentIdentifierRemovingExpressionVisitor = NHibernate.Linq.Visitors.TransparentIdentifierRemovingExpressionVisitor;
2121

2222
namespace NHibernate.Util
2323
{
@@ -594,6 +594,43 @@ private static IType GetType(
594594
: TypeFactory.GetDefaultTypeFor(member.ConvertType); // (long)q.OneToMany[0]
595595
}
596596

597+
private class GroupingKeyFlattener : NhExpressionVisitor
598+
{
599+
private bool _flattened;
600+
601+
public static Expression FlattenGroupingKey(Expression expression)
602+
{
603+
var visitor = new GroupingKeyFlattener();
604+
expression = visitor.Visit(expression);
605+
if (visitor._flattened)
606+
{
607+
expression = TransparentIdentifierRemovingExpressionVisitor.ReplaceTransparentIdentifiers(expression);
608+
// When the grouping key is an array we have to unwrap it (e.g. group.Key[0] == variable)
609+
if (expression.NodeType == ExpressionType.ArrayIndex &&
610+
expression is BinaryExpression binaryExpression &&
611+
binaryExpression.Left is NewArrayExpression newArray &&
612+
binaryExpression.Right is ConstantExpression indexExpression &&
613+
indexExpression.Value is int index)
614+
{
615+
return newArray.Expressions[index];
616+
}
617+
}
618+
619+
return expression;
620+
}
621+
622+
protected override Expression VisitMember(MemberExpression node)
623+
{
624+
if (node.TryGetGroupResultOperator(out var groupBy))
625+
{
626+
_flattened = true;
627+
return groupBy.KeySelector;
628+
}
629+
630+
return base.VisitMember(node);
631+
}
632+
}
633+
597634
private class MemberMetadataExtractor : NhExpressionVisitor
598635
{
599636
private readonly List<MemberMetadataResult> _childrenResults = new List<MemberMetadataResult>();
@@ -639,6 +676,8 @@ private static bool TryGetAllMemberMetadata(
639676
bool hasIndexer,
640677
out MemberMetadataResult results)
641678
{
679+
expression = GroupingKeyFlattener.FlattenGroupingKey(expression);
680+
642681
var extractor = new MemberMetadataExtractor(memberPaths, convertType, hasIndexer);
643682
extractor.Accept(expression);
644683
results = extractor._entityName != null || extractor._childrenResults.Count > 0

0 commit comments

Comments
 (0)