Skip to content

Fix parameter detection for Contains method for Linq provider #2520

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ public class AnotherEntityRequired

public virtual ISet<AnotherEntity> RelatedItems { get; set; } = new HashSet<AnotherEntity>();

public virtual ISet<AnotherEntityRequired> RequiredRelatedItems { get; set; } = new HashSet<AnotherEntityRequired>();

public virtual bool? NullableBool { get; set; }
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,9 @@
<key column="Id"/>
<one-to-many class="AnotherEntity"/>
</set>
<set name="RequiredRelatedItems" lazy="true" inverse="true">
<key column="Id"/>
<one-to-many class="AnotherEntityRequired"/>
</set>
</class>
</hibernate-mapping>
74 changes: 74 additions & 0 deletions src/NHibernate.Test/Async/Linq/ParameterTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,80 @@ public async Task UsingEntityParameterTwiceAsync()
1));
}

[Test]
public async Task UsingEntityParameterForCollectionAsync()
{
var item = await (db.OrderLines.FirstAsync());
await (AssertTotalParametersAsync(
db.Orders.Where(o => o.OrderLines.Contains(item)),
1));
}

[Test]
public async Task UsingProxyParameterForCollectionAsync()
{
var item = await (session.LoadAsync<Order>(10248));
Assert.That(NHibernateUtil.IsInitialized(item), Is.False);
await (AssertTotalParametersAsync(
db.Customers.Where(o => o.Orders.Contains(item)),
1));
}

[Test]
public async Task UsingFieldProxyParameterForCollectionAsync()
{
var item = await (session.Query<AnotherEntityRequired>().FirstAsync());
await (AssertTotalParametersAsync(
session.Query<AnotherEntityRequired>().Where(o => o.RequiredRelatedItems.Contains(item)),
1));
}

[Test]
public async Task UsingEntityParameterInSubQueryAsync()
{
var item = await (db.Customers.FirstAsync());
var subQuery = db.Orders.Select(o => o.Customer).Where(o => o == item);
await (AssertTotalParametersAsync(
db.Orders.Where(o => subQuery.Contains(o.Customer)),
1));
}

[Test]
public async Task UsingEntityParameterForCollectionSelectionAsync()
{
var item = await (db.OrderLines.FirstAsync());
await (AssertTotalParametersAsync(
db.Orders.SelectMany(o => o.OrderLines).Where(o => o == item),
1));
}

[Test]
public async Task UsingFieldProxyParameterForCollectionSelectionAsync()
{
var item = await (session.Query<AnotherEntityRequired>().FirstAsync());
await (AssertTotalParametersAsync(
session.Query<AnotherEntityRequired>().SelectMany(o => o.RequiredRelatedItems).Where(o => o == item),
1));
}

[Test]
public async Task UsingEntityListParameterForCollectionSelectionAsync()
{
var items = new[] {await (db.OrderLines.FirstAsync())};
await (AssertTotalParametersAsync(
db.Orders.SelectMany(o => o.OrderLines).Where(o => items.Contains(o)),
1));
}

[Test]
public async Task UsingFieldProxyListParameterForCollectionSelectionAsync()
{
var items = new[] {await (session.Query<AnotherEntityRequired>().FirstAsync())};
await (AssertTotalParametersAsync(
session.Query<AnotherEntityRequired>().SelectMany(o => o.RequiredRelatedItems).Where(o => items.Contains(o)),
1));
}

[Test]
public async Task UsingTwoEntityParametersAsync()
{
Expand Down
74 changes: 74 additions & 0 deletions src/NHibernate.Test/Linq/ParameterTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,80 @@ public void UsingEntityParameterTwice()
1);
}

[Test]
public void UsingEntityParameterForCollection()
{
var item = db.OrderLines.First();
AssertTotalParameters(
db.Orders.Where(o => o.OrderLines.Contains(item)),
1);
}

[Test]
public void UsingProxyParameterForCollection()
{
var item = session.Load<Order>(10248);
Assert.That(NHibernateUtil.IsInitialized(item), Is.False);
AssertTotalParameters(
db.Customers.Where(o => o.Orders.Contains(item)),
1);
}

[Test]
public void UsingFieldProxyParameterForCollection()
{
var item = session.Query<AnotherEntityRequired>().First();
AssertTotalParameters(
session.Query<AnotherEntityRequired>().Where(o => o.RequiredRelatedItems.Contains(item)),
1);
}

[Test]
public void UsingEntityParameterInSubQuery()
{
var item = db.Customers.First();
var subQuery = db.Orders.Select(o => o.Customer).Where(o => o == item);
AssertTotalParameters(
db.Orders.Where(o => subQuery.Contains(o.Customer)),
1);
}

[Test]
public void UsingEntityParameterForCollectionSelection()
{
var item = db.OrderLines.First();
AssertTotalParameters(
db.Orders.SelectMany(o => o.OrderLines).Where(o => o == item),
1);
}

[Test]
public void UsingFieldProxyParameterForCollectionSelection()
{
var item = session.Query<AnotherEntityRequired>().First();
AssertTotalParameters(
session.Query<AnotherEntityRequired>().SelectMany(o => o.RequiredRelatedItems).Where(o => o == item),
1);
}

[Test]
public void UsingEntityListParameterForCollectionSelection()
{
var items = new[] {db.OrderLines.First()};
AssertTotalParameters(
db.Orders.SelectMany(o => o.OrderLines).Where(o => items.Contains(o)),
1);
}

[Test]
public void UsingFieldProxyListParameterForCollectionSelection()
{
var items = new[] {session.Query<AnotherEntityRequired>().First()};
AssertTotalParameters(
session.Query<AnotherEntityRequired>().SelectMany(o => o.RequiredRelatedItems).Where(o => items.Contains(o)),
1);
}

[Test]
public void UsingTwoEntityParameters()
{
Expand Down
61 changes: 31 additions & 30 deletions src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ private static IType GetCandidateType(
if (!ExpressionsHelper.TryGetMappedType(sessionFactory, relatedExpression, out var mappedType, out _, out _, out _))
continue;

if (mappedType.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(relatedExpression))
if (mappedType.IsCollectionType)
{
var collection = (IQueryableCollection) ((IAssociationType) mappedType).GetAssociatedJoinable(sessionFactory);
mappedType = collection.ElementType;
Expand Down Expand Up @@ -176,7 +176,6 @@ private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor
new Dictionary<NamedParameter, HashSet<ConstantExpression>>();
public readonly Dictionary<Expression, HashSet<Expression>> RelatedExpressions =
new Dictionary<Expression, HashSet<Expression>>();
public readonly HashSet<Expression> SequenceSelectorExpressions = new HashSet<Expression>();

public ConstantTypeLocatorVisitor(
bool removeMappedAsCalls,
Expand Down Expand Up @@ -282,41 +281,43 @@ protected override Expression VisitConstant(ConstantExpression node)
}

protected override Expression VisitSubQuery(SubQueryExpression node)
{
if (!TryLinkContainsMethod(node.QueryModel))
{
node.QueryModel.TransformExpressions(Visit);
}

return node;
}

private bool TryLinkContainsMethod(QueryModel queryModel)
{
// ReLinq wraps all ResultOperatorExpressionNodeBase into a SubQueryExpression. In case of
// ContainsResultOperator where the constant expression is dislocated from the related expression,
// we have to manually link the related expressions.
if (node.QueryModel.ResultOperators.Count == 1 &&
node.QueryModel.ResultOperators[0] is ContainsResultOperator containsOperator &&
node.QueryModel.SelectClause.Selector is QuerySourceReferenceExpression querySourceReference &&
querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause &&
mainFromClause.FromExpression is ConstantExpression constantExpression)
if (queryModel.ResultOperators.Count != 1 ||
!(queryModel.ResultOperators[0] is ContainsResultOperator containsOperator) ||
!(queryModel.SelectClause.Selector is QuerySourceReferenceExpression querySourceReference) ||
!(querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause))
{
VisitConstant(constantExpression);
AddRelatedExpression(constantExpression, UnwrapUnary(Visit(containsOperator.Item)));
// Copy all found MemberExpressions to the constant expression
// (e.g. values.Contains(o.Name != o.Name2 ? o.Enum1 : o.Enum2) -> copy o.Enum1 and o.Enum2)
if (RelatedExpressions.TryGetValue(containsOperator.Item, out var set))
{
foreach (var nestedMemberExpression in set)
{
AddRelatedExpression(constantExpression, nestedMemberExpression);
}
}
return false;
}
else
{
// In case a parameter is related to a sequence selector we will have to get the underlying item type
// (e.g. q.Where(o => o.Users.Any(u => u == user)))
if (node.QueryModel.ResultOperators.Any(o => o is ValueFromSequenceResultOperatorBase))
{
SequenceSelectorExpressions.Add(node.QueryModel.SelectClause.Selector);
}

node.QueryModel.TransformExpressions(Visit);
var left = UnwrapUnary(Visit(mainFromClause.FromExpression));
var right = UnwrapUnary(Visit(containsOperator.Item));
// The constant is on the left side (e.g. db.Users.Where(o => users.Contains(o)))
// The constant is on the right side (e.g. db.Customers.Where(o => o.Orders.Contains(item)))
if (left.NodeType != ExpressionType.Constant && right.NodeType != ExpressionType.Constant)
{
return false;
}

return node;
// Copy all found MemberExpressions to the constant expression
// (e.g. values.Contains(o.Name != o.Name2 ? o.Enum1 : o.Enum2) -> copy o.Enum1 and o.Enum2)
AddRelatedExpression(null, left, right);
AddRelatedExpression(null, right, left);

return true;
}

private void VisitAssign(Expression leftNode, Expression rightNode)
Expand Down Expand Up @@ -346,7 +347,7 @@ private void AddRelatedExpression(Expression node, Expression left, Expression r
left is QuerySourceReferenceExpression)
{
AddRelatedExpression(right, left);
if (NonVoidOperators.Contains(node.NodeType))
if (node != null && NonVoidOperators.Contains(node.NodeType))
{
AddRelatedExpression(node, left);
}
Expand All @@ -359,7 +360,7 @@ private void AddRelatedExpression(Expression node, Expression left, Expression r
foreach (var nestedMemberExpression in set)
{
AddRelatedExpression(right, nestedMemberExpression);
if (NonVoidOperators.Contains(node.NodeType))
if (node != null && NonVoidOperators.Contains(node.NodeType))
{
AddRelatedExpression(node, nestedMemberExpression);
}
Expand Down