From 21ab72e86a0a8ba410d9fecc7879af71a854b96f Mon Sep 17 00:00:00 2001 From: maca88 Date: Tue, 10 Mar 2020 21:08:34 +0100 Subject: [PATCH 01/10] Add cross join support for Hql and Linq query provider --- .../Async/Hql/EntityJoinHqlTest.cs | 21 +++++++++++++++ .../Async/Linq/ByMethod/JoinTests.cs | 26 +++++++++++++++++++ src/NHibernate.Test/Hql/EntityJoinHqlTest.cs | 23 +++++++++++++++- .../Linq/ByMethod/JoinTests.cs | 26 +++++++++++++++++++ src/NHibernate/AdoNet/Util/BasicFormatter.cs | 1 + src/NHibernate/Dialect/DB2Dialect.cs | 3 +++ src/NHibernate/Dialect/Dialect.cs | 5 ++++ src/NHibernate/Dialect/InformixDialect0940.cs | 5 +++- src/NHibernate/Dialect/Oracle10gDialect.cs | 3 +++ src/NHibernate/Dialect/Oracle8iDialect.cs | 3 +++ src/NHibernate/Dialect/SybaseASA9Dialect.cs | 3 +++ src/NHibernate/Dialect/SybaseASE15Dialect.cs | 5 +++- src/NHibernate/Hql/Ast/ANTLR/Hql.g | 2 ++ src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.g | 3 +++ .../Hql/Ast/ANTLR/Util/JoinProcessor.cs | 2 ++ src/NHibernate/Hql/Ast/HqlTreeBuilder.cs | 5 ++++ src/NHibernate/Hql/Ast/HqlTreeNode.cs | 15 +++++++++++ .../Linq/Visitors/QueryModelVisitor.cs | 18 ++++++------- src/NHibernate/SqlCommand/ANSIJoinFragment.cs | 11 +++++++- src/NHibernate/SqlCommand/JoinFragment.cs | 3 ++- 20 files changed, 168 insertions(+), 15 deletions(-) diff --git a/src/NHibernate.Test/Async/Hql/EntityJoinHqlTest.cs b/src/NHibernate.Test/Async/Hql/EntityJoinHqlTest.cs index c573a7666a8..dcca7fc2924 100644 --- a/src/NHibernate.Test/Async/Hql/EntityJoinHqlTest.cs +++ b/src/NHibernate.Test/Async/Hql/EntityJoinHqlTest.cs @@ -274,6 +274,27 @@ public async Task EntityJoinWithFetchesAsync() } } + [Test] + public async Task CrossJoinAndWhereClauseAsync() + { + using (var sqlLog = new SqlLogSpy()) + using (var session = OpenSession()) + { + var result = await (session.CreateQuery( + "SELECT s " + + "FROM EntityComplex s cross join EntityComplex q " + + "where s.SameTypeChild.Id = q.SameTypeChild.Id" + ).ListAsync()); + + Assert.That(result, Has.Count.EqualTo(1)); + Assert.That(sqlLog.Appender.GetEvents().Length, Is.EqualTo(1), "Only one SQL select is expected"); + if (Dialect.SupportsCrossJoin) + { + Assert.That(sqlLog.GetWholeLog(), Does.Contain("cross join"), "A cross join is expected in the SQL select"); + } + } + } + #region Test Setup protected override HbmMapping GetMappings() diff --git a/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs b/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs index 6d11edace21..4dc750f4119 100644 --- a/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs +++ b/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs @@ -31,5 +31,31 @@ public async Task MultipleLinqJoinsWithSameProjectionNamesAsync() Assert.That(orders.Count, Is.EqualTo(828)); Assert.IsTrue(orders.All(x => x.FirstId == x.SecondId - 1 && x.SecondId == x.ThirdId - 1)); } + + [Test] + public async Task CrossJoinWithPredicateInOnStatementAsync() + { + var result = + await ((from o in db.Orders + from p in db.Products + join d in db.OrderLines + on new { o.OrderId, p.ProductId } equals new { d.Order.OrderId, d.Product.ProductId } + into details + from d in details + select new { o.OrderId, p.ProductId, d.UnitPrice }).Take(10).ToListAsync()); + + Assert.That(result.Count, Is.EqualTo(10)); + } + + [Test] + public async Task CrossJoinWithPredicateInWhereStatementAsync() + { + var result = await ((from o in db.Orders + from o2 in db.Orders.Where(x => x.Freight > 50) + where (o.OrderId == o2.OrderId + 1) || (o.OrderId == o2.OrderId - 1) + select new { o.OrderId, OrderId2 = o2.OrderId }).ToListAsync()); + + Assert.That(result.Count, Is.EqualTo(720)); + } } } diff --git a/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs b/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs index 0b8d0c11a6b..0a578bd982a 100644 --- a/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs +++ b/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs @@ -264,7 +264,7 @@ public void EntityJoinWithFetches() } [Test, Ignore("Failing for unrelated reasons")] - public void CrossJoinAndWithClause() + public void ImplicitJoinAndWithClause() { //This is about complex theta style join fix that was implemented in hibernate along with entity join functionality //https://hibernate.atlassian.net/browse/HHH-7321 @@ -279,6 +279,27 @@ public void CrossJoinAndWithClause() } } + [Test] + public void CrossJoinAndWhereClause() + { + using (var sqlLog = new SqlLogSpy()) + using (var session = OpenSession()) + { + var result = session.CreateQuery( + "SELECT s " + + "FROM EntityComplex s cross join EntityComplex q " + + "where s.SameTypeChild.Id = q.SameTypeChild.Id" + ).List(); + + Assert.That(result, Has.Count.EqualTo(1)); + Assert.That(sqlLog.Appender.GetEvents().Length, Is.EqualTo(1), "Only one SQL select is expected"); + if (Dialect.SupportsCrossJoin) + { + Assert.That(sqlLog.GetWholeLog(), Does.Contain("cross join"), "A cross join is expected in the SQL select"); + } + } + } + #region Test Setup protected override HbmMapping GetMappings() diff --git a/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs b/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs index 1d78b46b181..d57bb4cdedf 100644 --- a/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs @@ -19,5 +19,31 @@ public void MultipleLinqJoinsWithSameProjectionNames() Assert.That(orders.Count, Is.EqualTo(828)); Assert.IsTrue(orders.All(x => x.FirstId == x.SecondId - 1 && x.SecondId == x.ThirdId - 1)); } + + [Test] + public void CrossJoinWithPredicateInOnStatement() + { + var result = + (from o in db.Orders + from p in db.Products + join d in db.OrderLines + on new { o.OrderId, p.ProductId } equals new { d.Order.OrderId, d.Product.ProductId } + into details + from d in details + select new { o.OrderId, p.ProductId, d.UnitPrice }).Take(10).ToList(); + + Assert.That(result.Count, Is.EqualTo(10)); + } + + [Test] + public void CrossJoinWithPredicateInWhereStatement() + { + var result = (from o in db.Orders + from o2 in db.Orders.Where(x => x.Freight > 50) + where (o.OrderId == o2.OrderId + 1) || (o.OrderId == o2.OrderId - 1) + select new { o.OrderId, OrderId2 = o2.OrderId }).ToList(); + + Assert.That(result.Count, Is.EqualTo(720)); + } } } diff --git a/src/NHibernate/AdoNet/Util/BasicFormatter.cs b/src/NHibernate/AdoNet/Util/BasicFormatter.cs index 2e601990e06..c3f35fbddf1 100644 --- a/src/NHibernate/AdoNet/Util/BasicFormatter.cs +++ b/src/NHibernate/AdoNet/Util/BasicFormatter.cs @@ -20,6 +20,7 @@ static BasicFormatter() { beginClauses.Add("left"); beginClauses.Add("right"); + beginClauses.Add("cross"); beginClauses.Add("inner"); beginClauses.Add("outer"); beginClauses.Add("group"); diff --git a/src/NHibernate/Dialect/DB2Dialect.cs b/src/NHibernate/Dialect/DB2Dialect.cs index d56ab6dc06f..3eef1635595 100644 --- a/src/NHibernate/Dialect/DB2Dialect.cs +++ b/src/NHibernate/Dialect/DB2Dialect.cs @@ -300,6 +300,9 @@ public override string ForUpdateString public override bool SupportsResultSetPositionQueryMethodsOnForwardOnlyCursor => false; + /// + public override bool SupportsCrossJoin => false; // DB2 v9.1 doesn't support 'cross join' syntax + public override bool SupportsLobValueChangePropogation => false; public override bool SupportsExistsInSelect => false; diff --git a/src/NHibernate/Dialect/Dialect.cs b/src/NHibernate/Dialect/Dialect.cs index e3d0b6cacd9..423c5082361 100644 --- a/src/NHibernate/Dialect/Dialect.cs +++ b/src/NHibernate/Dialect/Dialect.cs @@ -1337,6 +1337,11 @@ public virtual JoinFragment CreateOuterJoinFragment() return new ANSIJoinFragment(); } + /// + /// Does this dialect support CROSS JOIN? + /// + public virtual bool SupportsCrossJoin => true; + /// /// Create a strategy responsible /// for handling this dialect's variations in how CASE statements are diff --git a/src/NHibernate/Dialect/InformixDialect0940.cs b/src/NHibernate/Dialect/InformixDialect0940.cs index 13563df43e8..bdfcae4f27f 100644 --- a/src/NHibernate/Dialect/InformixDialect0940.cs +++ b/src/NHibernate/Dialect/InformixDialect0940.cs @@ -126,7 +126,10 @@ public override JoinFragment CreateOuterJoinFragment() return new ANSIJoinFragment(); } - /// + /// + public override bool SupportsCrossJoin => false; + + /// /// Does this Dialect have some kind of LIMIT syntax? /// /// False, unless overridden. diff --git a/src/NHibernate/Dialect/Oracle10gDialect.cs b/src/NHibernate/Dialect/Oracle10gDialect.cs index 7219390c0e9..caab3e1f492 100644 --- a/src/NHibernate/Dialect/Oracle10gDialect.cs +++ b/src/NHibernate/Dialect/Oracle10gDialect.cs @@ -15,5 +15,8 @@ public override JoinFragment CreateOuterJoinFragment() { return new ANSIJoinFragment(); } + + /// + public override bool SupportsCrossJoin => true; } } \ No newline at end of file diff --git a/src/NHibernate/Dialect/Oracle8iDialect.cs b/src/NHibernate/Dialect/Oracle8iDialect.cs index f3a943a3b8e..b2103b87965 100644 --- a/src/NHibernate/Dialect/Oracle8iDialect.cs +++ b/src/NHibernate/Dialect/Oracle8iDialect.cs @@ -328,6 +328,9 @@ public override JoinFragment CreateOuterJoinFragment() return new OracleJoinFragment(); } + /// + public override bool SupportsCrossJoin => false; + /// /// Map case support to the Oracle DECODE function. Oracle did not /// add support for CASE until 9i. diff --git a/src/NHibernate/Dialect/SybaseASA9Dialect.cs b/src/NHibernate/Dialect/SybaseASA9Dialect.cs index 12ad1e3c088..f839871b6a1 100644 --- a/src/NHibernate/Dialect/SybaseASA9Dialect.cs +++ b/src/NHibernate/Dialect/SybaseASA9Dialect.cs @@ -97,6 +97,9 @@ public override bool OffsetStartsAtOne get { return true; } } + /// + public override bool SupportsCrossJoin => false; + public override SqlString GetLimitString(SqlString queryString, SqlString offset, SqlString limit) { int intSelectInsertPoint = GetAfterSelectInsertPoint(queryString); diff --git a/src/NHibernate/Dialect/SybaseASE15Dialect.cs b/src/NHibernate/Dialect/SybaseASE15Dialect.cs index 1b33b672cdf..0a84a822424 100644 --- a/src/NHibernate/Dialect/SybaseASE15Dialect.cs +++ b/src/NHibernate/Dialect/SybaseASE15Dialect.cs @@ -247,7 +247,10 @@ public override bool SupportsExpectedLobUsagePattern { get { return false; } } - + + /// + public override bool SupportsCrossJoin => false; + public override char OpenQuote { get { return '['; } diff --git a/src/NHibernate/Hql/Ast/ANTLR/Hql.g b/src/NHibernate/Hql/Ast/ANTLR/Hql.g index 9f67aaa2162..cef7a744c3f 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Hql.g +++ b/src/NHibernate/Hql/Ast/ANTLR/Hql.g @@ -19,6 +19,7 @@ tokens BETWEEN='between'; CLASS='class'; COUNT='count'; + CROSS='cross'; DELETE='delete'; DESCENDING='desc'; DOT; @@ -255,6 +256,7 @@ fromClause fromJoin : ( ( ( LEFT | RIGHT ) (OUTER)? ) | FULL | INNER )? JOIN^ (FETCH)? path (asAlias)? (propertyFetch)? (withClause)? | ( ( ( LEFT | RIGHT ) (OUTER)? ) | FULL | INNER )? JOIN^ (FETCH)? ELEMENTS! OPEN! path CLOSE! (asAlias)? (propertyFetch)? (withClause)? + | CROSS JOIN^ { WeakKeywords(); } path (asAlias)? ; withClause diff --git a/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.g b/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.g index 22e5450dce1..fba1010337c 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.g +++ b/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.g @@ -306,6 +306,9 @@ joinType returns [int j] | INNER { $j = INNER; } + | CROSS { + $j = CROSS; + } ; // Matches a path and returns the normalized string for the path (usually diff --git a/src/NHibernate/Hql/Ast/ANTLR/Util/JoinProcessor.cs b/src/NHibernate/Hql/Ast/ANTLR/Util/JoinProcessor.cs index c4f850f7aa3..8fbcfba7fd7 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Util/JoinProcessor.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Util/JoinProcessor.cs @@ -48,6 +48,8 @@ public static JoinType ToHibernateJoinType(int astJoinType) return JoinType.RightOuterJoin; case HqlSqlWalker.FULL: return JoinType.FullJoin; + case HqlSqlWalker.CROSS: + return JoinType.CrossJoin; default: throw new AssertionFailure("undefined join type " + astJoinType); } diff --git a/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs b/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs index 49cad7607e1..58d54af8804 100755 --- a/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs +++ b/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs @@ -483,6 +483,11 @@ public HqlLeftJoin LeftJoin(HqlExpression expression, HqlAlias @alias) return new HqlLeftJoin(_factory, expression, @alias); } + public HqlCrossJoin CrossJoin(HqlExpression expression, HqlAlias @alias) + { + return new HqlCrossJoin(_factory, expression, @alias); + } + public HqlFetchJoin FetchJoin(HqlExpression expression, HqlAlias @alias) { return new HqlFetchJoin(_factory, expression, @alias); diff --git a/src/NHibernate/Hql/Ast/HqlTreeNode.cs b/src/NHibernate/Hql/Ast/HqlTreeNode.cs index 5964b99db90..245b82db375 100755 --- a/src/NHibernate/Hql/Ast/HqlTreeNode.cs +++ b/src/NHibernate/Hql/Ast/HqlTreeNode.cs @@ -851,6 +851,13 @@ public class HqlLeftJoin : HqlTreeNode } } + public class HqlCrossJoin : HqlTreeNode + { + public HqlCrossJoin(IASTFactory factory, HqlExpression expression, HqlAlias @alias) : base(HqlSqlWalker.JOIN, "join", factory, new HqlCross(factory), expression, @alias) + { + } + } + public class HqlFetchJoin : HqlTreeNode { public HqlFetchJoin(IASTFactory factory, HqlExpression expression, HqlAlias @alias) @@ -906,6 +913,14 @@ public HqlLeft(IASTFactory factory) } } + public class HqlCross : HqlTreeNode + { + public HqlCross(IASTFactory factory) + : base(HqlSqlWalker.CROSS, "cross", factory) + { + } + } + public class HqlAny : HqlBooleanExpression { public HqlAny(IASTFactory factory) : base(HqlSqlWalker.ANY, "any", factory) diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index f043d30c51f..d2f9f892c54 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -315,22 +315,20 @@ public override void VisitMainFromClause(MainFromClause fromClause, QueryModel q public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, QueryModel queryModel, int index) { var querySourceName = VisitorParameters.QuerySourceNamer.GetName(fromClause); - + var fromExpressionTree = HqlGeneratorExpressionVisitor.Visit(fromClause.FromExpression, VisitorParameters); + var alias = _hqlTree.TreeBuilder.Alias(querySourceName); if (fromClause.FromExpression is MemberExpression) { // It's a join - _hqlTree.AddFromClause( - _hqlTree.TreeBuilder.Join( - HqlGeneratorExpressionVisitor.Visit(fromClause.FromExpression, VisitorParameters).AsExpression(), - _hqlTree.TreeBuilder.Alias(querySourceName))); + _hqlTree.AddFromClause(_hqlTree.TreeBuilder.Join(fromExpressionTree.AsExpression(), alias)); } else { - // TODO - exact same code as in MainFromClause; refactor this out - _hqlTree.AddFromClause( - _hqlTree.TreeBuilder.Range( - HqlGeneratorExpressionVisitor.Visit(fromClause.FromExpression, VisitorParameters), - _hqlTree.TreeBuilder.Alias(querySourceName))); + var join = VisitorParameters.SessionFactory.Dialect.SupportsCrossJoin + ? _hqlTree.TreeBuilder.CrossJoin(fromExpressionTree.AsExpression(), alias) + : (HqlTreeNode) _hqlTree.TreeBuilder.Range(fromExpressionTree, alias); + + _hqlTree.AddFromClause(join); } base.VisitAdditionalFromClause(fromClause, queryModel, index); diff --git a/src/NHibernate/SqlCommand/ANSIJoinFragment.cs b/src/NHibernate/SqlCommand/ANSIJoinFragment.cs index aa9421fa08e..063c361eb56 100644 --- a/src/NHibernate/SqlCommand/ANSIJoinFragment.cs +++ b/src/NHibernate/SqlCommand/ANSIJoinFragment.cs @@ -33,12 +33,21 @@ public override void AddJoin(string tableName, string alias, string[] fkColumns, case JoinType.FullJoin: joinString = " full outer join "; break; + case JoinType.CrossJoin: + joinString = " cross join "; + break; default: throw new AssertionFailure("undefined join type"); } - _fromFragment.Add(joinString + tableName + ' ' + alias + " on "); + _fromFragment.Add(joinString).Add(tableName).Add(" ").Add(alias).Add(" "); + if (joinType == JoinType.CrossJoin) + { + // Cross join does not have an 'on' statement + return; + } + _fromFragment.Add("on "); if (fkColumns.Length == 0) { AddBareCondition(_fromFragment, on); diff --git a/src/NHibernate/SqlCommand/JoinFragment.cs b/src/NHibernate/SqlCommand/JoinFragment.cs index 18b249bd57f..38c5d8ce280 100644 --- a/src/NHibernate/SqlCommand/JoinFragment.cs +++ b/src/NHibernate/SqlCommand/JoinFragment.cs @@ -10,7 +10,8 @@ public enum JoinType InnerJoin = 0, FullJoin = 4, LeftOuterJoin = 1, - RightOuterJoin = 2 + RightOuterJoin = 2, + CrossJoin = 8 } /// From 2eab4db5f876d7d6a718c118c60b01980064646c Mon Sep 17 00:00:00 2001 From: maca88 Date: Thu, 12 Mar 2020 21:09:52 +0100 Subject: [PATCH 02/10] Code review changes --- .../Async/Linq/ByMethod/JoinTests.cs | 49 ++++--- .../Async/Linq/LinqQuerySamples.cs | 125 +++++++++++++--- .../Linq/ByMethod/JoinTests.cs | 49 ++++--- src/NHibernate.Test/Linq/LinqQuerySamples.cs | 133 +++++++++++++++--- src/NHibernate.Test/TestCase.cs | 69 +++++++++ src/NHibernate/Hql/Ast/HqlTreeBuilder.cs | 5 + src/NHibernate/Hql/Ast/HqlTreeNode.cs | 16 +++ src/NHibernate/Linq/Clauses/NhJoinClause.cs | 2 + .../Linq/ReWriters/AddJoinsReWriter.cs | 19 +++ .../Linq/Visitors/QueryModelVisitor.cs | 52 +++++-- 10 files changed, 421 insertions(+), 98 deletions(-) diff --git a/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs b/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs index 4dc750f4119..86802cdb345 100644 --- a/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs +++ b/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs @@ -8,7 +8,13 @@ //------------------------------------------------------------------------------ +using System; using System.Linq; +using System.Reflection; +using NHibernate.Cfg; +using NHibernate.Engine.Query; +using NHibernate.Util; +using NSubstitute; using NUnit.Framework; using NHibernate.Linq; @@ -32,30 +38,31 @@ public async Task MultipleLinqJoinsWithSameProjectionNamesAsync() Assert.IsTrue(orders.All(x => x.FirstId == x.SecondId - 1 && x.SecondId == x.ThirdId - 1)); } - [Test] - public async Task CrossJoinWithPredicateInOnStatementAsync() + [TestCase(false)] + [TestCase(true)] + public async Task CrossJoinWithPredicateInWhereStatementAsync(bool useCrossJoin) { - var result = - await ((from o in db.Orders - from p in db.Products - join d in db.OrderLines - on new { o.OrderId, p.ProductId } equals new { d.Order.OrderId, d.Product.ProductId } - into details - from d in details - select new { o.OrderId, p.ProductId, d.UnitPrice }).Take(10).ToListAsync()); - - Assert.That(result.Count, Is.EqualTo(10)); - } + if (useCrossJoin && !Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } - [Test] - public async Task CrossJoinWithPredicateInWhereStatementAsync() - { - var result = await ((from o in db.Orders - from o2 in db.Orders.Where(x => x.Freight > 50) - where (o.OrderId == o2.OrderId + 1) || (o.OrderId == o2.OrderId - 1) - select new { o.OrderId, OrderId2 = o2.OrderId }).ToListAsync()); + using (var substitute = SubstituteDialect()) + using (var sqlSpy = new SqlLogSpy()) + { + ClearQueryPlanCache(); + substitute.Value.SupportsCrossJoin.Returns(useCrossJoin); + + var result = await ((from o in db.Orders + from o2 in db.Orders.Where(x => x.Freight > 50) + where (o.OrderId == o2.OrderId + 1) || (o.OrderId == o2.OrderId - 1) + select new { o.OrderId, OrderId2 = o2.OrderId }).ToListAsync()); - Assert.That(result.Count, Is.EqualTo(720)); + var sql = sqlSpy.GetWholeLog(); + Assert.That(result.Count, Is.EqualTo(720)); + Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(useCrossJoin ? 0 : 1)); + } } } } diff --git a/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs b/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs index 44fd7ef3cd8..b61f29330d5 100644 --- a/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs +++ b/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs @@ -9,9 +9,11 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Linq; using NHibernate.DomainModel.Northwind.Entities; +using NSubstitute; using NUnit.Framework; using NHibernate.Linq; @@ -757,7 +759,13 @@ from o in c.Orders where c.Address.City == "London" select o; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -863,7 +871,13 @@ from p in db.Products where p.Supplier.Address.Country == "USA" && p.UnitsInStock == 0 select p; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -879,7 +893,16 @@ from et in e.Territories where e.Address.City == "Seattle" select new {e.FirstName, e.LastName, et.Region.Description}; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + // EmployeeTerritories and Territories + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2)); + // Region + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -903,7 +926,13 @@ from e2 in e1.Subordinates e1.Address.City }; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -918,7 +947,13 @@ from c in db.Customers join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into orders select new {c.ContactName, OrderCount = orders.Average(x => x.Freight)}; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "join"), Is.EqualTo(0)); + } } [Category("JOIN")] @@ -959,15 +994,32 @@ from c in db.Customers } [Category("JOIN")] - [Test(Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] - public async Task DLinqJoin5dAsync() + [TestCase(true, Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] + [TestCase(false, Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] + public async Task DLinqJoin5dAsync(bool useCrossJoin) { + if (useCrossJoin && !Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } + var q = from c in db.Customers join o in db.Orders on new {c.CustomerId, HasContractTitle = c.ContactTitle != null} equals new {o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } select new { c.ContactName, o.OrderId }; - await (ObjectDumper.WriteAsync(q)); + using (var substitute = SubstituteDialect()) + using (var sqlSpy = new SqlLogSpy()) + { + ClearQueryPlanCache(); + substitute.Value.SupportsCrossJoin.Returns(useCrossJoin); + + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -983,7 +1035,13 @@ join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into ords join e in db.Employees on c.Address.City equals e.Address.City into emps select new {c.ContactName, ords = ords.Count(), emps = emps.Count()}; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "join"), Is.EqualTo(0)); + } } [Category("JOIN")] @@ -997,14 +1055,27 @@ join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into ords from o in ords select new {c.ContactName, o.OrderId, z}; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] - [Test(Description = "This sample shows a group join with a composite key.")] - public async Task DLinqJoin9Async() + [TestCase(true, Description = "This sample shows a group join with a composite key.")] + [TestCase(false, Description = "This sample shows a group join with a composite key.")] + public async Task DLinqJoin9Async(bool useCrossJoin) { - var expected = + if (useCrossJoin && !Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } + + ICollection expected, actual; + expected = (from o in db.Orders.ToList() from p in db.Products.ToList() join d in db.OrderLines.ToList() @@ -1013,14 +1084,26 @@ into details from d in details select new {o.OrderId, p.ProductId, d.UnitPrice}).ToList(); - var actual = - await ((from o in db.Orders - from p in db.Products - join d in db.OrderLines - on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} - into details - from d in details - select new {o.OrderId, p.ProductId, d.UnitPrice}).ToListAsync()); + using (var substitute = SubstituteDialect()) + using (var sqlSpy = new SqlLogSpy()) + { + ClearQueryPlanCache(); + substitute.Value.SupportsCrossJoin.Returns(useCrossJoin); + + actual = + await ((from o in db.Orders + from p in db.Products + join d in db.OrderLines + on new { o.OrderId, p.ProductId } equals new { d.Order.OrderId, d.Product.ProductId } + into details + from d in details + select new { o.OrderId, p.ProductId, d.UnitPrice }).ToListAsync()); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(actual.Count, Is.EqualTo(2155)); + Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(useCrossJoin ? 1 : 2)); + } Assert.AreEqual(expected.Count, actual.Count); } diff --git a/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs b/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs index d57bb4cdedf..a013014cae8 100644 --- a/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs @@ -1,4 +1,10 @@ -using System.Linq; +using System; +using System.Linq; +using System.Reflection; +using NHibernate.Cfg; +using NHibernate.Engine.Query; +using NHibernate.Util; +using NSubstitute; using NUnit.Framework; namespace NHibernate.Test.Linq.ByMethod @@ -20,30 +26,31 @@ public void MultipleLinqJoinsWithSameProjectionNames() Assert.IsTrue(orders.All(x => x.FirstId == x.SecondId - 1 && x.SecondId == x.ThirdId - 1)); } - [Test] - public void CrossJoinWithPredicateInOnStatement() + [TestCase(false)] + [TestCase(true)] + public void CrossJoinWithPredicateInWhereStatement(bool useCrossJoin) { - var result = - (from o in db.Orders - from p in db.Products - join d in db.OrderLines - on new { o.OrderId, p.ProductId } equals new { d.Order.OrderId, d.Product.ProductId } - into details - from d in details - select new { o.OrderId, p.ProductId, d.UnitPrice }).Take(10).ToList(); + if (useCrossJoin && !Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } - Assert.That(result.Count, Is.EqualTo(10)); - } + using (var substitute = SubstituteDialect()) + using (var sqlSpy = new SqlLogSpy()) + { + ClearQueryPlanCache(); + substitute.Value.SupportsCrossJoin.Returns(useCrossJoin); - [Test] - public void CrossJoinWithPredicateInWhereStatement() - { - var result = (from o in db.Orders - from o2 in db.Orders.Where(x => x.Freight > 50) - where (o.OrderId == o2.OrderId + 1) || (o.OrderId == o2.OrderId - 1) - select new { o.OrderId, OrderId2 = o2.OrderId }).ToList(); + var result = (from o in db.Orders + from o2 in db.Orders.Where(x => x.Freight > 50) + where (o.OrderId == o2.OrderId + 1) || (o.OrderId == o2.OrderId - 1) + select new { o.OrderId, OrderId2 = o2.OrderId }).ToList(); - Assert.That(result.Count, Is.EqualTo(720)); + var sql = sqlSpy.GetWholeLog(); + Assert.That(result.Count, Is.EqualTo(720)); + Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(useCrossJoin ? 0 : 1)); + } } } } diff --git a/src/NHibernate.Test/Linq/LinqQuerySamples.cs b/src/NHibernate.Test/Linq/LinqQuerySamples.cs index efe18a3caad..3dd7c1d3080 100755 --- a/src/NHibernate.Test/Linq/LinqQuerySamples.cs +++ b/src/NHibernate.Test/Linq/LinqQuerySamples.cs @@ -1,7 +1,9 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Linq; using NHibernate.DomainModel.Northwind.Entities; +using NSubstitute; using NUnit.Framework; namespace NHibernate.Test.Linq @@ -1301,7 +1303,13 @@ from o in c.Orders where c.Address.City == "London" select o; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -1407,7 +1415,13 @@ from p in db.Products where p.Supplier.Address.Country == "USA" && p.UnitsInStock == 0 select p; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -1423,7 +1437,16 @@ from et in e.Territories where e.Address.City == "Seattle" select new {e.FirstName, e.LastName, et.Region.Description}; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + // EmployeeTerritories and Territories + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2)); + // Region + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -1447,7 +1470,13 @@ from e2 in e1.Subordinates e1.Address.City }; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -1462,7 +1491,13 @@ from c in db.Customers join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into orders select new {c.ContactName, OrderCount = orders.Average(x => x.Freight)}; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "join"), Is.EqualTo(0)); + } } [Category("JOIN")] @@ -1503,15 +1538,32 @@ from c in db.Customers } [Category("JOIN")] - [Test(Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] - public void DLinqJoin5d() + [TestCase(true, Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] + [TestCase(false, Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] + public void DLinqJoin5d(bool useCrossJoin) { + if (useCrossJoin && !Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } + var q = from c in db.Customers join o in db.Orders on new {c.CustomerId, HasContractTitle = c.ContactTitle != null} equals new {o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } select new { c.ContactName, o.OrderId }; - ObjectDumper.Write(q); + using (var substitute = SubstituteDialect()) + using (var sqlSpy = new SqlLogSpy()) + { + ClearQueryPlanCache(); + substitute.Value.SupportsCrossJoin.Returns(useCrossJoin); + + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -1527,7 +1579,13 @@ join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into ords join e in db.Employees on c.Address.City equals e.Address.City into emps select new {c.ContactName, ords = ords.Count(), emps = emps.Count()}; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "join"), Is.EqualTo(0)); + } } [Category("JOIN")] @@ -1544,7 +1602,13 @@ join o in db.Orders on e equals o.Employee into ords from o in ords.DefaultIfEmpty() select new {e.FirstName, e.LastName, Order = o}; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -1558,14 +1622,27 @@ join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into ords from o in ords select new {c.ContactName, o.OrderId, z}; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] - [Test(Description = "This sample shows a group join with a composite key.")] - public void DLinqJoin9() + [TestCase(true, Description = "This sample shows a group join with a composite key.")] + [TestCase(false, Description = "This sample shows a group join with a composite key.")] + public void DLinqJoin9(bool useCrossJoin) { - var expected = + if (useCrossJoin && !Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } + + ICollection expected, actual; + expected = (from o in db.Orders.ToList() from p in db.Products.ToList() join d in db.OrderLines.ToList() @@ -1574,14 +1651,26 @@ into details from d in details select new {o.OrderId, p.ProductId, d.UnitPrice}).ToList(); - var actual = - (from o in db.Orders - from p in db.Products - join d in db.OrderLines - on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} - into details - from d in details - select new {o.OrderId, p.ProductId, d.UnitPrice}).ToList(); + using (var substitute = SubstituteDialect()) + using (var sqlSpy = new SqlLogSpy()) + { + ClearQueryPlanCache(); + substitute.Value.SupportsCrossJoin.Returns(useCrossJoin); + + actual = + (from o in db.Orders + from p in db.Products + join d in db.OrderLines + on new { o.OrderId, p.ProductId } equals new { d.Order.OrderId, d.Product.ProductId } + into details + from d in details + select new { o.OrderId, p.ProductId, d.UnitPrice }).ToList(); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(actual.Count, Is.EqualTo(2155)); + Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(useCrossJoin ? 1 : 2)); + } Assert.AreEqual(expected.Count, actual.Count); } diff --git a/src/NHibernate.Test/TestCase.cs b/src/NHibernate.Test/TestCase.cs index 4c9ea610c6d..486720300c5 100644 --- a/src/NHibernate.Test/TestCase.cs +++ b/src/NHibernate.Test/TestCase.cs @@ -16,6 +16,9 @@ using System.Text; using NHibernate.Dialect; using NHibernate.Driver; +using NHibernate.Engine.Query; +using NHibernate.Util; +using NSubstitute; namespace NHibernate.Test { @@ -477,6 +480,72 @@ protected void AssumeFunctionSupported(string functionName) $"{Dialect} doesn't support {functionName} standard function."); } + protected void ClearQueryPlanCache() + { + var planCacheField = typeof(QueryPlanCache) + .GetField("planCache", BindingFlags.NonPublic | BindingFlags.Instance) + ?? throw new InvalidOperationException("planCache field does not exist in QueryPlanCache."); + + var planCache = (SoftLimitMRUCache) planCacheField.GetValue(Sfi.QueryPlanCache); + planCache.Clear(); + } + + protected Substitute SubstituteDialect() + { + var origDialect = Sfi.Settings.Dialect; + var dialectProperty = (PropertyInfo) ReflectHelper.GetProperty(o => o.Dialect); + var forPartsOfMethod = ReflectHelper.GetMethodDefinition(() => Substitute.ForPartsOf()); + var substitute = (Dialect.Dialect) forPartsOfMethod.MakeGenericMethod(origDialect.GetType()) + .Invoke(null, new object[] { new object[0] }); + + dialectProperty.SetValue(Sfi.Settings, substitute); + + return new Substitute(substitute, Dispose); + + void Dispose() + { + dialectProperty.SetValue(Sfi.Settings, origDialect); + } + } + + protected static int GetTotalOccurrences(string content, string substring) + { + if (string.IsNullOrEmpty(substring)) + { + throw new ArgumentNullException(nameof(substring)); + } + + int occurrences = 0; + for (var index = 0; ; index += substring.Length) + { + index = content.IndexOf(substring, index); + if (index == -1) + { + return occurrences; + } + + occurrences++; + } + } + + protected struct Substitute : IDisposable + { + private readonly System.Action _disposeAction; + + public Substitute(TType value, System.Action disposeAction) + { + Value = value; + _disposeAction = disposeAction; + } + + public TType Value { get; } + + public void Dispose() + { + _disposeAction(); + } + } + #endregion } } diff --git a/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs b/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs index 58d54af8804..bb208295afc 100755 --- a/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs +++ b/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs @@ -478,6 +478,11 @@ public HqlIn In(HqlExpression itemExpression, HqlTreeNode source) return new HqlIn(_factory, itemExpression, source); } + public HqlInnerJoin InnerJoin(HqlExpression expression, HqlAlias @alias) + { + return new HqlInnerJoin(_factory, expression, @alias); + } + public HqlLeftJoin LeftJoin(HqlExpression expression, HqlAlias @alias) { return new HqlLeftJoin(_factory, expression, @alias); diff --git a/src/NHibernate/Hql/Ast/HqlTreeNode.cs b/src/NHibernate/Hql/Ast/HqlTreeNode.cs index 245b82db375..160739d7f16 100755 --- a/src/NHibernate/Hql/Ast/HqlTreeNode.cs +++ b/src/NHibernate/Hql/Ast/HqlTreeNode.cs @@ -844,6 +844,14 @@ public HqlJoin(IASTFactory factory, HqlExpression expression, HqlAlias @alias) : } } + public class HqlInnerJoin : HqlTreeNode + { + public HqlInnerJoin(IASTFactory factory, HqlExpression expression, HqlAlias alias) + : base(HqlSqlWalker.JOIN, "join", factory, new HqlInner(factory), expression, alias) + { + } + } + public class HqlLeftJoin : HqlTreeNode { public HqlLeftJoin(IASTFactory factory, HqlExpression expression, HqlAlias @alias) : base(HqlSqlWalker.JOIN, "join", factory, new HqlLeft(factory), expression, @alias) @@ -905,6 +913,14 @@ public HqlBitwiseAnd(IASTFactory factory, HqlExpression lhs, HqlExpression rhs) } } + public class HqlInner : HqlTreeNode + { + public HqlInner(IASTFactory factory) + : base(HqlSqlWalker.INNER, "inner", factory) + { + } + } + public class HqlLeft : HqlTreeNode { public HqlLeft(IASTFactory factory) diff --git a/src/NHibernate/Linq/Clauses/NhJoinClause.cs b/src/NHibernate/Linq/Clauses/NhJoinClause.cs index d68926f88b3..0df0afd2b32 100644 --- a/src/NHibernate/Linq/Clauses/NhJoinClause.cs +++ b/src/NHibernate/Linq/Clauses/NhJoinClause.cs @@ -54,6 +54,8 @@ public NhJoinClause(string itemName, System.Type itemType, Expression fromExpres public bool IsInner { get; private set; } + internal IBodyClause RelatedBodyClause { get; set; } + public void TransformExpressions(Func transformation) { if (transformation == null) throw new ArgumentNullException(nameof(transformation)); diff --git a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs index 576bb65a9ea..0a58fc85770 100644 --- a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs +++ b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Specialized; using System.Linq; using NHibernate.Engine; using NHibernate.Linq.Clauses; @@ -59,6 +60,24 @@ public override void VisitNhHavingClause(NhHavingClause havingClause, QueryModel _whereJoinDetector.Transform(havingClause); } + public override void VisitJoinClause(JoinClause joinClause, QueryModel queryModel, int index) + { + // When there are association navigations inside an on clause (e.g. c.ContactTitle equals o.Customer.ContactTitle), + // we have to move the condition to the where statement, otherwise the query will be invalid. + // Link newly created joins with the current join clause in order to later detect which join type to use. + queryModel.BodyClauses.CollectionChanged += OnCollectionChange; + _whereJoinDetector.Transform(joinClause); + queryModel.BodyClauses.CollectionChanged -= OnCollectionChange; + + void OnCollectionChange(object sender, NotifyCollectionChangedEventArgs e) + { + foreach (var nhJoinClause in e.NewItems.OfType()) + { + nhJoinClause.RelatedBodyClause = joinClause; + } + } + } + public bool IsEntity(System.Type type) { return _sessionFactory.GetImplementors(type.FullName).Any(); diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index d2f9f892c54..04a1e4afa38 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Linq.Expressions; using System.Reflection; using NHibernate.Hql.Ast; @@ -17,6 +18,8 @@ using Remotion.Linq.Clauses.ResultOperators; using Remotion.Linq.Clauses.StreamedData; using Remotion.Linq.EagerFetching; +using OrderByClause = Remotion.Linq.Clauses.OrderByClause; +using SelectClause = Remotion.Linq.Clauses.SelectClause; namespace NHibernate.Linq.Visitors { @@ -324,11 +327,7 @@ public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, } else { - var join = VisitorParameters.SessionFactory.Dialect.SupportsCrossJoin - ? _hqlTree.TreeBuilder.CrossJoin(fromExpressionTree.AsExpression(), alias) - : (HqlTreeNode) _hqlTree.TreeBuilder.Range(fromExpressionTree, alias); - - _hqlTree.AddFromClause(join); + _hqlTree.AddFromClause(CreateCrossJoin(fromExpressionTree, alias)); } base.VisitAdditionalFromClause(fromClause, queryModel, index); @@ -515,15 +514,25 @@ public override void VisitOrderByClause(OrderByClause orderByClause, QueryModel public override void VisitJoinClause(JoinClause joinClause, QueryModel queryModel, int index) { var equalityVisitor = new EqualityHqlGenerator(VisitorParameters); - var whereClause = equalityVisitor.Visit(joinClause.InnerKeySelector, joinClause.OuterKeySelector); - var querySourceName = VisitorParameters.QuerySourceNamer.GetName(joinClause); - - _hqlTree.AddWhereClause(whereClause); + var withClause = equalityVisitor.Visit(joinClause.InnerKeySelector, joinClause.OuterKeySelector); + var alias = _hqlTree.TreeBuilder.Alias(VisitorParameters.QuerySourceNamer.GetName(joinClause)); + var joinExpression = HqlGeneratorExpressionVisitor.Visit(joinClause.InnerSequence, VisitorParameters); + HqlTreeNode join; + // When there are association navigations inside an on clause: + // from c in db.Customers join o in db.Orders on c.ContactTitle equals o.Customer.ContactTitle + // we have to use a cross join instead of inner join and add the condition in the where statement. + if (queryModel.BodyClauses.OfType().Any(o => o.RelatedBodyClause == joinClause)) + { + _hqlTree.AddWhereClause(withClause); + join = CreateCrossJoin(joinExpression, alias); + } + else + { + join = _hqlTree.TreeBuilder.InnerJoin(joinExpression.AsExpression(), alias); + join.AddChild(_hqlTree.TreeBuilder.With(withClause)); + } - _hqlTree.AddFromClause( - _hqlTree.TreeBuilder.Range( - HqlGeneratorExpressionVisitor.Visit(joinClause.InnerSequence, VisitorParameters), - _hqlTree.TreeBuilder.Alias(querySourceName))); + _hqlTree.AddFromClause(join); } public override void VisitGroupJoinClause(GroupJoinClause groupJoinClause, QueryModel queryModel, int index) @@ -550,5 +559,22 @@ public override void VisitNhWithClause(NhWithClause withClause, QueryModel query var expression = HqlGeneratorExpressionVisitor.Visit(withClause.Predicate, VisitorParameters).ToBooleanExpression(); _hqlTree.AddWhereClause(expression); } + + private HqlTreeNode CreateCrossJoin(HqlTreeNode joinExpression, HqlAlias alias) + { + if (VisitorParameters.SessionFactory.Dialect.SupportsCrossJoin) + { + return _hqlTree.TreeBuilder.CrossJoin(joinExpression.AsExpression(), alias); + } + + // Simulate cross join as a inner join on 1=1 + var join = _hqlTree.TreeBuilder.InnerJoin(joinExpression.AsExpression(), alias); + var onExpression = _hqlTree.TreeBuilder.Equality( + _hqlTree.TreeBuilder.True(), + _hqlTree.TreeBuilder.True()); + join.AddChild(_hqlTree.TreeBuilder.With(onExpression)); + + return join; + } } } From 000137a3884e468d77724de18b30a7d6babb4974 Mon Sep 17 00:00:00 2001 From: maca88 Date: Thu, 12 Mar 2020 21:29:50 +0100 Subject: [PATCH 03/10] Fix CodeFactor issue --- src/NHibernate.Test/TestCase.cs | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/NHibernate.Test/TestCase.cs b/src/NHibernate.Test/TestCase.cs index 486720300c5..0d59b989cf4 100644 --- a/src/NHibernate.Test/TestCase.cs +++ b/src/NHibernate.Test/TestCase.cs @@ -515,17 +515,14 @@ protected static int GetTotalOccurrences(string content, string substring) throw new ArgumentNullException(nameof(substring)); } - int occurrences = 0; - for (var index = 0; ; index += substring.Length) + int occurrences = 0, index = 0; + while ((index = content.IndexOf(substring, index)) >= 0) { - index = content.IndexOf(substring, index); - if (index == -1) - { - return occurrences; - } - occurrences++; + index += substring.Length; } + + return occurrences; } protected struct Substitute : IDisposable From b17820aaaff34df722fd1820e7a3e31193ea0a70 Mon Sep 17 00:00:00 2001 From: maca88 Date: Thu, 12 Mar 2020 22:27:00 +0100 Subject: [PATCH 04/10] Add test for #1128 --- .../Async/Linq/LinqQuerySamples.cs | 19 +++++++++++++++++++ src/NHibernate.Test/Linq/LinqQuerySamples.cs | 19 +++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs b/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs index b61f29330d5..80d443967e9 100644 --- a/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs +++ b/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs @@ -1022,6 +1022,25 @@ from c in db.Customers } } + [Category("JOIN")] + [Test(Description = "This sample joins two tables and projects results from the first table.")] + public async Task DLinqJoin5eAsync() + { + var q = + from c in db.Customers + join o in db.Orders on c.CustomerId equals o.Customer.CustomerId + where c.ContactName != null + select o; + + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } + } + [Category("JOIN")] [Test(Description = "This sample explictly joins three tables and projects results from each of them.")] public async Task DLinqJoin6Async() diff --git a/src/NHibernate.Test/Linq/LinqQuerySamples.cs b/src/NHibernate.Test/Linq/LinqQuerySamples.cs index 3dd7c1d3080..234ca67ee53 100755 --- a/src/NHibernate.Test/Linq/LinqQuerySamples.cs +++ b/src/NHibernate.Test/Linq/LinqQuerySamples.cs @@ -1566,6 +1566,25 @@ from c in db.Customers } } + [Category("JOIN")] + [Test(Description = "This sample joins two tables and projects results from the first table.")] + public void DLinqJoin5e() + { + var q = + from c in db.Customers + join o in db.Orders on c.CustomerId equals o.Customer.CustomerId + where c.ContactName != null + select o; + + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } + } + [Category("JOIN")] [Test(Description = "This sample explictly joins three tables and projects results from each of them.")] public void DLinqJoin6() From b57e654a002b5147dd4b97c97ec769e0e8f933ca Mon Sep 17 00:00:00 2001 From: maca88 Date: Sun, 15 Mar 2020 00:43:50 +0100 Subject: [PATCH 05/10] Avoid cross join for associations inside an outer key selector --- .../Async/Linq/LinqQuerySamples.cs | 52 ++++++++++++++++++- src/NHibernate.Test/Linq/LinqQuerySamples.cs | 52 ++++++++++++++++++- src/NHibernate/Linq/Clauses/NhJoinClause.cs | 2 +- .../Linq/ReWriters/AddJoinsReWriter.cs | 47 ++++++++++++----- src/NHibernate/Linq/Visitors/JoinBuilder.cs | 18 ++++++- .../Linq/Visitors/QueryModelVisitor.cs | 7 ++- .../Linq/Visitors/WhereJoinDetector.cs | 11 ++++ 7 files changed, 165 insertions(+), 24 deletions(-) diff --git a/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs b/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs index 80d443967e9..3058db20e83 100644 --- a/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs +++ b/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs @@ -965,7 +965,13 @@ from c in db.Customers join o in db.Orders on c.CustomerId equals o.Customer.CustomerId select new { c.ContactName, o.OrderId }; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -1005,7 +1011,9 @@ public async Task DLinqJoin5dAsync(bool useCrossJoin) var q = from c in db.Customers - join o in db.Orders on new {c.CustomerId, HasContractTitle = c.ContactTitle != null} equals new {o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } + join o in db.Orders on + new {c.CustomerId, HasContractTitle = c.ContactTitle != null} equals + new {o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } select new { c.ContactName, o.OrderId }; using (var substitute = SubstituteDialect()) @@ -1041,6 +1049,27 @@ join o in db.Orders on c.CustomerId equals o.Customer.CustomerId } } + [Category("JOIN")] + [TestCase(Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] + public async Task DLinqJoin5fAsync() + { + var q = + from o in db.Orders + join c in db.Customers on + new { o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } equals + new { c.CustomerId, HasContractTitle = c.ContactTitle != null } + select new { c.ContactName, o.OrderId }; + + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } + } + [Category("JOIN")] [Test(Description = "This sample explictly joins three tables and projects results from each of them.")] public async Task DLinqJoin6Async() @@ -1138,5 +1167,24 @@ group o by c into x await (ObjectDumper.WriteAsync(q)); } + + [Category("JOIN")] + [Test(Description = "This sample shows how to join multiple tables.")] + public async Task DLinqJoin10aAsync() + { + var q = + from e in db.Employees + join s in db.Employees on e.Superior.EmployeeId equals s.EmployeeId + join s2 in db.Employees on s.Superior.EmployeeId equals s2.EmployeeId + select new { e.FirstName, SuperiorName = s.FirstName, Superior2Name = s2.FirstName }; + + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2)); + } + } } } diff --git a/src/NHibernate.Test/Linq/LinqQuerySamples.cs b/src/NHibernate.Test/Linq/LinqQuerySamples.cs index 234ca67ee53..848b0057ce1 100755 --- a/src/NHibernate.Test/Linq/LinqQuerySamples.cs +++ b/src/NHibernate.Test/Linq/LinqQuerySamples.cs @@ -1509,7 +1509,13 @@ from c in db.Customers join o in db.Orders on c.CustomerId equals o.Customer.CustomerId select new { c.ContactName, o.OrderId }; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -1549,7 +1555,9 @@ public void DLinqJoin5d(bool useCrossJoin) var q = from c in db.Customers - join o in db.Orders on new {c.CustomerId, HasContractTitle = c.ContactTitle != null} equals new {o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } + join o in db.Orders on + new {c.CustomerId, HasContractTitle = c.ContactTitle != null} equals + new {o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } select new { c.ContactName, o.OrderId }; using (var substitute = SubstituteDialect()) @@ -1585,6 +1593,27 @@ join o in db.Orders on c.CustomerId equals o.Customer.CustomerId } } + [Category("JOIN")] + [TestCase(Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] + public void DLinqJoin5f() + { + var q = + from o in db.Orders + join c in db.Customers on + new { o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } equals + new { c.CustomerId, HasContractTitle = c.ContactTitle != null } + select new { c.ContactName, o.OrderId }; + + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } + } + [Category("JOIN")] [Test(Description = "This sample explictly joins three tables and projects results from each of them.")] public void DLinqJoin6() @@ -1706,6 +1735,25 @@ group o by c into x ObjectDumper.Write(q); } + [Category("JOIN")] + [Test(Description = "This sample shows how to join multiple tables.")] + public void DLinqJoin10a() + { + var q = + from e in db.Employees + join s in db.Employees on e.Superior.EmployeeId equals s.EmployeeId + join s2 in db.Employees on s.Superior.EmployeeId equals s2.EmployeeId + select new { e.FirstName, SuperiorName = s.FirstName, Superior2Name = s2.FirstName }; + + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2)); + } + } + [Category("WHERE")] [Test(Description = "This sample uses WHERE to filter for orders with shipping date equals to null.")] public void DLinq2B() diff --git a/src/NHibernate/Linq/Clauses/NhJoinClause.cs b/src/NHibernate/Linq/Clauses/NhJoinClause.cs index 0df0afd2b32..3a1c17147aa 100644 --- a/src/NHibernate/Linq/Clauses/NhJoinClause.cs +++ b/src/NHibernate/Linq/Clauses/NhJoinClause.cs @@ -54,7 +54,7 @@ public NhJoinClause(string itemName, System.Type itemType, Expression fromExpres public bool IsInner { get; private set; } - internal IBodyClause RelatedBodyClause { get; set; } + internal JoinClause ParentJoinClause { get; set; } public void TransformExpressions(Func transformation) { diff --git a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs index 0a58fc85770..44e40652e89 100644 --- a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs +++ b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs @@ -20,11 +20,13 @@ public class AddJoinsReWriter : NhQueryModelVisitorBase, IIsEntityDecider private readonly ISessionFactoryImplementor _sessionFactory; private readonly MemberExpressionJoinDetector _memberExpressionJoinDetector; private readonly WhereJoinDetector _whereJoinDetector; + private int? _joinInsertIndex; + private JoinClause _currentJoin; private AddJoinsReWriter(ISessionFactoryImplementor sessionFactory, QueryModel queryModel) { _sessionFactory = sessionFactory; - var joiner = new Joiner(queryModel); + var joiner = new Joiner(queryModel, AddJoin); _memberExpressionJoinDetector = new MemberExpressionJoinDetector(this, joiner); _whereJoinDetector = new WhereJoinDetector(this, joiner); } @@ -62,20 +64,25 @@ public override void VisitNhHavingClause(NhHavingClause havingClause, QueryModel public override void VisitJoinClause(JoinClause joinClause, QueryModel queryModel, int index) { - // When there are association navigations inside an on clause (e.g. c.ContactTitle equals o.Customer.ContactTitle), + VisitJoinClause(joinClause, queryModel, joinClause); + } + + private void VisitJoinClause(JoinClause joinClause, QueryModel queryModel, IBodyClause bodyClause) + { + joinClause.InnerSequence = _whereJoinDetector.Transform(joinClause.InnerSequence); + + // When associations are located in the outer key (e.g. from a in A join b in B b on a.C.D.Id equals b.Id), + // we have to insert the association join before the current join in order to produce a valid query. + _joinInsertIndex = queryModel.BodyClauses.IndexOf(bodyClause); + joinClause.OuterKeySelector = _whereJoinDetector.Transform(joinClause.OuterKeySelector); + _joinInsertIndex = null; + + // When associations are located in the inner key (e.g. from a in A join b in B b on a.Id equals b.C.D.Id), // we have to move the condition to the where statement, otherwise the query will be invalid. // Link newly created joins with the current join clause in order to later detect which join type to use. - queryModel.BodyClauses.CollectionChanged += OnCollectionChange; - _whereJoinDetector.Transform(joinClause); - queryModel.BodyClauses.CollectionChanged -= OnCollectionChange; - - void OnCollectionChange(object sender, NotifyCollectionChangedEventArgs e) - { - foreach (var nhJoinClause in e.NewItems.OfType()) - { - nhJoinClause.RelatedBodyClause = joinClause; - } - } + _currentJoin = joinClause; + joinClause.InnerKeySelector = _whereJoinDetector.Transform(joinClause.InnerKeySelector); + _currentJoin = null; } public bool IsEntity(System.Type type) @@ -88,5 +95,19 @@ public bool IsIdentifier(System.Type type, string propertyName) var metadata = _sessionFactory.GetClassMetadata(type); return metadata != null && propertyName.Equals(metadata.IdentifierPropertyName); } + + private void AddJoin(QueryModel queryModel, NhJoinClause joinClause) + { + joinClause.ParentJoinClause = _currentJoin; + if (_joinInsertIndex.HasValue) + { + queryModel.BodyClauses.Insert(_joinInsertIndex.Value, joinClause); + _joinInsertIndex++; + } + else + { + queryModel.BodyClauses.Add(joinClause); + } + } } } diff --git a/src/NHibernate/Linq/Visitors/JoinBuilder.cs b/src/NHibernate/Linq/Visitors/JoinBuilder.cs index b84a72c7ac3..dc7bc395c13 100644 --- a/src/NHibernate/Linq/Visitors/JoinBuilder.cs +++ b/src/NHibernate/Linq/Visitors/JoinBuilder.cs @@ -21,11 +21,20 @@ public class Joiner : IJoiner private readonly NameGenerator _nameGenerator; private readonly QueryModel _queryModel; + internal Joiner(QueryModel queryModel, System.Action addJoinMethod) + : this(queryModel) + { + AddJoinMethod = addJoinMethod; + } + internal Joiner(QueryModel queryModel) { _nameGenerator = new NameGenerator(queryModel); _queryModel = queryModel; + AddJoinMethod = AddJoin; } + + internal System.Action AddJoinMethod { get; } public IEnumerable Joins { @@ -39,7 +48,7 @@ public Expression AddJoin(Expression expression, string key) if (!_joins.TryGetValue(key, out join)) { join = new NhJoinClause(_nameGenerator.GetNewName(), expression.Type, expression); - _queryModel.BodyClauses.Add(join); + AddJoinMethod(_queryModel, join); _joins.Add(key, join); } @@ -72,6 +81,11 @@ public bool CanAddJoin(Expression expression) return resultOperatorBase != null && _queryModel.ResultOperators.Contains(resultOperatorBase); } + private void AddJoin(QueryModel queryModel, NhJoinClause joinClause) + { + queryModel.BodyClauses.Add(joinClause); + } + private class QuerySourceExtractor : RelinqExpressionVisitor { private IQuerySource _querySource; @@ -90,4 +104,4 @@ protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpr } } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index 04a1e4afa38..a487a1281ce 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -518,10 +518,9 @@ public override void VisitJoinClause(JoinClause joinClause, QueryModel queryMode var alias = _hqlTree.TreeBuilder.Alias(VisitorParameters.QuerySourceNamer.GetName(joinClause)); var joinExpression = HqlGeneratorExpressionVisitor.Visit(joinClause.InnerSequence, VisitorParameters); HqlTreeNode join; - // When there are association navigations inside an on clause: - // from c in db.Customers join o in db.Orders on c.ContactTitle equals o.Customer.ContactTitle - // we have to use a cross join instead of inner join and add the condition in the where statement. - if (queryModel.BodyClauses.OfType().Any(o => o.RelatedBodyClause == joinClause)) + // When associations are located inside the inner key selector we have to use a cross join instead of an inner + // join and add the condition in the where statement. + if (queryModel.BodyClauses.OfType().Any(o => o.ParentJoinClause == joinClause)) { _hqlTree.AddWhereClause(withClause); join = CreateCrossJoin(joinExpression, alias); diff --git a/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs b/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs index 66508f01eef..68f88c6cc55 100644 --- a/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs +++ b/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs @@ -77,10 +77,21 @@ internal WhereJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner) _joiner = joiner; } + public Expression Transform(Expression expression) + { + var result = Visit(expression); + PostTransform(); + return result; + } + public void Transform(IClause whereClause) { whereClause.TransformExpressions(Visit); + PostTransform(); + } + private void PostTransform() + { var values = _values.Pop(); foreach (var memberExpression in values.MemberExpressions) From 4cee21715c26e8427a911f416e7c787041df7a00 Mon Sep 17 00:00:00 2001 From: maca88 Date: Sun, 15 Mar 2020 01:33:39 +0100 Subject: [PATCH 06/10] Add property fetch support for cross join --- .../FetchLazyPropertiesFixture.cs | 33 +++++++++++++++++++ .../FetchLazyPropertiesFixture.cs | 33 +++++++++++++++++++ src/NHibernate/Hql/Ast/ANTLR/Hql.g | 2 +- 3 files changed, 67 insertions(+), 1 deletion(-) diff --git a/src/NHibernate.Test/Async/FetchLazyProperties/FetchLazyPropertiesFixture.cs b/src/NHibernate.Test/Async/FetchLazyProperties/FetchLazyPropertiesFixture.cs index a971ec6d132..9cff3809b59 100644 --- a/src/NHibernate.Test/Async/FetchLazyProperties/FetchLazyPropertiesFixture.cs +++ b/src/NHibernate.Test/Async/FetchLazyProperties/FetchLazyPropertiesFixture.cs @@ -9,12 +9,14 @@ using System; +using System.Collections.Generic; using System.Linq; using NHibernate.Cache; using NHibernate.Cfg; using NHibernate.Hql.Ast.ANTLR; using NHibernate.Linq; using NUnit.Framework; +using NUnit.Framework.Constraints; using Environment = NHibernate.Cfg.Environment; namespace NHibernate.Test.FetchLazyProperties @@ -943,6 +945,37 @@ public async Task TestFetchAfterEntityIsInitializedAsync(bool readOnly) Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Formula"), Is.True); } + [Test] + public async Task TestHqlCrossJoinFetchFormulaAsync() + { + var persons = new List(); + var bestFriends = new List(); + using (var sqlSpy = new SqlLogSpy()) + using (var s = OpenSession()) + { + var list = await (s.CreateQuery("select p, bf from Person p cross join Person bf fetch bf.Formula where bf.Id = p.BestFriend.Id").ListAsync()); + foreach (var arr in list) + { + persons.Add((Person) arr[0]); + bestFriends.Add((Person) arr[1]); + } + } + + AssertPersons(persons, false); + AssertPersons(bestFriends, true); + + void AssertPersons(List results, bool fetched) + { + foreach (var person in results) + { + Assert.That(person, Is.Not.Null); + Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Image"), Is.False); + Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Address"), Is.False); + Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Formula"), fetched ? Is.True : (IResolveConstraint) Is.False); + } + } + } + private static Person GeneratePerson(int i, Person bestFriend) { return new Person diff --git a/src/NHibernate.Test/FetchLazyProperties/FetchLazyPropertiesFixture.cs b/src/NHibernate.Test/FetchLazyProperties/FetchLazyPropertiesFixture.cs index 04cb354f3f4..74823eba282 100644 --- a/src/NHibernate.Test/FetchLazyProperties/FetchLazyPropertiesFixture.cs +++ b/src/NHibernate.Test/FetchLazyProperties/FetchLazyPropertiesFixture.cs @@ -1,10 +1,12 @@ using System; +using System.Collections.Generic; using System.Linq; using NHibernate.Cache; using NHibernate.Cfg; using NHibernate.Hql.Ast.ANTLR; using NHibernate.Linq; using NUnit.Framework; +using NUnit.Framework.Constraints; using Environment = NHibernate.Cfg.Environment; namespace NHibernate.Test.FetchLazyProperties @@ -932,6 +934,37 @@ public void TestFetchAfterEntityIsInitialized(bool readOnly) Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Formula"), Is.True); } + [Test] + public void TestHqlCrossJoinFetchFormula() + { + var persons = new List(); + var bestFriends = new List(); + using (var sqlSpy = new SqlLogSpy()) + using (var s = OpenSession()) + { + var list = s.CreateQuery("select p, bf from Person p cross join Person bf fetch bf.Formula where bf.Id = p.BestFriend.Id").List(); + foreach (var arr in list) + { + persons.Add((Person) arr[0]); + bestFriends.Add((Person) arr[1]); + } + } + + AssertPersons(persons, false); + AssertPersons(bestFriends, true); + + void AssertPersons(List results, bool fetched) + { + foreach (var person in results) + { + Assert.That(person, Is.Not.Null); + Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Image"), Is.False); + Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Address"), Is.False); + Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Formula"), fetched ? Is.True : (IResolveConstraint) Is.False); + } + } + } + private static Person GeneratePerson(int i, Person bestFriend) { return new Person diff --git a/src/NHibernate/Hql/Ast/ANTLR/Hql.g b/src/NHibernate/Hql/Ast/ANTLR/Hql.g index cef7a744c3f..efe660d8918 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Hql.g +++ b/src/NHibernate/Hql/Ast/ANTLR/Hql.g @@ -256,7 +256,7 @@ fromClause fromJoin : ( ( ( LEFT | RIGHT ) (OUTER)? ) | FULL | INNER )? JOIN^ (FETCH)? path (asAlias)? (propertyFetch)? (withClause)? | ( ( ( LEFT | RIGHT ) (OUTER)? ) | FULL | INNER )? JOIN^ (FETCH)? ELEMENTS! OPEN! path CLOSE! (asAlias)? (propertyFetch)? (withClause)? - | CROSS JOIN^ { WeakKeywords(); } path (asAlias)? + | CROSS JOIN^ { WeakKeywords(); } path (asAlias)? (propertyFetch)? ; withClause From f44223ef47f4d2361ab05bc75fed8bf50064cc81 Mon Sep 17 00:00:00 2001 From: maca88 Date: Sun, 15 Mar 2020 18:00:18 +0100 Subject: [PATCH 07/10] Code review changes --- .../FetchLazyPropertiesFixture.cs | 5 +++++ .../Async/Linq/LinqQuerySamples.cs | 4 ++-- .../FetchLazyPropertiesFixture.cs | 5 +++++ src/NHibernate.Test/Linq/LinqQuerySamples.cs | 4 ++-- .../Linq/ReWriters/AddJoinsReWriter.cs | 19 ++++--------------- 5 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/NHibernate.Test/Async/FetchLazyProperties/FetchLazyPropertiesFixture.cs b/src/NHibernate.Test/Async/FetchLazyProperties/FetchLazyPropertiesFixture.cs index 9cff3809b59..8fe1e68ba81 100644 --- a/src/NHibernate.Test/Async/FetchLazyProperties/FetchLazyPropertiesFixture.cs +++ b/src/NHibernate.Test/Async/FetchLazyProperties/FetchLazyPropertiesFixture.cs @@ -948,6 +948,11 @@ public async Task TestFetchAfterEntityIsInitializedAsync(bool readOnly) [Test] public async Task TestHqlCrossJoinFetchFormulaAsync() { + if (!Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } + var persons = new List(); var bestFriends = new List(); using (var sqlSpy = new SqlLogSpy()) diff --git a/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs b/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs index 3058db20e83..f99abd00850 100644 --- a/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs +++ b/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs @@ -1065,8 +1065,8 @@ join c in db.Customers on await (ObjectDumper.WriteAsync(q)); var sql = sqlSpy.GetWholeLog(); - Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); - Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(0)); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2)); } } diff --git a/src/NHibernate.Test/FetchLazyProperties/FetchLazyPropertiesFixture.cs b/src/NHibernate.Test/FetchLazyProperties/FetchLazyPropertiesFixture.cs index 74823eba282..49a75c4b107 100644 --- a/src/NHibernate.Test/FetchLazyProperties/FetchLazyPropertiesFixture.cs +++ b/src/NHibernate.Test/FetchLazyProperties/FetchLazyPropertiesFixture.cs @@ -937,6 +937,11 @@ public void TestFetchAfterEntityIsInitialized(bool readOnly) [Test] public void TestHqlCrossJoinFetchFormula() { + if (!Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } + var persons = new List(); var bestFriends = new List(); using (var sqlSpy = new SqlLogSpy()) diff --git a/src/NHibernate.Test/Linq/LinqQuerySamples.cs b/src/NHibernate.Test/Linq/LinqQuerySamples.cs index 848b0057ce1..2277e4e3d61 100755 --- a/src/NHibernate.Test/Linq/LinqQuerySamples.cs +++ b/src/NHibernate.Test/Linq/LinqQuerySamples.cs @@ -1609,8 +1609,8 @@ join c in db.Customers on ObjectDumper.Write(q); var sql = sqlSpy.GetWholeLog(); - Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); - Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(0)); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2)); } } diff --git a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs index 44e40652e89..4252890614b 100644 --- a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs +++ b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs @@ -20,7 +20,6 @@ public class AddJoinsReWriter : NhQueryModelVisitorBase, IIsEntityDecider private readonly ISessionFactoryImplementor _sessionFactory; private readonly MemberExpressionJoinDetector _memberExpressionJoinDetector; private readonly WhereJoinDetector _whereJoinDetector; - private int? _joinInsertIndex; private JoinClause _currentJoin; private AddJoinsReWriter(ISessionFactoryImplementor sessionFactory, QueryModel queryModel) @@ -72,13 +71,11 @@ private void VisitJoinClause(JoinClause joinClause, QueryModel queryModel, IBody joinClause.InnerSequence = _whereJoinDetector.Transform(joinClause.InnerSequence); // When associations are located in the outer key (e.g. from a in A join b in B b on a.C.D.Id equals b.Id), - // we have to insert the association join before the current join in order to produce a valid query. - _joinInsertIndex = queryModel.BodyClauses.IndexOf(bodyClause); - joinClause.OuterKeySelector = _whereJoinDetector.Transform(joinClause.OuterKeySelector); - _joinInsertIndex = null; + // do nothing and leave them to HQL for adding the missing joins. // When associations are located in the inner key (e.g. from a in A join b in B b on a.Id equals b.C.D.Id), - // we have to move the condition to the where statement, otherwise the query will be invalid. + // we have to move the condition to the where statement, otherwise the query will be invalid (HQL does not + // support them). // Link newly created joins with the current join clause in order to later detect which join type to use. _currentJoin = joinClause; joinClause.InnerKeySelector = _whereJoinDetector.Transform(joinClause.InnerKeySelector); @@ -99,15 +96,7 @@ public bool IsIdentifier(System.Type type, string propertyName) private void AddJoin(QueryModel queryModel, NhJoinClause joinClause) { joinClause.ParentJoinClause = _currentJoin; - if (_joinInsertIndex.HasValue) - { - queryModel.BodyClauses.Insert(_joinInsertIndex.Value, joinClause); - _joinInsertIndex++; - } - else - { - queryModel.BodyClauses.Add(joinClause); - } + queryModel.BodyClauses.Add(joinClause); } } } From 2a334777362f1f382b8ad9edb5f787dd58e16fd3 Mon Sep 17 00:00:00 2001 From: maca88 Date: Sun, 15 Mar 2020 18:53:06 +0100 Subject: [PATCH 08/10] Use inner joins for inner key associations --- src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs | 4 +++- src/NHibernate.Test/Linq/LinqQuerySamples.cs | 4 +++- src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs | 6 ++++++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs b/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs index f99abd00850..fbcfa1224dc 100644 --- a/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs +++ b/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs @@ -1026,7 +1026,9 @@ join o in db.Orders on var sql = sqlSpy.GetWholeLog(); Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); - Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(0)); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(useCrossJoin ? 1 : 2)); + Assert.That(GetTotalOccurrences(sql, "cross join"), Is.EqualTo(useCrossJoin ? 1 : 0)); } } diff --git a/src/NHibernate.Test/Linq/LinqQuerySamples.cs b/src/NHibernate.Test/Linq/LinqQuerySamples.cs index 2277e4e3d61..af77418ab5f 100755 --- a/src/NHibernate.Test/Linq/LinqQuerySamples.cs +++ b/src/NHibernate.Test/Linq/LinqQuerySamples.cs @@ -1570,7 +1570,9 @@ join o in db.Orders on var sql = sqlSpy.GetWholeLog(); Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); - Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(0)); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(useCrossJoin ? 1 : 2)); + Assert.That(GetTotalOccurrences(sql, "cross join"), Is.EqualTo(useCrossJoin ? 1 : 0)); } } diff --git a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs index 4252890614b..9eb8f028d68 100644 --- a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs +++ b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs @@ -96,6 +96,12 @@ public bool IsIdentifier(System.Type type, string propertyName) private void AddJoin(QueryModel queryModel, NhJoinClause joinClause) { joinClause.ParentJoinClause = _currentJoin; + if (_currentJoin != null) + { + // Match the parent join type + joinClause.MakeInner(); + } + queryModel.BodyClauses.Add(joinClause); } } From 3c93bfff8a6e72a96c9423fd9bd517cdebde5aa3 Mon Sep 17 00:00:00 2001 From: maca88 Date: Sun, 15 Mar 2020 19:02:58 +0100 Subject: [PATCH 09/10] Remove unneeded private method --- src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs index 9eb8f028d68..d022e1ffc88 100644 --- a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs +++ b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs @@ -62,11 +62,6 @@ public override void VisitNhHavingClause(NhHavingClause havingClause, QueryModel } public override void VisitJoinClause(JoinClause joinClause, QueryModel queryModel, int index) - { - VisitJoinClause(joinClause, queryModel, joinClause); - } - - private void VisitJoinClause(JoinClause joinClause, QueryModel queryModel, IBodyClause bodyClause) { joinClause.InnerSequence = _whereJoinDetector.Transform(joinClause.InnerSequence); From eab20453e98e1a1a921465b74f9edb29526fe47b Mon Sep 17 00:00:00 2001 From: maca88 Date: Tue, 17 Mar 2020 21:19:17 +0100 Subject: [PATCH 10/10] Comment Linq to Objects query --- .../Async/Linq/LinqQuerySamples.cs | 22 +++++++++---------- src/NHibernate.Test/Linq/LinqQuerySamples.cs | 22 +++++++++---------- 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs b/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs index fbcfa1224dc..a332a24f6ce 100644 --- a/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs +++ b/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs @@ -1124,15 +1124,15 @@ public async Task DLinqJoin9Async(bool useCrossJoin) Assert.Ignore("Dialect does not support cross join."); } - ICollection expected, actual; - expected = - (from o in db.Orders.ToList() - from p in db.Products.ToList() - join d in db.OrderLines.ToList() - on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} - into details - from d in details - select new {o.OrderId, p.ProductId, d.UnitPrice}).ToList(); + // The expected collection can be obtained from the below Linq to Objects query. + //var expected = + // (from o in db.Orders.ToList() + // from p in db.Products.ToList() + // join d in db.OrderLines.ToList() + // on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} + // into details + // from d in details + // select new {o.OrderId, p.ProductId, d.UnitPrice}).ToList(); using (var substitute = SubstituteDialect()) using (var sqlSpy = new SqlLogSpy()) @@ -1140,7 +1140,7 @@ from d in details ClearQueryPlanCache(); substitute.Value.SupportsCrossJoin.Returns(useCrossJoin); - actual = + var actual = await ((from o in db.Orders from p in db.Products join d in db.OrderLines @@ -1154,8 +1154,6 @@ from d in details Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(useCrossJoin ? 1 : 2)); } - - Assert.AreEqual(expected.Count, actual.Count); } [Category("JOIN")] diff --git a/src/NHibernate.Test/Linq/LinqQuerySamples.cs b/src/NHibernate.Test/Linq/LinqQuerySamples.cs index af77418ab5f..193c313130c 100755 --- a/src/NHibernate.Test/Linq/LinqQuerySamples.cs +++ b/src/NHibernate.Test/Linq/LinqQuerySamples.cs @@ -1691,15 +1691,15 @@ public void DLinqJoin9(bool useCrossJoin) Assert.Ignore("Dialect does not support cross join."); } - ICollection expected, actual; - expected = - (from o in db.Orders.ToList() - from p in db.Products.ToList() - join d in db.OrderLines.ToList() - on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} - into details - from d in details - select new {o.OrderId, p.ProductId, d.UnitPrice}).ToList(); + // The expected collection can be obtained from the below Linq to Objects query. + //var expected = + // (from o in db.Orders.ToList() + // from p in db.Products.ToList() + // join d in db.OrderLines.ToList() + // on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} + // into details + // from d in details + // select new {o.OrderId, p.ProductId, d.UnitPrice}).ToList(); using (var substitute = SubstituteDialect()) using (var sqlSpy = new SqlLogSpy()) @@ -1707,7 +1707,7 @@ from d in details ClearQueryPlanCache(); substitute.Value.SupportsCrossJoin.Returns(useCrossJoin); - actual = + var actual = (from o in db.Orders from p in db.Products join d in db.OrderLines @@ -1721,8 +1721,6 @@ from d in details Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(useCrossJoin ? 1 : 2)); } - - Assert.AreEqual(expected.Count, actual.Count); } [Category("JOIN")]