diff --git a/src/NHibernate.Test/Async/FetchLazyProperties/FetchLazyPropertiesFixture.cs b/src/NHibernate.Test/Async/FetchLazyProperties/FetchLazyPropertiesFixture.cs index a971ec6d132..8fe1e68ba81 100644 --- a/src/NHibernate.Test/Async/FetchLazyProperties/FetchLazyPropertiesFixture.cs +++ b/src/NHibernate.Test/Async/FetchLazyProperties/FetchLazyPropertiesFixture.cs @@ -9,12 +9,14 @@ using System; +using System.Collections.Generic; using System.Linq; using NHibernate.Cache; using NHibernate.Cfg; using NHibernate.Hql.Ast.ANTLR; using NHibernate.Linq; using NUnit.Framework; +using NUnit.Framework.Constraints; using Environment = NHibernate.Cfg.Environment; namespace NHibernate.Test.FetchLazyProperties @@ -943,6 +945,42 @@ public async Task TestFetchAfterEntityIsInitializedAsync(bool readOnly) Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Formula"), Is.True); } + [Test] + public async Task TestHqlCrossJoinFetchFormulaAsync() + { + if (!Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } + + var persons = new List(); + var bestFriends = new List(); + using (var sqlSpy = new SqlLogSpy()) + using (var s = OpenSession()) + { + var list = await (s.CreateQuery("select p, bf from Person p cross join Person bf fetch bf.Formula where bf.Id = p.BestFriend.Id").ListAsync()); + foreach (var arr in list) + { + persons.Add((Person) arr[0]); + bestFriends.Add((Person) arr[1]); + } + } + + AssertPersons(persons, false); + AssertPersons(bestFriends, true); + + void AssertPersons(List results, bool fetched) + { + foreach (var person in results) + { + Assert.That(person, Is.Not.Null); + Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Image"), Is.False); + Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Address"), Is.False); + Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Formula"), fetched ? Is.True : (IResolveConstraint) Is.False); + } + } + } + private static Person GeneratePerson(int i, Person bestFriend) { return new Person diff --git a/src/NHibernate.Test/Async/Hql/EntityJoinHqlTest.cs b/src/NHibernate.Test/Async/Hql/EntityJoinHqlTest.cs index c573a7666a8..dcca7fc2924 100644 --- a/src/NHibernate.Test/Async/Hql/EntityJoinHqlTest.cs +++ b/src/NHibernate.Test/Async/Hql/EntityJoinHqlTest.cs @@ -274,6 +274,27 @@ public async Task EntityJoinWithFetchesAsync() } } + [Test] + public async Task CrossJoinAndWhereClauseAsync() + { + using (var sqlLog = new SqlLogSpy()) + using (var session = OpenSession()) + { + var result = await (session.CreateQuery( + "SELECT s " + + "FROM EntityComplex s cross join EntityComplex q " + + "where s.SameTypeChild.Id = q.SameTypeChild.Id" + ).ListAsync()); + + Assert.That(result, Has.Count.EqualTo(1)); + Assert.That(sqlLog.Appender.GetEvents().Length, Is.EqualTo(1), "Only one SQL select is expected"); + if (Dialect.SupportsCrossJoin) + { + Assert.That(sqlLog.GetWholeLog(), Does.Contain("cross join"), "A cross join is expected in the SQL select"); + } + } + } + #region Test Setup protected override HbmMapping GetMappings() diff --git a/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs b/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs index 6d11edace21..86802cdb345 100644 --- a/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs +++ b/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs @@ -8,7 +8,13 @@ //------------------------------------------------------------------------------ +using System; using System.Linq; +using System.Reflection; +using NHibernate.Cfg; +using NHibernate.Engine.Query; +using NHibernate.Util; +using NSubstitute; using NUnit.Framework; using NHibernate.Linq; @@ -31,5 +37,32 @@ public async Task MultipleLinqJoinsWithSameProjectionNamesAsync() Assert.That(orders.Count, Is.EqualTo(828)); Assert.IsTrue(orders.All(x => x.FirstId == x.SecondId - 1 && x.SecondId == x.ThirdId - 1)); } + + [TestCase(false)] + [TestCase(true)] + public async Task CrossJoinWithPredicateInWhereStatementAsync(bool useCrossJoin) + { + if (useCrossJoin && !Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } + + using (var substitute = SubstituteDialect()) + using (var sqlSpy = new SqlLogSpy()) + { + ClearQueryPlanCache(); + substitute.Value.SupportsCrossJoin.Returns(useCrossJoin); + + var result = await ((from o in db.Orders + from o2 in db.Orders.Where(x => x.Freight > 50) + where (o.OrderId == o2.OrderId + 1) || (o.OrderId == o2.OrderId - 1) + select new { o.OrderId, OrderId2 = o2.OrderId }).ToListAsync()); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(result.Count, Is.EqualTo(720)); + Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(useCrossJoin ? 0 : 1)); + } + } } } diff --git a/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs b/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs index 44fd7ef3cd8..a332a24f6ce 100644 --- a/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs +++ b/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs @@ -9,9 +9,11 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Linq; using NHibernate.DomainModel.Northwind.Entities; +using NSubstitute; using NUnit.Framework; using NHibernate.Linq; @@ -757,7 +759,13 @@ from o in c.Orders where c.Address.City == "London" select o; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -863,7 +871,13 @@ from p in db.Products where p.Supplier.Address.Country == "USA" && p.UnitsInStock == 0 select p; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -879,7 +893,16 @@ from et in e.Territories where e.Address.City == "Seattle" select new {e.FirstName, e.LastName, et.Region.Description}; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + // EmployeeTerritories and Territories + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2)); + // Region + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -903,7 +926,13 @@ from e2 in e1.Subordinates e1.Address.City }; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -918,7 +947,13 @@ from c in db.Customers join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into orders select new {c.ContactName, OrderCount = orders.Average(x => x.Freight)}; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "join"), Is.EqualTo(0)); + } } [Category("JOIN")] @@ -930,7 +965,13 @@ from c in db.Customers join o in db.Orders on c.CustomerId equals o.Customer.CustomerId select new { c.ContactName, o.OrderId }; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -959,15 +1000,76 @@ from c in db.Customers } [Category("JOIN")] - [Test(Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] - public async Task DLinqJoin5dAsync() + [TestCase(true, Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] + [TestCase(false, Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] + public async Task DLinqJoin5dAsync(bool useCrossJoin) { + if (useCrossJoin && !Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } + var q = from c in db.Customers - join o in db.Orders on new {c.CustomerId, HasContractTitle = c.ContactTitle != null} equals new {o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } + join o in db.Orders on + new {c.CustomerId, HasContractTitle = c.ContactTitle != null} equals + new {o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } select new { c.ContactName, o.OrderId }; - await (ObjectDumper.WriteAsync(q)); + using (var substitute = SubstituteDialect()) + using (var sqlSpy = new SqlLogSpy()) + { + ClearQueryPlanCache(); + substitute.Value.SupportsCrossJoin.Returns(useCrossJoin); + + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(0)); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(useCrossJoin ? 1 : 2)); + Assert.That(GetTotalOccurrences(sql, "cross join"), Is.EqualTo(useCrossJoin ? 1 : 0)); + } + } + + [Category("JOIN")] + [Test(Description = "This sample joins two tables and projects results from the first table.")] + public async Task DLinqJoin5eAsync() + { + var q = + from c in db.Customers + join o in db.Orders on c.CustomerId equals o.Customer.CustomerId + where c.ContactName != null + select o; + + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } + } + + [Category("JOIN")] + [TestCase(Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] + public async Task DLinqJoin5fAsync() + { + var q = + from o in db.Orders + join c in db.Customers on + new { o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } equals + new { c.CustomerId, HasContractTitle = c.ContactTitle != null } + select new { c.ContactName, o.OrderId }; + + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(0)); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2)); + } } [Category("JOIN")] @@ -983,7 +1085,13 @@ join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into ords join e in db.Employees on c.Address.City equals e.Address.City into emps select new {c.ContactName, ords = ords.Count(), emps = emps.Count()}; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "join"), Is.EqualTo(0)); + } } [Category("JOIN")] @@ -997,32 +1105,55 @@ join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into ords from o in ords select new {c.ContactName, o.OrderId, z}; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] - [Test(Description = "This sample shows a group join with a composite key.")] - public async Task DLinqJoin9Async() - { - var expected = - (from o in db.Orders.ToList() - from p in db.Products.ToList() - join d in db.OrderLines.ToList() - on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} - into details - from d in details - select new {o.OrderId, p.ProductId, d.UnitPrice}).ToList(); - - var actual = - await ((from o in db.Orders - from p in db.Products - join d in db.OrderLines - on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} - into details - from d in details - select new {o.OrderId, p.ProductId, d.UnitPrice}).ToListAsync()); - - Assert.AreEqual(expected.Count, actual.Count); + [TestCase(true, Description = "This sample shows a group join with a composite key.")] + [TestCase(false, Description = "This sample shows a group join with a composite key.")] + public async Task DLinqJoin9Async(bool useCrossJoin) + { + if (useCrossJoin && !Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } + + // The expected collection can be obtained from the below Linq to Objects query. + //var expected = + // (from o in db.Orders.ToList() + // from p in db.Products.ToList() + // join d in db.OrderLines.ToList() + // on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} + // into details + // from d in details + // select new {o.OrderId, p.ProductId, d.UnitPrice}).ToList(); + + using (var substitute = SubstituteDialect()) + using (var sqlSpy = new SqlLogSpy()) + { + ClearQueryPlanCache(); + substitute.Value.SupportsCrossJoin.Returns(useCrossJoin); + + var actual = + await ((from o in db.Orders + from p in db.Products + join d in db.OrderLines + on new { o.OrderId, p.ProductId } equals new { d.Order.OrderId, d.Product.ProductId } + into details + from d in details + select new { o.OrderId, p.ProductId, d.UnitPrice }).ToListAsync()); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(actual.Count, Is.EqualTo(2155)); + Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(useCrossJoin ? 1 : 2)); + } } [Category("JOIN")] @@ -1036,5 +1167,24 @@ group o by c into x await (ObjectDumper.WriteAsync(q)); } + + [Category("JOIN")] + [Test(Description = "This sample shows how to join multiple tables.")] + public async Task DLinqJoin10aAsync() + { + var q = + from e in db.Employees + join s in db.Employees on e.Superior.EmployeeId equals s.EmployeeId + join s2 in db.Employees on s.Superior.EmployeeId equals s2.EmployeeId + select new { e.FirstName, SuperiorName = s.FirstName, Superior2Name = s2.FirstName }; + + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2)); + } + } } } diff --git a/src/NHibernate.Test/FetchLazyProperties/FetchLazyPropertiesFixture.cs b/src/NHibernate.Test/FetchLazyProperties/FetchLazyPropertiesFixture.cs index 04cb354f3f4..49a75c4b107 100644 --- a/src/NHibernate.Test/FetchLazyProperties/FetchLazyPropertiesFixture.cs +++ b/src/NHibernate.Test/FetchLazyProperties/FetchLazyPropertiesFixture.cs @@ -1,10 +1,12 @@ using System; +using System.Collections.Generic; using System.Linq; using NHibernate.Cache; using NHibernate.Cfg; using NHibernate.Hql.Ast.ANTLR; using NHibernate.Linq; using NUnit.Framework; +using NUnit.Framework.Constraints; using Environment = NHibernate.Cfg.Environment; namespace NHibernate.Test.FetchLazyProperties @@ -932,6 +934,42 @@ public void TestFetchAfterEntityIsInitialized(bool readOnly) Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Formula"), Is.True); } + [Test] + public void TestHqlCrossJoinFetchFormula() + { + if (!Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } + + var persons = new List(); + var bestFriends = new List(); + using (var sqlSpy = new SqlLogSpy()) + using (var s = OpenSession()) + { + var list = s.CreateQuery("select p, bf from Person p cross join Person bf fetch bf.Formula where bf.Id = p.BestFriend.Id").List(); + foreach (var arr in list) + { + persons.Add((Person) arr[0]); + bestFriends.Add((Person) arr[1]); + } + } + + AssertPersons(persons, false); + AssertPersons(bestFriends, true); + + void AssertPersons(List results, bool fetched) + { + foreach (var person in results) + { + Assert.That(person, Is.Not.Null); + Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Image"), Is.False); + Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Address"), Is.False); + Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Formula"), fetched ? Is.True : (IResolveConstraint) Is.False); + } + } + } + private static Person GeneratePerson(int i, Person bestFriend) { return new Person diff --git a/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs b/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs index 0b8d0c11a6b..0a578bd982a 100644 --- a/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs +++ b/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs @@ -264,7 +264,7 @@ public void EntityJoinWithFetches() } [Test, Ignore("Failing for unrelated reasons")] - public void CrossJoinAndWithClause() + public void ImplicitJoinAndWithClause() { //This is about complex theta style join fix that was implemented in hibernate along with entity join functionality //https://hibernate.atlassian.net/browse/HHH-7321 @@ -279,6 +279,27 @@ public void CrossJoinAndWithClause() } } + [Test] + public void CrossJoinAndWhereClause() + { + using (var sqlLog = new SqlLogSpy()) + using (var session = OpenSession()) + { + var result = session.CreateQuery( + "SELECT s " + + "FROM EntityComplex s cross join EntityComplex q " + + "where s.SameTypeChild.Id = q.SameTypeChild.Id" + ).List(); + + Assert.That(result, Has.Count.EqualTo(1)); + Assert.That(sqlLog.Appender.GetEvents().Length, Is.EqualTo(1), "Only one SQL select is expected"); + if (Dialect.SupportsCrossJoin) + { + Assert.That(sqlLog.GetWholeLog(), Does.Contain("cross join"), "A cross join is expected in the SQL select"); + } + } + } + #region Test Setup protected override HbmMapping GetMappings() diff --git a/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs b/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs index 1d78b46b181..a013014cae8 100644 --- a/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs @@ -1,4 +1,10 @@ -using System.Linq; +using System; +using System.Linq; +using System.Reflection; +using NHibernate.Cfg; +using NHibernate.Engine.Query; +using NHibernate.Util; +using NSubstitute; using NUnit.Framework; namespace NHibernate.Test.Linq.ByMethod @@ -19,5 +25,32 @@ public void MultipleLinqJoinsWithSameProjectionNames() Assert.That(orders.Count, Is.EqualTo(828)); Assert.IsTrue(orders.All(x => x.FirstId == x.SecondId - 1 && x.SecondId == x.ThirdId - 1)); } + + [TestCase(false)] + [TestCase(true)] + public void CrossJoinWithPredicateInWhereStatement(bool useCrossJoin) + { + if (useCrossJoin && !Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } + + using (var substitute = SubstituteDialect()) + using (var sqlSpy = new SqlLogSpy()) + { + ClearQueryPlanCache(); + substitute.Value.SupportsCrossJoin.Returns(useCrossJoin); + + var result = (from o in db.Orders + from o2 in db.Orders.Where(x => x.Freight > 50) + where (o.OrderId == o2.OrderId + 1) || (o.OrderId == o2.OrderId - 1) + select new { o.OrderId, OrderId2 = o2.OrderId }).ToList(); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(result.Count, Is.EqualTo(720)); + Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(useCrossJoin ? 0 : 1)); + } + } } } diff --git a/src/NHibernate.Test/Linq/LinqQuerySamples.cs b/src/NHibernate.Test/Linq/LinqQuerySamples.cs index efe18a3caad..193c313130c 100755 --- a/src/NHibernate.Test/Linq/LinqQuerySamples.cs +++ b/src/NHibernate.Test/Linq/LinqQuerySamples.cs @@ -1,7 +1,9 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Linq; using NHibernate.DomainModel.Northwind.Entities; +using NSubstitute; using NUnit.Framework; namespace NHibernate.Test.Linq @@ -1301,7 +1303,13 @@ from o in c.Orders where c.Address.City == "London" select o; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -1407,7 +1415,13 @@ from p in db.Products where p.Supplier.Address.Country == "USA" && p.UnitsInStock == 0 select p; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -1423,7 +1437,16 @@ from et in e.Territories where e.Address.City == "Seattle" select new {e.FirstName, e.LastName, et.Region.Description}; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + // EmployeeTerritories and Territories + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2)); + // Region + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -1447,7 +1470,13 @@ from e2 in e1.Subordinates e1.Address.City }; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -1462,7 +1491,13 @@ from c in db.Customers join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into orders select new {c.ContactName, OrderCount = orders.Average(x => x.Freight)}; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "join"), Is.EqualTo(0)); + } } [Category("JOIN")] @@ -1474,7 +1509,13 @@ from c in db.Customers join o in db.Orders on c.CustomerId equals o.Customer.CustomerId select new { c.ContactName, o.OrderId }; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -1503,15 +1544,76 @@ from c in db.Customers } [Category("JOIN")] - [Test(Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] - public void DLinqJoin5d() + [TestCase(true, Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] + [TestCase(false, Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] + public void DLinqJoin5d(bool useCrossJoin) { + if (useCrossJoin && !Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } + var q = from c in db.Customers - join o in db.Orders on new {c.CustomerId, HasContractTitle = c.ContactTitle != null} equals new {o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } + join o in db.Orders on + new {c.CustomerId, HasContractTitle = c.ContactTitle != null} equals + new {o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } select new { c.ContactName, o.OrderId }; - ObjectDumper.Write(q); + using (var substitute = SubstituteDialect()) + using (var sqlSpy = new SqlLogSpy()) + { + ClearQueryPlanCache(); + substitute.Value.SupportsCrossJoin.Returns(useCrossJoin); + + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(0)); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(useCrossJoin ? 1 : 2)); + Assert.That(GetTotalOccurrences(sql, "cross join"), Is.EqualTo(useCrossJoin ? 1 : 0)); + } + } + + [Category("JOIN")] + [Test(Description = "This sample joins two tables and projects results from the first table.")] + public void DLinqJoin5e() + { + var q = + from c in db.Customers + join o in db.Orders on c.CustomerId equals o.Customer.CustomerId + where c.ContactName != null + select o; + + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } + } + + [Category("JOIN")] + [TestCase(Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] + public void DLinqJoin5f() + { + var q = + from o in db.Orders + join c in db.Customers on + new { o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } equals + new { c.CustomerId, HasContractTitle = c.ContactTitle != null } + select new { c.ContactName, o.OrderId }; + + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(0)); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2)); + } } [Category("JOIN")] @@ -1527,7 +1629,13 @@ join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into ords join e in db.Employees on c.Address.City equals e.Address.City into emps select new {c.ContactName, ords = ords.Count(), emps = emps.Count()}; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "join"), Is.EqualTo(0)); + } } [Category("JOIN")] @@ -1544,7 +1652,13 @@ join o in db.Orders on e equals o.Employee into ords from o in ords.DefaultIfEmpty() select new {e.FirstName, e.LastName, Order = o}; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -1558,32 +1672,55 @@ join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into ords from o in ords select new {c.ContactName, o.OrderId, z}; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] - [Test(Description = "This sample shows a group join with a composite key.")] - public void DLinqJoin9() - { - var expected = - (from o in db.Orders.ToList() - from p in db.Products.ToList() - join d in db.OrderLines.ToList() - on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} - into details - from d in details - select new {o.OrderId, p.ProductId, d.UnitPrice}).ToList(); - - var actual = - (from o in db.Orders - from p in db.Products - join d in db.OrderLines - on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} - into details - from d in details - select new {o.OrderId, p.ProductId, d.UnitPrice}).ToList(); + [TestCase(true, Description = "This sample shows a group join with a composite key.")] + [TestCase(false, Description = "This sample shows a group join with a composite key.")] + public void DLinqJoin9(bool useCrossJoin) + { + if (useCrossJoin && !Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } - Assert.AreEqual(expected.Count, actual.Count); + // The expected collection can be obtained from the below Linq to Objects query. + //var expected = + // (from o in db.Orders.ToList() + // from p in db.Products.ToList() + // join d in db.OrderLines.ToList() + // on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} + // into details + // from d in details + // select new {o.OrderId, p.ProductId, d.UnitPrice}).ToList(); + + using (var substitute = SubstituteDialect()) + using (var sqlSpy = new SqlLogSpy()) + { + ClearQueryPlanCache(); + substitute.Value.SupportsCrossJoin.Returns(useCrossJoin); + + var actual = + (from o in db.Orders + from p in db.Products + join d in db.OrderLines + on new { o.OrderId, p.ProductId } equals new { d.Order.OrderId, d.Product.ProductId } + into details + from d in details + select new { o.OrderId, p.ProductId, d.UnitPrice }).ToList(); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(actual.Count, Is.EqualTo(2155)); + Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(useCrossJoin ? 1 : 2)); + } } [Category("JOIN")] @@ -1598,6 +1735,25 @@ group o by c into x ObjectDumper.Write(q); } + [Category("JOIN")] + [Test(Description = "This sample shows how to join multiple tables.")] + public void DLinqJoin10a() + { + var q = + from e in db.Employees + join s in db.Employees on e.Superior.EmployeeId equals s.EmployeeId + join s2 in db.Employees on s.Superior.EmployeeId equals s2.EmployeeId + select new { e.FirstName, SuperiorName = s.FirstName, Superior2Name = s2.FirstName }; + + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2)); + } + } + [Category("WHERE")] [Test(Description = "This sample uses WHERE to filter for orders with shipping date equals to null.")] public void DLinq2B() diff --git a/src/NHibernate.Test/TestCase.cs b/src/NHibernate.Test/TestCase.cs index 4c9ea610c6d..0d59b989cf4 100644 --- a/src/NHibernate.Test/TestCase.cs +++ b/src/NHibernate.Test/TestCase.cs @@ -16,6 +16,9 @@ using System.Text; using NHibernate.Dialect; using NHibernate.Driver; +using NHibernate.Engine.Query; +using NHibernate.Util; +using NSubstitute; namespace NHibernate.Test { @@ -477,6 +480,69 @@ protected void AssumeFunctionSupported(string functionName) $"{Dialect} doesn't support {functionName} standard function."); } + protected void ClearQueryPlanCache() + { + var planCacheField = typeof(QueryPlanCache) + .GetField("planCache", BindingFlags.NonPublic | BindingFlags.Instance) + ?? throw new InvalidOperationException("planCache field does not exist in QueryPlanCache."); + + var planCache = (SoftLimitMRUCache) planCacheField.GetValue(Sfi.QueryPlanCache); + planCache.Clear(); + } + + protected Substitute SubstituteDialect() + { + var origDialect = Sfi.Settings.Dialect; + var dialectProperty = (PropertyInfo) ReflectHelper.GetProperty(o => o.Dialect); + var forPartsOfMethod = ReflectHelper.GetMethodDefinition(() => Substitute.ForPartsOf()); + var substitute = (Dialect.Dialect) forPartsOfMethod.MakeGenericMethod(origDialect.GetType()) + .Invoke(null, new object[] { new object[0] }); + + dialectProperty.SetValue(Sfi.Settings, substitute); + + return new Substitute(substitute, Dispose); + + void Dispose() + { + dialectProperty.SetValue(Sfi.Settings, origDialect); + } + } + + protected static int GetTotalOccurrences(string content, string substring) + { + if (string.IsNullOrEmpty(substring)) + { + throw new ArgumentNullException(nameof(substring)); + } + + int occurrences = 0, index = 0; + while ((index = content.IndexOf(substring, index)) >= 0) + { + occurrences++; + index += substring.Length; + } + + return occurrences; + } + + protected struct Substitute : IDisposable + { + private readonly System.Action _disposeAction; + + public Substitute(TType value, System.Action disposeAction) + { + Value = value; + _disposeAction = disposeAction; + } + + public TType Value { get; } + + public void Dispose() + { + _disposeAction(); + } + } + #endregion } } diff --git a/src/NHibernate/AdoNet/Util/BasicFormatter.cs b/src/NHibernate/AdoNet/Util/BasicFormatter.cs index 2e601990e06..c3f35fbddf1 100644 --- a/src/NHibernate/AdoNet/Util/BasicFormatter.cs +++ b/src/NHibernate/AdoNet/Util/BasicFormatter.cs @@ -20,6 +20,7 @@ static BasicFormatter() { beginClauses.Add("left"); beginClauses.Add("right"); + beginClauses.Add("cross"); beginClauses.Add("inner"); beginClauses.Add("outer"); beginClauses.Add("group"); diff --git a/src/NHibernate/Dialect/DB2Dialect.cs b/src/NHibernate/Dialect/DB2Dialect.cs index d56ab6dc06f..3eef1635595 100644 --- a/src/NHibernate/Dialect/DB2Dialect.cs +++ b/src/NHibernate/Dialect/DB2Dialect.cs @@ -300,6 +300,9 @@ public override string ForUpdateString public override bool SupportsResultSetPositionQueryMethodsOnForwardOnlyCursor => false; + /// + public override bool SupportsCrossJoin => false; // DB2 v9.1 doesn't support 'cross join' syntax + public override bool SupportsLobValueChangePropogation => false; public override bool SupportsExistsInSelect => false; diff --git a/src/NHibernate/Dialect/Dialect.cs b/src/NHibernate/Dialect/Dialect.cs index e3d0b6cacd9..423c5082361 100644 --- a/src/NHibernate/Dialect/Dialect.cs +++ b/src/NHibernate/Dialect/Dialect.cs @@ -1337,6 +1337,11 @@ public virtual JoinFragment CreateOuterJoinFragment() return new ANSIJoinFragment(); } + /// + /// Does this dialect support CROSS JOIN? + /// + public virtual bool SupportsCrossJoin => true; + /// /// Create a strategy responsible /// for handling this dialect's variations in how CASE statements are diff --git a/src/NHibernate/Dialect/InformixDialect0940.cs b/src/NHibernate/Dialect/InformixDialect0940.cs index 13563df43e8..bdfcae4f27f 100644 --- a/src/NHibernate/Dialect/InformixDialect0940.cs +++ b/src/NHibernate/Dialect/InformixDialect0940.cs @@ -126,7 +126,10 @@ public override JoinFragment CreateOuterJoinFragment() return new ANSIJoinFragment(); } - /// + /// + public override bool SupportsCrossJoin => false; + + /// /// Does this Dialect have some kind of LIMIT syntax? /// /// False, unless overridden. diff --git a/src/NHibernate/Dialect/Oracle10gDialect.cs b/src/NHibernate/Dialect/Oracle10gDialect.cs index 7219390c0e9..caab3e1f492 100644 --- a/src/NHibernate/Dialect/Oracle10gDialect.cs +++ b/src/NHibernate/Dialect/Oracle10gDialect.cs @@ -15,5 +15,8 @@ public override JoinFragment CreateOuterJoinFragment() { return new ANSIJoinFragment(); } + + /// + public override bool SupportsCrossJoin => true; } } \ No newline at end of file diff --git a/src/NHibernate/Dialect/Oracle8iDialect.cs b/src/NHibernate/Dialect/Oracle8iDialect.cs index f3a943a3b8e..b2103b87965 100644 --- a/src/NHibernate/Dialect/Oracle8iDialect.cs +++ b/src/NHibernate/Dialect/Oracle8iDialect.cs @@ -328,6 +328,9 @@ public override JoinFragment CreateOuterJoinFragment() return new OracleJoinFragment(); } + /// + public override bool SupportsCrossJoin => false; + /// /// Map case support to the Oracle DECODE function. Oracle did not /// add support for CASE until 9i. diff --git a/src/NHibernate/Dialect/SybaseASA9Dialect.cs b/src/NHibernate/Dialect/SybaseASA9Dialect.cs index 12ad1e3c088..f839871b6a1 100644 --- a/src/NHibernate/Dialect/SybaseASA9Dialect.cs +++ b/src/NHibernate/Dialect/SybaseASA9Dialect.cs @@ -97,6 +97,9 @@ public override bool OffsetStartsAtOne get { return true; } } + /// + public override bool SupportsCrossJoin => false; + public override SqlString GetLimitString(SqlString queryString, SqlString offset, SqlString limit) { int intSelectInsertPoint = GetAfterSelectInsertPoint(queryString); diff --git a/src/NHibernate/Dialect/SybaseASE15Dialect.cs b/src/NHibernate/Dialect/SybaseASE15Dialect.cs index 1b33b672cdf..0a84a822424 100644 --- a/src/NHibernate/Dialect/SybaseASE15Dialect.cs +++ b/src/NHibernate/Dialect/SybaseASE15Dialect.cs @@ -247,7 +247,10 @@ public override bool SupportsExpectedLobUsagePattern { get { return false; } } - + + /// + public override bool SupportsCrossJoin => false; + public override char OpenQuote { get { return '['; } diff --git a/src/NHibernate/Hql/Ast/ANTLR/Hql.g b/src/NHibernate/Hql/Ast/ANTLR/Hql.g index 9f67aaa2162..efe660d8918 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Hql.g +++ b/src/NHibernate/Hql/Ast/ANTLR/Hql.g @@ -19,6 +19,7 @@ tokens BETWEEN='between'; CLASS='class'; COUNT='count'; + CROSS='cross'; DELETE='delete'; DESCENDING='desc'; DOT; @@ -255,6 +256,7 @@ fromClause fromJoin : ( ( ( LEFT | RIGHT ) (OUTER)? ) | FULL | INNER )? JOIN^ (FETCH)? path (asAlias)? (propertyFetch)? (withClause)? | ( ( ( LEFT | RIGHT ) (OUTER)? ) | FULL | INNER )? JOIN^ (FETCH)? ELEMENTS! OPEN! path CLOSE! (asAlias)? (propertyFetch)? (withClause)? + | CROSS JOIN^ { WeakKeywords(); } path (asAlias)? (propertyFetch)? ; withClause diff --git a/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.g b/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.g index 22e5450dce1..fba1010337c 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.g +++ b/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.g @@ -306,6 +306,9 @@ joinType returns [int j] | INNER { $j = INNER; } + | CROSS { + $j = CROSS; + } ; // Matches a path and returns the normalized string for the path (usually diff --git a/src/NHibernate/Hql/Ast/ANTLR/Util/JoinProcessor.cs b/src/NHibernate/Hql/Ast/ANTLR/Util/JoinProcessor.cs index c4f850f7aa3..8fbcfba7fd7 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Util/JoinProcessor.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Util/JoinProcessor.cs @@ -48,6 +48,8 @@ public static JoinType ToHibernateJoinType(int astJoinType) return JoinType.RightOuterJoin; case HqlSqlWalker.FULL: return JoinType.FullJoin; + case HqlSqlWalker.CROSS: + return JoinType.CrossJoin; default: throw new AssertionFailure("undefined join type " + astJoinType); } diff --git a/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs b/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs index 49cad7607e1..bb208295afc 100755 --- a/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs +++ b/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs @@ -478,11 +478,21 @@ public HqlIn In(HqlExpression itemExpression, HqlTreeNode source) return new HqlIn(_factory, itemExpression, source); } + public HqlInnerJoin InnerJoin(HqlExpression expression, HqlAlias @alias) + { + return new HqlInnerJoin(_factory, expression, @alias); + } + public HqlLeftJoin LeftJoin(HqlExpression expression, HqlAlias @alias) { return new HqlLeftJoin(_factory, expression, @alias); } + public HqlCrossJoin CrossJoin(HqlExpression expression, HqlAlias @alias) + { + return new HqlCrossJoin(_factory, expression, @alias); + } + public HqlFetchJoin FetchJoin(HqlExpression expression, HqlAlias @alias) { return new HqlFetchJoin(_factory, expression, @alias); diff --git a/src/NHibernate/Hql/Ast/HqlTreeNode.cs b/src/NHibernate/Hql/Ast/HqlTreeNode.cs index 5964b99db90..160739d7f16 100755 --- a/src/NHibernate/Hql/Ast/HqlTreeNode.cs +++ b/src/NHibernate/Hql/Ast/HqlTreeNode.cs @@ -844,6 +844,14 @@ public HqlJoin(IASTFactory factory, HqlExpression expression, HqlAlias @alias) : } } + public class HqlInnerJoin : HqlTreeNode + { + public HqlInnerJoin(IASTFactory factory, HqlExpression expression, HqlAlias alias) + : base(HqlSqlWalker.JOIN, "join", factory, new HqlInner(factory), expression, alias) + { + } + } + public class HqlLeftJoin : HqlTreeNode { public HqlLeftJoin(IASTFactory factory, HqlExpression expression, HqlAlias @alias) : base(HqlSqlWalker.JOIN, "join", factory, new HqlLeft(factory), expression, @alias) @@ -851,6 +859,13 @@ public class HqlLeftJoin : HqlTreeNode } } + public class HqlCrossJoin : HqlTreeNode + { + public HqlCrossJoin(IASTFactory factory, HqlExpression expression, HqlAlias @alias) : base(HqlSqlWalker.JOIN, "join", factory, new HqlCross(factory), expression, @alias) + { + } + } + public class HqlFetchJoin : HqlTreeNode { public HqlFetchJoin(IASTFactory factory, HqlExpression expression, HqlAlias @alias) @@ -898,6 +913,14 @@ public HqlBitwiseAnd(IASTFactory factory, HqlExpression lhs, HqlExpression rhs) } } + public class HqlInner : HqlTreeNode + { + public HqlInner(IASTFactory factory) + : base(HqlSqlWalker.INNER, "inner", factory) + { + } + } + public class HqlLeft : HqlTreeNode { public HqlLeft(IASTFactory factory) @@ -906,6 +929,14 @@ public HqlLeft(IASTFactory factory) } } + public class HqlCross : HqlTreeNode + { + public HqlCross(IASTFactory factory) + : base(HqlSqlWalker.CROSS, "cross", factory) + { + } + } + public class HqlAny : HqlBooleanExpression { public HqlAny(IASTFactory factory) : base(HqlSqlWalker.ANY, "any", factory) diff --git a/src/NHibernate/Linq/Clauses/NhJoinClause.cs b/src/NHibernate/Linq/Clauses/NhJoinClause.cs index d68926f88b3..3a1c17147aa 100644 --- a/src/NHibernate/Linq/Clauses/NhJoinClause.cs +++ b/src/NHibernate/Linq/Clauses/NhJoinClause.cs @@ -54,6 +54,8 @@ public NhJoinClause(string itemName, System.Type itemType, Expression fromExpres public bool IsInner { get; private set; } + internal JoinClause ParentJoinClause { get; set; } + public void TransformExpressions(Func transformation) { if (transformation == null) throw new ArgumentNullException(nameof(transformation)); diff --git a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs index 576bb65a9ea..d022e1ffc88 100644 --- a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs +++ b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Specialized; using System.Linq; using NHibernate.Engine; using NHibernate.Linq.Clauses; @@ -19,11 +20,12 @@ public class AddJoinsReWriter : NhQueryModelVisitorBase, IIsEntityDecider private readonly ISessionFactoryImplementor _sessionFactory; private readonly MemberExpressionJoinDetector _memberExpressionJoinDetector; private readonly WhereJoinDetector _whereJoinDetector; + private JoinClause _currentJoin; private AddJoinsReWriter(ISessionFactoryImplementor sessionFactory, QueryModel queryModel) { _sessionFactory = sessionFactory; - var joiner = new Joiner(queryModel); + var joiner = new Joiner(queryModel, AddJoin); _memberExpressionJoinDetector = new MemberExpressionJoinDetector(this, joiner); _whereJoinDetector = new WhereJoinDetector(this, joiner); } @@ -59,6 +61,22 @@ public override void VisitNhHavingClause(NhHavingClause havingClause, QueryModel _whereJoinDetector.Transform(havingClause); } + public override void VisitJoinClause(JoinClause joinClause, QueryModel queryModel, int index) + { + joinClause.InnerSequence = _whereJoinDetector.Transform(joinClause.InnerSequence); + + // When associations are located in the outer key (e.g. from a in A join b in B b on a.C.D.Id equals b.Id), + // do nothing and leave them to HQL for adding the missing joins. + + // When associations are located in the inner key (e.g. from a in A join b in B b on a.Id equals b.C.D.Id), + // we have to move the condition to the where statement, otherwise the query will be invalid (HQL does not + // support them). + // Link newly created joins with the current join clause in order to later detect which join type to use. + _currentJoin = joinClause; + joinClause.InnerKeySelector = _whereJoinDetector.Transform(joinClause.InnerKeySelector); + _currentJoin = null; + } + public bool IsEntity(System.Type type) { return _sessionFactory.GetImplementors(type.FullName).Any(); @@ -69,5 +87,17 @@ public bool IsIdentifier(System.Type type, string propertyName) var metadata = _sessionFactory.GetClassMetadata(type); return metadata != null && propertyName.Equals(metadata.IdentifierPropertyName); } + + private void AddJoin(QueryModel queryModel, NhJoinClause joinClause) + { + joinClause.ParentJoinClause = _currentJoin; + if (_currentJoin != null) + { + // Match the parent join type + joinClause.MakeInner(); + } + + queryModel.BodyClauses.Add(joinClause); + } } } diff --git a/src/NHibernate/Linq/Visitors/JoinBuilder.cs b/src/NHibernate/Linq/Visitors/JoinBuilder.cs index b84a72c7ac3..dc7bc395c13 100644 --- a/src/NHibernate/Linq/Visitors/JoinBuilder.cs +++ b/src/NHibernate/Linq/Visitors/JoinBuilder.cs @@ -21,11 +21,20 @@ public class Joiner : IJoiner private readonly NameGenerator _nameGenerator; private readonly QueryModel _queryModel; + internal Joiner(QueryModel queryModel, System.Action addJoinMethod) + : this(queryModel) + { + AddJoinMethod = addJoinMethod; + } + internal Joiner(QueryModel queryModel) { _nameGenerator = new NameGenerator(queryModel); _queryModel = queryModel; + AddJoinMethod = AddJoin; } + + internal System.Action AddJoinMethod { get; } public IEnumerable Joins { @@ -39,7 +48,7 @@ public Expression AddJoin(Expression expression, string key) if (!_joins.TryGetValue(key, out join)) { join = new NhJoinClause(_nameGenerator.GetNewName(), expression.Type, expression); - _queryModel.BodyClauses.Add(join); + AddJoinMethod(_queryModel, join); _joins.Add(key, join); } @@ -72,6 +81,11 @@ public bool CanAddJoin(Expression expression) return resultOperatorBase != null && _queryModel.ResultOperators.Contains(resultOperatorBase); } + private void AddJoin(QueryModel queryModel, NhJoinClause joinClause) + { + queryModel.BodyClauses.Add(joinClause); + } + private class QuerySourceExtractor : RelinqExpressionVisitor { private IQuerySource _querySource; @@ -90,4 +104,4 @@ protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpr } } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index f043d30c51f..a487a1281ce 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Linq.Expressions; using System.Reflection; using NHibernate.Hql.Ast; @@ -17,6 +18,8 @@ using Remotion.Linq.Clauses.ResultOperators; using Remotion.Linq.Clauses.StreamedData; using Remotion.Linq.EagerFetching; +using OrderByClause = Remotion.Linq.Clauses.OrderByClause; +using SelectClause = Remotion.Linq.Clauses.SelectClause; namespace NHibernate.Linq.Visitors { @@ -315,22 +318,16 @@ public override void VisitMainFromClause(MainFromClause fromClause, QueryModel q public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, QueryModel queryModel, int index) { var querySourceName = VisitorParameters.QuerySourceNamer.GetName(fromClause); - + var fromExpressionTree = HqlGeneratorExpressionVisitor.Visit(fromClause.FromExpression, VisitorParameters); + var alias = _hqlTree.TreeBuilder.Alias(querySourceName); if (fromClause.FromExpression is MemberExpression) { // It's a join - _hqlTree.AddFromClause( - _hqlTree.TreeBuilder.Join( - HqlGeneratorExpressionVisitor.Visit(fromClause.FromExpression, VisitorParameters).AsExpression(), - _hqlTree.TreeBuilder.Alias(querySourceName))); + _hqlTree.AddFromClause(_hqlTree.TreeBuilder.Join(fromExpressionTree.AsExpression(), alias)); } else { - // TODO - exact same code as in MainFromClause; refactor this out - _hqlTree.AddFromClause( - _hqlTree.TreeBuilder.Range( - HqlGeneratorExpressionVisitor.Visit(fromClause.FromExpression, VisitorParameters), - _hqlTree.TreeBuilder.Alias(querySourceName))); + _hqlTree.AddFromClause(CreateCrossJoin(fromExpressionTree, alias)); } base.VisitAdditionalFromClause(fromClause, queryModel, index); @@ -517,15 +514,24 @@ public override void VisitOrderByClause(OrderByClause orderByClause, QueryModel public override void VisitJoinClause(JoinClause joinClause, QueryModel queryModel, int index) { var equalityVisitor = new EqualityHqlGenerator(VisitorParameters); - var whereClause = equalityVisitor.Visit(joinClause.InnerKeySelector, joinClause.OuterKeySelector); - var querySourceName = VisitorParameters.QuerySourceNamer.GetName(joinClause); - - _hqlTree.AddWhereClause(whereClause); + var withClause = equalityVisitor.Visit(joinClause.InnerKeySelector, joinClause.OuterKeySelector); + var alias = _hqlTree.TreeBuilder.Alias(VisitorParameters.QuerySourceNamer.GetName(joinClause)); + var joinExpression = HqlGeneratorExpressionVisitor.Visit(joinClause.InnerSequence, VisitorParameters); + HqlTreeNode join; + // When associations are located inside the inner key selector we have to use a cross join instead of an inner + // join and add the condition in the where statement. + if (queryModel.BodyClauses.OfType().Any(o => o.ParentJoinClause == joinClause)) + { + _hqlTree.AddWhereClause(withClause); + join = CreateCrossJoin(joinExpression, alias); + } + else + { + join = _hqlTree.TreeBuilder.InnerJoin(joinExpression.AsExpression(), alias); + join.AddChild(_hqlTree.TreeBuilder.With(withClause)); + } - _hqlTree.AddFromClause( - _hqlTree.TreeBuilder.Range( - HqlGeneratorExpressionVisitor.Visit(joinClause.InnerSequence, VisitorParameters), - _hqlTree.TreeBuilder.Alias(querySourceName))); + _hqlTree.AddFromClause(join); } public override void VisitGroupJoinClause(GroupJoinClause groupJoinClause, QueryModel queryModel, int index) @@ -552,5 +558,22 @@ public override void VisitNhWithClause(NhWithClause withClause, QueryModel query var expression = HqlGeneratorExpressionVisitor.Visit(withClause.Predicate, VisitorParameters).ToBooleanExpression(); _hqlTree.AddWhereClause(expression); } + + private HqlTreeNode CreateCrossJoin(HqlTreeNode joinExpression, HqlAlias alias) + { + if (VisitorParameters.SessionFactory.Dialect.SupportsCrossJoin) + { + return _hqlTree.TreeBuilder.CrossJoin(joinExpression.AsExpression(), alias); + } + + // Simulate cross join as a inner join on 1=1 + var join = _hqlTree.TreeBuilder.InnerJoin(joinExpression.AsExpression(), alias); + var onExpression = _hqlTree.TreeBuilder.Equality( + _hqlTree.TreeBuilder.True(), + _hqlTree.TreeBuilder.True()); + join.AddChild(_hqlTree.TreeBuilder.With(onExpression)); + + return join; + } } } diff --git a/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs b/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs index 66508f01eef..68f88c6cc55 100644 --- a/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs +++ b/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs @@ -77,10 +77,21 @@ internal WhereJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner) _joiner = joiner; } + public Expression Transform(Expression expression) + { + var result = Visit(expression); + PostTransform(); + return result; + } + public void Transform(IClause whereClause) { whereClause.TransformExpressions(Visit); + PostTransform(); + } + private void PostTransform() + { var values = _values.Pop(); foreach (var memberExpression in values.MemberExpressions) diff --git a/src/NHibernate/SqlCommand/ANSIJoinFragment.cs b/src/NHibernate/SqlCommand/ANSIJoinFragment.cs index aa9421fa08e..063c361eb56 100644 --- a/src/NHibernate/SqlCommand/ANSIJoinFragment.cs +++ b/src/NHibernate/SqlCommand/ANSIJoinFragment.cs @@ -33,12 +33,21 @@ public override void AddJoin(string tableName, string alias, string[] fkColumns, case JoinType.FullJoin: joinString = " full outer join "; break; + case JoinType.CrossJoin: + joinString = " cross join "; + break; default: throw new AssertionFailure("undefined join type"); } - _fromFragment.Add(joinString + tableName + ' ' + alias + " on "); + _fromFragment.Add(joinString).Add(tableName).Add(" ").Add(alias).Add(" "); + if (joinType == JoinType.CrossJoin) + { + // Cross join does not have an 'on' statement + return; + } + _fromFragment.Add("on "); if (fkColumns.Length == 0) { AddBareCondition(_fromFragment, on); diff --git a/src/NHibernate/SqlCommand/JoinFragment.cs b/src/NHibernate/SqlCommand/JoinFragment.cs index 18b249bd57f..38c5d8ce280 100644 --- a/src/NHibernate/SqlCommand/JoinFragment.cs +++ b/src/NHibernate/SqlCommand/JoinFragment.cs @@ -10,7 +10,8 @@ public enum JoinType InnerJoin = 0, FullJoin = 4, LeftOuterJoin = 1, - RightOuterJoin = 2 + RightOuterJoin = 2, + CrossJoin = 8 } ///