From 88abcf0629fb454baaaef149a804c5da51afe33a Mon Sep 17 00:00:00 2001 From: "g.yakimov" Date: Thu, 19 Mar 2020 10:05:54 +0200 Subject: [PATCH 01/43] improve adding of with clauses when entity overrides property from base class --- src/NHibernate.Test/Hql/EntityJoinHqlTest.cs | 46 +++++++++++- src/NHibernate.Test/Hql/Node.cs | 75 +++++++++++++++++++ .../Hql/Ast/ANTLR/Tree/ComponentJoin.cs | 5 ++ .../Criteria/CriteriaQueryTranslator.cs | 2 +- .../Collection/AbstractCollectionPersister.cs | 6 ++ .../Collection/CollectionPropertyMapping.cs | 8 +- .../Collection/ElementPropertyMapping.cs | 8 +- .../Entity/AbstractEntityPersister.cs | 5 ++ .../Entity/AbstractPropertyMapping.cs | 6 ++ .../Entity/BasicEntityPropertyMapping.cs | 14 ++++ .../Persister/Entity/IPropertyMapping.cs | 4 +- 11 files changed, 174 insertions(+), 5 deletions(-) create mode 100644 src/NHibernate.Test/Hql/Node.cs diff --git a/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs b/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs index 0b8d0c11a6b..1ef0765eafa 100644 --- a/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs +++ b/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs @@ -1,5 +1,8 @@ -using System.Text.RegularExpressions; +using System.Collections.Generic; +using System.Linq; +using System.Text.RegularExpressions; using NHibernate.Cfg.MappingSchema; +using NHibernate.Criterion; using NHibernate.Mapping.ByCode; using NHibernate.Test.Hql.EntityJoinHqlTestEntities; using NUnit.Framework; @@ -279,6 +282,43 @@ public void CrossJoinAndWithClause() } } + [Test] + public void Join_Inheritance() + { + // arrange + IEnumerable results; + var person = new PersonBase { Login = "dave", FamilyName = "grohl" }; + var visit_1 = new UserEntityVisit { PersonBase = person }; + var visit_2 = new UserEntityVisit { PersonBase = person }; + + using (ISession arrangeSession = OpenSession()) + using (ITransaction tx = arrangeSession.BeginTransaction()) + { + arrangeSession.Save(person); + arrangeSession.Save(visit_1); + arrangeSession.Save(visit_2); + arrangeSession.Flush(); + + tx.Commit(); + } + + // act + using (var session = OpenSession()) + { + results = session.CreateCriteria() + .CreateCriteria( + $"{nameof(UserEntityVisit.PersonBase)}", + "f", + SqlCommand.JoinType.LeftOuterJoin, + Restrictions.Eq("Deleted", false)) + .List() + .Select(x => x.Id); + } + + // assert + Assert.That(results, Is.EquivalentTo(new[] { visit_1.Id, visit_2.Id, })); + } + #region Test Setup protected override HbmMapping GetMappings() @@ -351,6 +391,10 @@ protected override HbmMapping GetMappings() rc.Property(e => e.Name); }); + + Node.AddMapping(mapper); + UserEntityVisit.AddMapping(mapper); + return mapper.CompileMappingForAllExplicitlyAddedEntities(); } diff --git a/src/NHibernate.Test/Hql/Node.cs b/src/NHibernate.Test/Hql/Node.cs new file mode 100644 index 00000000000..e3cb2937002 --- /dev/null +++ b/src/NHibernate.Test/Hql/Node.cs @@ -0,0 +1,75 @@ +using System; +using NHibernate.Mapping.ByCode; + +namespace NHibernate.Test.Hql +{ + public abstract class Node + { + private int _id; + public virtual int Id + { + get { return _id; } + set { _id = value; } + } + + public virtual bool Deleted { get; set; } + public virtual string FamilyName { get; set; } + + public static void AddMapping(ModelMapper mapper) + { + mapper.Class(ca => + { + ca.Id(x => x.Id, map => map.Generator(Generators.Identity)); + ca.Property(x => x.Deleted); + ca.Property(x => x.FamilyName); + ca.Table("Node"); + ca.Abstract(true); + }); + + mapper.JoinedSubclass(ca => + { + ca.Key(x => x.Column("FK_Node_ID")); + ca.Extends(typeof(Node)); + ca.Property(x => x.Deleted); + ca.Property(x => x.Login); + }); + } + } + + [Serializable] + public class PersonBase : Node + { + public virtual string Login { get; set; } + public override bool Deleted { get; set; } + } + + [Serializable] + public class UserEntityVisit + { + private int _id; + public virtual int Id + { + get { return _id; } + set { _id = value; } + } + + public virtual bool Deleted { get; set; } + + private PersonBase _PersonBase; + public virtual PersonBase PersonBase + { + get { return _PersonBase; } + set { _PersonBase = value; } + } + + public static void AddMapping(ModelMapper mapper) + { + mapper.Class(ca => + { + ca.Id(x => x.Id, map => map.Generator(Generators.Identity)); + ca.Property(x => x.Deleted); + ca.ManyToOne(x => x.PersonBase); + }); + } + } +} diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs index bfffb9be928..f434b3b7f51 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs @@ -150,6 +150,11 @@ public bool TryToType(string propertyName, out IType type) return fromElementType.GetBasePropertyMapping().TryToType(GetPropertyPath(propertyName), out type); } + public string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) + { + return fromElementType.GetBasePropertyMapping().ToColumns(pathCriteria, GetPropertyPath(propertyName), getSQLAlias); + } + public string[] ToColumns(string alias, string propertyName) { return fromElementType.GetBasePropertyMapping().ToColumns(alias, GetPropertyPath(propertyName)); diff --git a/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs b/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs index 682dc73e084..8ebad5641c4 100644 --- a/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs +++ b/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs @@ -766,7 +766,7 @@ private bool TryGetColumns(ICriteria subcriteria, string path, bool verifyProper return false; } - columns = propertyMapping.ToColumns(GetSQLAlias(pathCriteria), propertyName); + columns = propertyMapping.ToColumns(pathCriteria, propertyName, GetSQLAlias); return true; } diff --git a/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs b/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs index 20701df9435..593482633e8 100644 --- a/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs +++ b/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs @@ -1386,6 +1386,12 @@ public bool IsManyToManyFiltered(IDictionary enabledFilters) return IsManyToMany && (manyToManyWhereString != null || manyToManyFilterHelper.IsAffectedBy(enabledFilters)); } + public string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) + { + string alias = getSQLAlias(pathCriteria); + return ToColumns(alias, propertyName); + } + public string[] ToColumns(string alias, string propertyName) { if ("index".Equals(propertyName)) diff --git a/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs b/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs index e9e2f89dc51..de56cb9cd81 100644 --- a/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs +++ b/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs @@ -57,6 +57,12 @@ public bool TryToType(string propertyName, out IType type) } } + public string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) + { + string alias = getSQLAlias(pathCriteria); + return ToColumns(alias, propertyName); + } + public string[] ToColumns(string alias, string propertyName) { string[] cols; @@ -117,4 +123,4 @@ public IType Type get { return memberPersister.CollectionType; } } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs b/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs index 20e9899ddb6..ad412a19774 100644 --- a/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs +++ b/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs @@ -47,6 +47,12 @@ public bool TryToType(string propertyName, out IType outType) } } + public string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) + { + string alias = getSQLAlias(pathCriteria); + return ToColumns(alias, propertyName); + } + public string[] ToColumns(string alias, string propertyName) { if (propertyName == null || "id".Equals(propertyName)) @@ -71,4 +77,4 @@ public IType Type #endregion } -} \ No newline at end of file +} diff --git a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs index 924da726cc1..879a0d670f3 100644 --- a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs @@ -2050,6 +2050,11 @@ public virtual string GetRootTableAlias(string drivingAlias) return drivingAlias; } + public virtual string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) + { + return propertyMapping.ToColumns(pathCriteria, propertyName, getSQLAlias); + } + public virtual string[] ToColumns(string alias, string propertyName) { return propertyMapping.ToColumns(alias, propertyName); diff --git a/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs b/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs index c027568bf18..46e8ca70e34 100644 --- a/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs +++ b/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs @@ -44,6 +44,12 @@ public bool TryToType(string propertyName, out IType type) return typesByPropertyPath.TryGetValue(propertyName, out type); } + public virtual string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) + { + string alias = getSQLAlias(pathCriteria); + return ToColumns(alias, propertyName); + } + public virtual string[] ToColumns(string alias, string propertyName) { //TODO: *two* hashmap lookups here is one too many... diff --git a/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs b/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs index 02f625bd550..6c7b31a6940 100644 --- a/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs +++ b/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs @@ -1,4 +1,7 @@ +using System; +using NHibernate.Criterion; using NHibernate.Type; +using static NHibernate.Impl.CriteriaImpl; namespace NHibernate.Persister.Entity { @@ -26,6 +29,17 @@ public override IType Type get { return persister.Type; } } + public override string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) + { + var withClause = pathCriteria as Subcriteria != null ? ((Subcriteria) pathCriteria).WithClause as SimpleExpression : null; + if (withClause != null && withClause.PropertyName == propertyName) + { + return base.ToColumns(persister.GenerateTableAlias(getSQLAlias(pathCriteria), 0), propertyName); + } + + return base.ToColumns(pathCriteria, propertyName, getSQLAlias); + } + public override string[] ToColumns(string alias, string propertyName) { return diff --git a/src/NHibernate/Persister/Entity/IPropertyMapping.cs b/src/NHibernate/Persister/Entity/IPropertyMapping.cs index dbe08dd9139..b348d36eae9 100644 --- a/src/NHibernate/Persister/Entity/IPropertyMapping.cs +++ b/src/NHibernate/Persister/Entity/IPropertyMapping.cs @@ -29,6 +29,8 @@ public interface IPropertyMapping /// true if a type was found, false if not bool TryToType(string propertyName, out IType type); + string[] ToColumns(ICriteria pathCriteria, string propertyName, System.Func getSQLAlias); + /// /// Given a query alias and a property path, return the qualified column name /// @@ -40,4 +42,4 @@ public interface IPropertyMapping /// Given a property path, return the corresponding column name(s). string[] ToColumns(string propertyName); } -} \ No newline at end of file +} From 881ba039a84ea2a893ba086726d0f5325d130e43 Mon Sep 17 00:00:00 2001 From: "g.yakimov" Date: Tue, 31 Mar 2020 12:25:11 +0300 Subject: [PATCH 02/43] remove ToColumns with 3 arguments --- src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs | 5 ----- .../Loader/Criteria/CriteriaQueryTranslator.cs | 2 +- .../Collection/AbstractCollectionPersister.cs | 6 ------ .../Collection/CollectionPropertyMapping.cs | 8 +------- .../Persister/Collection/ElementPropertyMapping.cs | 8 +------- .../Persister/Entity/AbstractEntityPersister.cs | 5 ----- .../Persister/Entity/AbstractPropertyMapping.cs | 6 ------ .../Persister/Entity/BasicEntityPropertyMapping.cs | 14 -------------- .../Persister/Entity/IPropertyMapping.cs | 4 +--- 9 files changed, 4 insertions(+), 54 deletions(-) diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs index f434b3b7f51..bfffb9be928 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs @@ -150,11 +150,6 @@ public bool TryToType(string propertyName, out IType type) return fromElementType.GetBasePropertyMapping().TryToType(GetPropertyPath(propertyName), out type); } - public string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) - { - return fromElementType.GetBasePropertyMapping().ToColumns(pathCriteria, GetPropertyPath(propertyName), getSQLAlias); - } - public string[] ToColumns(string alias, string propertyName) { return fromElementType.GetBasePropertyMapping().ToColumns(alias, GetPropertyPath(propertyName)); diff --git a/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs b/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs index 8ebad5641c4..682dc73e084 100644 --- a/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs +++ b/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs @@ -766,7 +766,7 @@ private bool TryGetColumns(ICriteria subcriteria, string path, bool verifyProper return false; } - columns = propertyMapping.ToColumns(pathCriteria, propertyName, GetSQLAlias); + columns = propertyMapping.ToColumns(GetSQLAlias(pathCriteria), propertyName); return true; } diff --git a/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs b/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs index 593482633e8..20701df9435 100644 --- a/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs +++ b/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs @@ -1386,12 +1386,6 @@ public bool IsManyToManyFiltered(IDictionary enabledFilters) return IsManyToMany && (manyToManyWhereString != null || manyToManyFilterHelper.IsAffectedBy(enabledFilters)); } - public string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) - { - string alias = getSQLAlias(pathCriteria); - return ToColumns(alias, propertyName); - } - public string[] ToColumns(string alias, string propertyName) { if ("index".Equals(propertyName)) diff --git a/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs b/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs index de56cb9cd81..e9e2f89dc51 100644 --- a/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs +++ b/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs @@ -57,12 +57,6 @@ public bool TryToType(string propertyName, out IType type) } } - public string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) - { - string alias = getSQLAlias(pathCriteria); - return ToColumns(alias, propertyName); - } - public string[] ToColumns(string alias, string propertyName) { string[] cols; @@ -123,4 +117,4 @@ public IType Type get { return memberPersister.CollectionType; } } } -} +} \ No newline at end of file diff --git a/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs b/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs index ad412a19774..20e9899ddb6 100644 --- a/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs +++ b/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs @@ -47,12 +47,6 @@ public bool TryToType(string propertyName, out IType outType) } } - public string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) - { - string alias = getSQLAlias(pathCriteria); - return ToColumns(alias, propertyName); - } - public string[] ToColumns(string alias, string propertyName) { if (propertyName == null || "id".Equals(propertyName)) @@ -77,4 +71,4 @@ public IType Type #endregion } -} +} \ No newline at end of file diff --git a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs index 879a0d670f3..924da726cc1 100644 --- a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs @@ -2050,11 +2050,6 @@ public virtual string GetRootTableAlias(string drivingAlias) return drivingAlias; } - public virtual string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) - { - return propertyMapping.ToColumns(pathCriteria, propertyName, getSQLAlias); - } - public virtual string[] ToColumns(string alias, string propertyName) { return propertyMapping.ToColumns(alias, propertyName); diff --git a/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs b/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs index 46e8ca70e34..c027568bf18 100644 --- a/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs +++ b/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs @@ -44,12 +44,6 @@ public bool TryToType(string propertyName, out IType type) return typesByPropertyPath.TryGetValue(propertyName, out type); } - public virtual string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) - { - string alias = getSQLAlias(pathCriteria); - return ToColumns(alias, propertyName); - } - public virtual string[] ToColumns(string alias, string propertyName) { //TODO: *two* hashmap lookups here is one too many... diff --git a/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs b/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs index 6c7b31a6940..02f625bd550 100644 --- a/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs +++ b/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs @@ -1,7 +1,4 @@ -using System; -using NHibernate.Criterion; using NHibernate.Type; -using static NHibernate.Impl.CriteriaImpl; namespace NHibernate.Persister.Entity { @@ -29,17 +26,6 @@ public override IType Type get { return persister.Type; } } - public override string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) - { - var withClause = pathCriteria as Subcriteria != null ? ((Subcriteria) pathCriteria).WithClause as SimpleExpression : null; - if (withClause != null && withClause.PropertyName == propertyName) - { - return base.ToColumns(persister.GenerateTableAlias(getSQLAlias(pathCriteria), 0), propertyName); - } - - return base.ToColumns(pathCriteria, propertyName, getSQLAlias); - } - public override string[] ToColumns(string alias, string propertyName) { return diff --git a/src/NHibernate/Persister/Entity/IPropertyMapping.cs b/src/NHibernate/Persister/Entity/IPropertyMapping.cs index b348d36eae9..dbe08dd9139 100644 --- a/src/NHibernate/Persister/Entity/IPropertyMapping.cs +++ b/src/NHibernate/Persister/Entity/IPropertyMapping.cs @@ -29,8 +29,6 @@ public interface IPropertyMapping /// true if a type was found, false if not bool TryToType(string propertyName, out IType type); - string[] ToColumns(ICriteria pathCriteria, string propertyName, System.Func getSQLAlias); - /// /// Given a query alias and a property path, return the qualified column name /// @@ -42,4 +40,4 @@ public interface IPropertyMapping /// Given a property path, return the corresponding column name(s). string[] ToColumns(string propertyName); } -} +} \ No newline at end of file From 487bcb981be48313fd8cc9bacd932b57ae482087 Mon Sep 17 00:00:00 2001 From: "g.yakimov" Date: Tue, 31 Mar 2020 12:39:24 +0300 Subject: [PATCH 03/43] consider overriden properties when getting subclass property table number --- src/NHibernate/Persister/Entity/AbstractEntityPersister.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs index 924da726cc1..0bf2e57c673 100644 --- a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs @@ -2110,7 +2110,7 @@ public virtual int GetSubclassPropertyTableNumber(string propertyPath) return getSubclassColumnTableNumberClosure()[idx]; } }*/ - int index = Array.IndexOf(SubclassPropertyNameClosure, rootPropertyName); //TODO: optimize this better! + int index = Array.LastIndexOf(SubclassPropertyNameClosure, rootPropertyName); //TODO: optimize this better! return index == -1 ? 0 : GetSubclassPropertyTableNumber(index); } From 4844868dd5b17a36f157dea317263c9e53f7a648 Mon Sep 17 00:00:00 2001 From: "g.yakimov" Date: Thu, 9 Apr 2020 18:45:35 +0300 Subject: [PATCH 04/43] special handling of with clauses --- src/NHibernate.Test/App.config | 4 +-- .../Ast/ANTLR/Tree/AssignmentSpecification.cs | 2 +- .../Hql/Ast/ANTLR/Tree/ComponentJoin.cs | 8 ++--- .../Hql/Ast/ANTLR/Tree/IntoClause.cs | 4 +-- .../Criteria/CriteriaQueryTranslator.cs | 11 ++++++- .../Collection/AbstractCollectionPersister.cs | 8 ++--- .../Collection/CollectionPropertyMapping.cs | 6 ++-- .../Collection/ElementPropertyMapping.cs | 6 ++-- .../Entity/AbstractEntityPersister.cs | 31 ++++++++++--------- .../Entity/AbstractPropertyMapping.cs | 4 +-- .../Entity/BasicEntityPropertyMapping.cs | 7 ++--- .../Persister/Entity/IPropertyMapping.cs | 7 +++-- src/NHibernate/Persister/Entity/IQueryable.cs | 3 +- .../Entity/JoinedSubclassEntityPersister.cs | 6 ++-- .../Entity/SingleTableEntityPersister.cs | 10 +++--- .../Entity/UnionSubclassEntityPersister.cs | 4 +-- 16 files changed, 67 insertions(+), 54 deletions(-) diff --git a/src/NHibernate.Test/App.config b/src/NHibernate.Test/App.config index d3965012af5..8d0aa714996 100644 --- a/src/NHibernate.Test/App.config +++ b/src/NHibernate.Test/App.config @@ -7,7 +7,7 @@ - + @@ -31,7 +31,7 @@ NHibernate.Dialect.MsSql2008Dialect NHibernate.Driver.Sql2008ClientDriver - Server=localhost\sqlexpress;Database=nhibernate;Integrated Security=SSPI + Server=localhost;Database=nhibernate;Integrated Security=SSPI NHibernate.Test.DebugConnectionProvider, NHibernate.Test ReadCommitted diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/AssignmentSpecification.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/AssignmentSpecification.cs index 351d29318fa..4f450eead3f 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/AssignmentSpecification.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/AssignmentSpecification.cs @@ -61,7 +61,7 @@ public AssignmentSpecification(IASTNode eq, IQueryable persister) } else { - temp.Add(persister.GetSubclassTableName(persister.GetSubclassPropertyTableNumber(propertyPath))); + temp.Add(persister.GetSubclassTableName(persister.GetSubclassPropertyTableNumber(propertyPath, false))); } _tableNames = new HashSet(temp); diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs index bfffb9be928..448bf8ceda2 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs @@ -150,14 +150,14 @@ public bool TryToType(string propertyName, out IType type) return fromElementType.GetBasePropertyMapping().TryToType(GetPropertyPath(propertyName), out type); } - public string[] ToColumns(string alias, string propertyName) + public string[] ToColumns(string alias, string propertyName, bool useLastIndex = false) { - return fromElementType.GetBasePropertyMapping().ToColumns(alias, GetPropertyPath(propertyName)); + return fromElementType.GetBasePropertyMapping().ToColumns(alias, GetPropertyPath(propertyName), useLastIndex); } - public string[] ToColumns(string propertyName) + public string[] ToColumns(string propertyName, bool useLastIndex = false) { - return fromElementType.GetBasePropertyMapping().ToColumns(GetPropertyPath(propertyName)); + return fromElementType.GetBasePropertyMapping().ToColumns(GetPropertyPath(propertyName), useLastIndex); } #endregion diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/IntoClause.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/IntoClause.cs index a59e6d5a8ec..05e4b9a4248 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/IntoClause.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/IntoClause.cs @@ -214,7 +214,7 @@ private bool IsSuperclassProperty(string propertyName) // // we may want to disallow it for discrim-subclass just for // consistency-sake (currently does not work anyway)... - return _persister.GetSubclassPropertyTableNumber(propertyName) != 0; + return _persister.GetSubclassPropertyTableNumber(propertyName, false) != 0; } /// @@ -263,4 +263,4 @@ private static bool AreSqlTypesCompatible(SqlType target, SqlType source) return target.Equals(source); } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs b/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs index 682dc73e084..6dafb3fda1d 100644 --- a/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs +++ b/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs @@ -13,6 +13,7 @@ using NHibernate.Type; using NHibernate.Util; using IQueryable = NHibernate.Persister.Entity.IQueryable; +using static NHibernate.Impl.CriteriaImpl; namespace NHibernate.Loader.Criteria { @@ -766,7 +767,15 @@ private bool TryGetColumns(ICriteria subcriteria, string path, bool verifyProper return false; } - columns = propertyMapping.ToColumns(GetSQLAlias(pathCriteria), propertyName); + // here we can check if the condition belongs to a with clause + bool useLastIndex = false; + var withClause = pathCriteria as Subcriteria != null ? ((Subcriteria) pathCriteria).WithClause as SimpleExpression : null; + if (withClause != null && withClause.PropertyName == propertyName) + { + useLastIndex = true; + } + + columns = propertyMapping.ToColumns(GetSQLAlias(pathCriteria), propertyName, useLastIndex); return true; } diff --git a/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs b/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs index 20701df9435..707bf423d38 100644 --- a/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs +++ b/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs @@ -1386,7 +1386,7 @@ public bool IsManyToManyFiltered(IDictionary enabledFilters) return IsManyToMany && (manyToManyWhereString != null || manyToManyFilterHelper.IsAffectedBy(enabledFilters)); } - public string[] ToColumns(string alias, string propertyName) + public string[] ToColumns(string alias, string propertyName, bool useLastIndex = false) { if ("index".Equals(propertyName)) { @@ -1397,10 +1397,10 @@ public string[] ToColumns(string alias, string propertyName) return StringHelper.Qualify(alias, indexColumnNames); } - return elementPropertyMapping.ToColumns(alias, propertyName); + return elementPropertyMapping.ToColumns(alias, propertyName, useLastIndex); } - public string[] ToColumns(string propertyName) + public string[] ToColumns(string propertyName, bool useLastIndex = false) { if ("index".Equals(propertyName)) { @@ -1412,7 +1412,7 @@ public string[] ToColumns(string propertyName) return indexColumnNames; } - return elementPropertyMapping.ToColumns(propertyName); + return elementPropertyMapping.ToColumns(propertyName, useLastIndex); } protected abstract SqlCommandInfo GenerateDeleteString(); diff --git a/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs b/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs index e9e2f89dc51..569c53eb63d 100644 --- a/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs +++ b/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs @@ -57,7 +57,7 @@ public bool TryToType(string propertyName, out IType type) } } - public string[] ToColumns(string alias, string propertyName) + public string[] ToColumns(string alias, string propertyName, bool useLastIndex = false) { string[] cols; switch (propertyName) @@ -107,7 +107,7 @@ public string[] ToColumns(string alias, string propertyName) } } - public string[] ToColumns(string propertyName) + public string[] ToColumns(string propertyName, bool useLastIndex = false) { throw new System.NotSupportedException("References to collections must be define a SQL alias"); } @@ -117,4 +117,4 @@ public IType Type get { return memberPersister.CollectionType; } } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs b/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs index 20e9899ddb6..6ab04642a5f 100644 --- a/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs +++ b/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs @@ -47,7 +47,7 @@ public bool TryToType(string propertyName, out IType outType) } } - public string[] ToColumns(string alias, string propertyName) + public string[] ToColumns(string alias, string propertyName, bool useLastIndex) { if (propertyName == null || "id".Equals(propertyName)) { @@ -59,7 +59,7 @@ public string[] ToColumns(string alias, string propertyName) } } - public string[] ToColumns(string propertyName) + public string[] ToColumns(string propertyName, bool useLastIndex) { throw new System.NotSupportedException("References to collections must be define a SQL alias"); } @@ -71,4 +71,4 @@ public IType Type #endregion } -} \ No newline at end of file +} diff --git a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs index 0bf2e57c673..32639db4d0e 100644 --- a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs @@ -1118,9 +1118,9 @@ protected virtual bool IsIdOfTable(int property, int table) return false; } - protected abstract int GetSubclassPropertyTableNumber(int i); + protected abstract int GetSubclassPropertyTableNumber(int i, bool useLastIndex); - internal int GetSubclassPropertyTableNumber(string propertyName, string entityName) + internal int GetSubclassPropertyTableNumber(string propertyName, string entityName, bool useLastIndex = false) { var type = propertyMapping.ToType(propertyName); if (type.IsAssociationType && ((IAssociationType) type).UseLHSPrimaryKey) @@ -1271,7 +1271,7 @@ protected internal virtual SqlString GenerateLazySelectString() // use the subclass closure int propertyNumber = GetSubclassPropertyIndex(lazyPropertyNames[i]); - int tableNumber = GetSubclassPropertyTableNumber(propertyNumber); + int tableNumber = GetSubclassPropertyTableNumber(propertyNumber, false); tableNumbers.Add(tableNumber); int[] colNumbers = subclassPropertyColumnNumberClosure[propertyNumber]; @@ -1326,7 +1326,7 @@ protected virtual IDictionary GenerateLazySelectStringsByFetc // use the subclass closure var propertyNumber = GetSubclassPropertyIndex(lazyPropertyDescriptor.Name); - var tableNumber = GetSubclassPropertyTableNumber(propertyNumber); + var tableNumber = GetSubclassPropertyTableNumber(propertyNumber, false); tableNumbers.Add(tableNumber); var colNumbers = subclassPropertyColumnNumberClosure[propertyNumber]; @@ -2050,12 +2050,12 @@ public virtual string GetRootTableAlias(string drivingAlias) return drivingAlias; } - public virtual string[] ToColumns(string alias, string propertyName) + public virtual string[] ToColumns(string alias, string propertyName, bool useLastIndex = false) { - return propertyMapping.ToColumns(alias, propertyName); + return propertyMapping.ToColumns(alias, propertyName, useLastIndex); } - public string[] ToColumns(string propertyName) + public string[] ToColumns(string propertyName, bool useLastIndex = false) { return propertyMapping.GetColumnNames(propertyName); } @@ -2083,7 +2083,7 @@ public string[] GetPropertyColumnNames(string propertyName) /// SingleTableEntityPersister defines an overloaded form /// which takes the entity name. /// - public virtual int GetSubclassPropertyTableNumber(string propertyPath) + public virtual int GetSubclassPropertyTableNumber(string propertyPath, bool useLastIndex) { string rootPropertyName = StringHelper.Root(propertyPath); IType type = propertyMapping.ToType(rootPropertyName); @@ -2110,13 +2110,16 @@ public virtual int GetSubclassPropertyTableNumber(string propertyPath) return getSubclassColumnTableNumberClosure()[idx]; } }*/ - int index = Array.LastIndexOf(SubclassPropertyNameClosure, rootPropertyName); //TODO: optimize this better! - return index == -1 ? 0 : GetSubclassPropertyTableNumber(index); + int index = useLastIndex + ? Array.LastIndexOf(SubclassPropertyNameClosure, rootPropertyName) + : Array.IndexOf(SubclassPropertyNameClosure, rootPropertyName); //TODO: optimize this better! + + return index == -1 ? 0 : GetSubclassPropertyTableNumber(index, false); } public virtual Declarer GetSubclassPropertyDeclarer(string propertyPath) { - int tableIndex = GetSubclassPropertyTableNumber(propertyPath); + int tableIndex = GetSubclassPropertyTableNumber(propertyPath, false); if (tableIndex == 0) { return Declarer.Class; @@ -2164,7 +2167,7 @@ private string GetSubclassAliasedColumn(string rootAlias, int tableNumber, strin public string[] ToColumns(string name, int i) { - string alias = GenerateTableAlias(name, GetSubclassPropertyTableNumber(i)); + string alias = GenerateTableAlias(name, GetSubclassPropertyTableNumber(i, false)); string[] cols = GetSubclassPropertyColumnNames(i); string[] templates = SubclassPropertyFormulaTemplateClosure[i]; string[] result = new string[cols.Length]; @@ -2398,7 +2401,7 @@ private EntityLoader GetAppropriateUniqueKeyLoader(string propertyName, IDiction return uniqueKeyLoaders[propertyName]; } - return CreateUniqueKeyLoader(propertyMapping.ToType(propertyName), propertyMapping.ToColumns(propertyName), enabledFilters); + return CreateUniqueKeyLoader(propertyMapping.ToType(propertyName), propertyMapping.ToColumns(propertyName, false), enabledFilters); } public int GetPropertyIndex(string propertyName) @@ -3682,7 +3685,7 @@ private IDictionary GetColumnsToTableAliasMap(string rootAlias) if (cols != null && cols.Length > 0) { - PropertyKey key = new PropertyKey(cols[0], GetSubclassPropertyTableNumber(i)); + PropertyKey key = new PropertyKey(cols[0], GetSubclassPropertyTableNumber(i, false)); propDictionary[key] = property; } } diff --git a/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs b/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs index c027568bf18..40f9550802e 100644 --- a/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs +++ b/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs @@ -44,7 +44,7 @@ public bool TryToType(string propertyName, out IType type) return typesByPropertyPath.TryGetValue(propertyName, out type); } - public virtual string[] ToColumns(string alias, string propertyName) + public virtual string[] ToColumns(string alias, string propertyName, bool useLastIndex) { //TODO: *two* hashmap lookups here is one too many... string[] columns = GetColumns(propertyName); @@ -71,7 +71,7 @@ private string[] GetColumns(string propertyName) return columns; } - public virtual string[] ToColumns(string propertyName) + public virtual string[] ToColumns(string propertyName, bool useLastIndex) { string[] columns = GetColumns(propertyName); diff --git a/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs b/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs index 02f625bd550..ff0e71aefc0 100644 --- a/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs +++ b/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs @@ -26,11 +26,10 @@ public override IType Type get { return persister.Type; } } - public override string[] ToColumns(string alias, string propertyName) + public override string[] ToColumns(string alias, string propertyName, bool useLastIndex) { - return - base.ToColumns(persister.GenerateTableAlias(alias, persister.GetSubclassPropertyTableNumber(propertyName)), - propertyName); + var tableAlias = persister.GenerateTableAlias(alias, persister.GetSubclassPropertyTableNumber(propertyName, useLastIndex)); + return base.ToColumns(tableAlias, propertyName, useLastIndex); } } } diff --git a/src/NHibernate/Persister/Entity/IPropertyMapping.cs b/src/NHibernate/Persister/Entity/IPropertyMapping.cs index dbe08dd9139..fc1dc5bf495 100644 --- a/src/NHibernate/Persister/Entity/IPropertyMapping.cs +++ b/src/NHibernate/Persister/Entity/IPropertyMapping.cs @@ -34,10 +34,11 @@ public interface IPropertyMapping /// /// /// + /// /// - string[] ToColumns(string alias, string propertyName); + string[] ToColumns(string alias, string propertyName, bool useLastIndex = false); /// Given a property path, return the corresponding column name(s). - string[] ToColumns(string propertyName); + string[] ToColumns(string propertyName, bool useLastIndex = false); } -} \ No newline at end of file +} diff --git a/src/NHibernate/Persister/Entity/IQueryable.cs b/src/NHibernate/Persister/Entity/IQueryable.cs index 2178b43a024..fdf69b1ddf8 100644 --- a/src/NHibernate/Persister/Entity/IQueryable.cs +++ b/src/NHibernate/Persister/Entity/IQueryable.cs @@ -112,13 +112,14 @@ public interface IQueryable : ILoadable, IPropertyMapping, IJoinable /// to which this property is mapped. /// /// The name of the property. + /// The name of the property. /// The number of the table to which the property is mapped. /// /// Note that this is not relative to the results from {@link #getConstraintOrderedTableNameClosure()}. /// It is relative to the subclass table name closure maintained internal to the persister (yick!). /// It is also relative to the indexing used to resolve {@link #getSubclassTableName}... /// - int GetSubclassPropertyTableNumber(string propertyPath); + int GetSubclassPropertyTableNumber(string propertyPath, bool useLastIndex); /// Determine whether the given property is declared by our /// mapped class, our super class, or one of our subclasses... diff --git a/src/NHibernate/Persister/Entity/JoinedSubclassEntityPersister.cs b/src/NHibernate/Persister/Entity/JoinedSubclassEntityPersister.cs index 47b0a7c19a7..d11fb3d37c9 100644 --- a/src/NHibernate/Persister/Entity/JoinedSubclassEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/JoinedSubclassEntityPersister.cs @@ -526,7 +526,7 @@ public override string GenerateFilterConditionAlias(string rootAlias) return GenerateTableAlias(rootAlias, tableSpan - 1); } - public override string[] ToColumns(string alias, string propertyName) + public override string[] ToColumns(string alias, string propertyName, bool useLastIndex) { if (EntityClass.Equals(propertyName)) { @@ -542,11 +542,11 @@ public override string[] ToColumns(string alias, string propertyName) } else { - return base.ToColumns(alias, propertyName); + return base.ToColumns(alias, propertyName, useLastIndex); } } - protected override int GetSubclassPropertyTableNumber(int i) + protected override int GetSubclassPropertyTableNumber(int i, bool useLastIndex) { return subclassPropertyTableNumberClosure[i]; } diff --git a/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs b/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs index 9aa8a71a3e4..4dbb004fc8e 100644 --- a/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs @@ -675,7 +675,7 @@ protected override void AddDiscriminatorToSelect(SelectFragment select, string n select.AddColumn(name, DiscriminatorColumnName, DiscriminatorAlias); } - protected override int GetSubclassPropertyTableNumber(int i) + protected override int GetSubclassPropertyTableNumber(int i, bool useLastIndex) { return subclassPropertyTableNumberClosure[i]; } @@ -696,12 +696,12 @@ protected override void AddDiscriminatorToInsert(SqlInsertBuilder insert) protected override bool IsSubclassPropertyDeferred(string propertyName, string entityName) { return - hasSequentialSelects && IsSubclassTableSequentialSelect(base.GetSubclassPropertyTableNumber(propertyName, entityName)); + hasSequentialSelects && IsSubclassTableSequentialSelect(base.GetSubclassPropertyTableNumber(propertyName, entityName, false)); } protected override bool IsPropertyDeferred(int propertyIndex) { - return _hasSequentialSelect && subclassTableSequentialSelect[GetSubclassPropertyTableNumber(propertyIndex)]; + return _hasSequentialSelect && subclassTableSequentialSelect[GetSubclassPropertyTableNumber(propertyIndex, false)]; } //Since v5.3 @@ -713,9 +713,9 @@ public override bool HasSequentialSelect //Since v5.3 [Obsolete("This method has no more usage in NHibernate and will be removed in a future version.")] - public new int GetSubclassPropertyTableNumber(string propertyName, string entityName) + public new int GetSubclassPropertyTableNumber(string propertyName, string entityName, bool useLastIndex = false) { - return base.GetSubclassPropertyTableNumber(propertyName, entityName); + return base.GetSubclassPropertyTableNumber(propertyName, entityName, useLastIndex); } //Since v5.3 diff --git a/src/NHibernate/Persister/Entity/UnionSubclassEntityPersister.cs b/src/NHibernate/Persister/Entity/UnionSubclassEntityPersister.cs index c8d2c834cac..0b1dd021b61 100644 --- a/src/NHibernate/Persister/Entity/UnionSubclassEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/UnionSubclassEntityPersister.cs @@ -289,12 +289,12 @@ protected override void AddDiscriminatorToSelect(SelectFragment select, string n select.AddColumn(name, DiscriminatorColumnName, DiscriminatorAlias); } - protected override int GetSubclassPropertyTableNumber(int i) + protected override int GetSubclassPropertyTableNumber(int i, bool useLastIndex) { return 0; } - public override int GetSubclassPropertyTableNumber(string propertyName) + public override int GetSubclassPropertyTableNumber(string propertyName, bool useLastIndex) { return 0; } From 45a15b060700c7cfed0d03b00edaa8d101f63b38 Mon Sep 17 00:00:00 2001 From: "g.yakimov" Date: Fri, 10 Apr 2020 14:02:44 +0300 Subject: [PATCH 05/43] revert config changes --- src/NHibernate.Test/App.config | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/NHibernate.Test/App.config b/src/NHibernate.Test/App.config index 8d0aa714996..d3965012af5 100644 --- a/src/NHibernate.Test/App.config +++ b/src/NHibernate.Test/App.config @@ -7,7 +7,7 @@ - + @@ -31,7 +31,7 @@ NHibernate.Dialect.MsSql2008Dialect NHibernate.Driver.Sql2008ClientDriver - Server=localhost;Database=nhibernate;Integrated Security=SSPI + Server=localhost\sqlexpress;Database=nhibernate;Integrated Security=SSPI NHibernate.Test.DebugConnectionProvider, NHibernate.Test ReadCommitted From bc99fedb2e3146a30fa3dae6513b5a3f5a291d0e Mon Sep 17 00:00:00 2001 From: Alexander Zaytsev Date: Thu, 16 Apr 2020 00:46:44 +1200 Subject: [PATCH 06/43] Remove unused useLastIndex parameter --- .../Persister/Entity/AbstractEntityPersister.cs | 12 ++++++------ .../Entity/JoinedSubclassEntityPersister.cs | 2 +- .../Persister/Entity/SingleTableEntityPersister.cs | 4 ++-- .../Persister/Entity/UnionSubclassEntityPersister.cs | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs index 32639db4d0e..6b5431c7239 100644 --- a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs @@ -1118,7 +1118,7 @@ protected virtual bool IsIdOfTable(int property, int table) return false; } - protected abstract int GetSubclassPropertyTableNumber(int i, bool useLastIndex); + protected abstract int GetSubclassPropertyTableNumber(int i); internal int GetSubclassPropertyTableNumber(string propertyName, string entityName, bool useLastIndex = false) { @@ -1271,7 +1271,7 @@ protected internal virtual SqlString GenerateLazySelectString() // use the subclass closure int propertyNumber = GetSubclassPropertyIndex(lazyPropertyNames[i]); - int tableNumber = GetSubclassPropertyTableNumber(propertyNumber, false); + int tableNumber = GetSubclassPropertyTableNumber(propertyNumber); tableNumbers.Add(tableNumber); int[] colNumbers = subclassPropertyColumnNumberClosure[propertyNumber]; @@ -1326,7 +1326,7 @@ protected virtual IDictionary GenerateLazySelectStringsByFetc // use the subclass closure var propertyNumber = GetSubclassPropertyIndex(lazyPropertyDescriptor.Name); - var tableNumber = GetSubclassPropertyTableNumber(propertyNumber, false); + var tableNumber = GetSubclassPropertyTableNumber(propertyNumber); tableNumbers.Add(tableNumber); var colNumbers = subclassPropertyColumnNumberClosure[propertyNumber]; @@ -2114,7 +2114,7 @@ public virtual int GetSubclassPropertyTableNumber(string propertyPath, bool useL ? Array.LastIndexOf(SubclassPropertyNameClosure, rootPropertyName) : Array.IndexOf(SubclassPropertyNameClosure, rootPropertyName); //TODO: optimize this better! - return index == -1 ? 0 : GetSubclassPropertyTableNumber(index, false); + return index == -1 ? 0 : GetSubclassPropertyTableNumber(index); } public virtual Declarer GetSubclassPropertyDeclarer(string propertyPath) @@ -2167,7 +2167,7 @@ private string GetSubclassAliasedColumn(string rootAlias, int tableNumber, strin public string[] ToColumns(string name, int i) { - string alias = GenerateTableAlias(name, GetSubclassPropertyTableNumber(i, false)); + string alias = GenerateTableAlias(name, GetSubclassPropertyTableNumber(i)); string[] cols = GetSubclassPropertyColumnNames(i); string[] templates = SubclassPropertyFormulaTemplateClosure[i]; string[] result = new string[cols.Length]; @@ -3685,7 +3685,7 @@ private IDictionary GetColumnsToTableAliasMap(string rootAlias) if (cols != null && cols.Length > 0) { - PropertyKey key = new PropertyKey(cols[0], GetSubclassPropertyTableNumber(i, false)); + PropertyKey key = new PropertyKey(cols[0], GetSubclassPropertyTableNumber(i)); propDictionary[key] = property; } } diff --git a/src/NHibernate/Persister/Entity/JoinedSubclassEntityPersister.cs b/src/NHibernate/Persister/Entity/JoinedSubclassEntityPersister.cs index d11fb3d37c9..f70643df99e 100644 --- a/src/NHibernate/Persister/Entity/JoinedSubclassEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/JoinedSubclassEntityPersister.cs @@ -546,7 +546,7 @@ public override string[] ToColumns(string alias, string propertyName, bool useLa } } - protected override int GetSubclassPropertyTableNumber(int i, bool useLastIndex) + protected override int GetSubclassPropertyTableNumber(int i) { return subclassPropertyTableNumberClosure[i]; } diff --git a/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs b/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs index 4dbb004fc8e..785ac6a69c4 100644 --- a/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs @@ -675,7 +675,7 @@ protected override void AddDiscriminatorToSelect(SelectFragment select, string n select.AddColumn(name, DiscriminatorColumnName, DiscriminatorAlias); } - protected override int GetSubclassPropertyTableNumber(int i, bool useLastIndex) + protected override int GetSubclassPropertyTableNumber(int i) { return subclassPropertyTableNumberClosure[i]; } @@ -701,7 +701,7 @@ protected override bool IsSubclassPropertyDeferred(string propertyName, string e protected override bool IsPropertyDeferred(int propertyIndex) { - return _hasSequentialSelect && subclassTableSequentialSelect[GetSubclassPropertyTableNumber(propertyIndex, false)]; + return _hasSequentialSelect && subclassTableSequentialSelect[GetSubclassPropertyTableNumber(propertyIndex)]; } //Since v5.3 diff --git a/src/NHibernate/Persister/Entity/UnionSubclassEntityPersister.cs b/src/NHibernate/Persister/Entity/UnionSubclassEntityPersister.cs index 0b1dd021b61..7c4b5f7c96e 100644 --- a/src/NHibernate/Persister/Entity/UnionSubclassEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/UnionSubclassEntityPersister.cs @@ -289,7 +289,7 @@ protected override void AddDiscriminatorToSelect(SelectFragment select, string n select.AddColumn(name, DiscriminatorColumnName, DiscriminatorAlias); } - protected override int GetSubclassPropertyTableNumber(int i, bool useLastIndex) + protected override int GetSubclassPropertyTableNumber(int i) { return 0; } From 6abb7448f3b951fa0bb1d4626a2bc1b43737cb4c Mon Sep 17 00:00:00 2001 From: Alexander Zaytsev Date: Thu, 16 Apr 2020 00:50:07 +1200 Subject: [PATCH 07/43] Remove unused useLastIndex parameter --- src/NHibernate/Persister/Entity/AbstractEntityPersister.cs | 2 +- .../Persister/Entity/SingleTableEntityPersister.cs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs index 6b5431c7239..3213b39a4c3 100644 --- a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs @@ -1120,7 +1120,7 @@ protected virtual bool IsIdOfTable(int property, int table) protected abstract int GetSubclassPropertyTableNumber(int i); - internal int GetSubclassPropertyTableNumber(string propertyName, string entityName, bool useLastIndex = false) + internal int GetSubclassPropertyTableNumber(string propertyName, string entityName) { var type = propertyMapping.ToType(propertyName); if (type.IsAssociationType && ((IAssociationType) type).UseLHSPrimaryKey) diff --git a/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs b/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs index 785ac6a69c4..9aa8a71a3e4 100644 --- a/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs @@ -696,7 +696,7 @@ protected override void AddDiscriminatorToInsert(SqlInsertBuilder insert) protected override bool IsSubclassPropertyDeferred(string propertyName, string entityName) { return - hasSequentialSelects && IsSubclassTableSequentialSelect(base.GetSubclassPropertyTableNumber(propertyName, entityName, false)); + hasSequentialSelects && IsSubclassTableSequentialSelect(base.GetSubclassPropertyTableNumber(propertyName, entityName)); } protected override bool IsPropertyDeferred(int propertyIndex) @@ -713,9 +713,9 @@ public override bool HasSequentialSelect //Since v5.3 [Obsolete("This method has no more usage in NHibernate and will be removed in a future version.")] - public new int GetSubclassPropertyTableNumber(string propertyName, string entityName, bool useLastIndex = false) + public new int GetSubclassPropertyTableNumber(string propertyName, string entityName) { - return base.GetSubclassPropertyTableNumber(propertyName, entityName, useLastIndex); + return base.GetSubclassPropertyTableNumber(propertyName, entityName); } //Since v5.3 From 26f4b63bcef935ed120ba6a8db74f69347d64857 Mon Sep 17 00:00:00 2001 From: maca88 Date: Sat, 21 Mar 2020 08:36:57 +0100 Subject: [PATCH 08/43] Add cross join support for Hql and Linq query provider (#2327) Fixes: #1128, closes #1060 --- .../FetchLazyPropertiesFixture.cs | 38 +++ .../Async/Hql/EntityJoinHqlTest.cs | 21 ++ .../Async/Linq/ByMethod/JoinTests.cs | 33 +++ .../Async/Linq/LinqQuerySamples.cs | 218 ++++++++++++++--- .../FetchLazyPropertiesFixture.cs | 38 +++ src/NHibernate.Test/Hql/EntityJoinHqlTest.cs | 23 +- .../Linq/ByMethod/JoinTests.cs | 35 ++- src/NHibernate.Test/Linq/LinqQuerySamples.cs | 224 +++++++++++++++--- src/NHibernate.Test/TestCase.cs | 66 ++++++ src/NHibernate/AdoNet/Util/BasicFormatter.cs | 1 + src/NHibernate/Dialect/DB2Dialect.cs | 3 + src/NHibernate/Dialect/Dialect.cs | 5 + src/NHibernate/Dialect/InformixDialect0940.cs | 5 +- src/NHibernate/Dialect/Oracle10gDialect.cs | 3 + src/NHibernate/Dialect/Oracle8iDialect.cs | 3 + src/NHibernate/Dialect/SybaseASA9Dialect.cs | 3 + src/NHibernate/Dialect/SybaseASE15Dialect.cs | 5 +- src/NHibernate/Hql/Ast/ANTLR/Hql.g | 2 + src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.g | 3 + .../Hql/Ast/ANTLR/Util/JoinProcessor.cs | 2 + src/NHibernate/Hql/Ast/HqlTreeBuilder.cs | 10 + src/NHibernate/Hql/Ast/HqlTreeNode.cs | 31 +++ src/NHibernate/Linq/Clauses/NhJoinClause.cs | 2 + .../Linq/ReWriters/AddJoinsReWriter.cs | 32 ++- src/NHibernate/Linq/Visitors/JoinBuilder.cs | 18 +- .../Linq/Visitors/QueryModelVisitor.cs | 59 +++-- .../Linq/Visitors/WhereJoinDetector.cs | 11 + src/NHibernate/SqlCommand/ANSIJoinFragment.cs | 11 +- src/NHibernate/SqlCommand/JoinFragment.cs | 3 +- 29 files changed, 813 insertions(+), 95 deletions(-) diff --git a/src/NHibernate.Test/Async/FetchLazyProperties/FetchLazyPropertiesFixture.cs b/src/NHibernate.Test/Async/FetchLazyProperties/FetchLazyPropertiesFixture.cs index a971ec6d132..8fe1e68ba81 100644 --- a/src/NHibernate.Test/Async/FetchLazyProperties/FetchLazyPropertiesFixture.cs +++ b/src/NHibernate.Test/Async/FetchLazyProperties/FetchLazyPropertiesFixture.cs @@ -9,12 +9,14 @@ using System; +using System.Collections.Generic; using System.Linq; using NHibernate.Cache; using NHibernate.Cfg; using NHibernate.Hql.Ast.ANTLR; using NHibernate.Linq; using NUnit.Framework; +using NUnit.Framework.Constraints; using Environment = NHibernate.Cfg.Environment; namespace NHibernate.Test.FetchLazyProperties @@ -943,6 +945,42 @@ public async Task TestFetchAfterEntityIsInitializedAsync(bool readOnly) Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Formula"), Is.True); } + [Test] + public async Task TestHqlCrossJoinFetchFormulaAsync() + { + if (!Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } + + var persons = new List(); + var bestFriends = new List(); + using (var sqlSpy = new SqlLogSpy()) + using (var s = OpenSession()) + { + var list = await (s.CreateQuery("select p, bf from Person p cross join Person bf fetch bf.Formula where bf.Id = p.BestFriend.Id").ListAsync()); + foreach (var arr in list) + { + persons.Add((Person) arr[0]); + bestFriends.Add((Person) arr[1]); + } + } + + AssertPersons(persons, false); + AssertPersons(bestFriends, true); + + void AssertPersons(List results, bool fetched) + { + foreach (var person in results) + { + Assert.That(person, Is.Not.Null); + Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Image"), Is.False); + Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Address"), Is.False); + Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Formula"), fetched ? Is.True : (IResolveConstraint) Is.False); + } + } + } + private static Person GeneratePerson(int i, Person bestFriend) { return new Person diff --git a/src/NHibernate.Test/Async/Hql/EntityJoinHqlTest.cs b/src/NHibernate.Test/Async/Hql/EntityJoinHqlTest.cs index c573a7666a8..dcca7fc2924 100644 --- a/src/NHibernate.Test/Async/Hql/EntityJoinHqlTest.cs +++ b/src/NHibernate.Test/Async/Hql/EntityJoinHqlTest.cs @@ -274,6 +274,27 @@ public async Task EntityJoinWithFetchesAsync() } } + [Test] + public async Task CrossJoinAndWhereClauseAsync() + { + using (var sqlLog = new SqlLogSpy()) + using (var session = OpenSession()) + { + var result = await (session.CreateQuery( + "SELECT s " + + "FROM EntityComplex s cross join EntityComplex q " + + "where s.SameTypeChild.Id = q.SameTypeChild.Id" + ).ListAsync()); + + Assert.That(result, Has.Count.EqualTo(1)); + Assert.That(sqlLog.Appender.GetEvents().Length, Is.EqualTo(1), "Only one SQL select is expected"); + if (Dialect.SupportsCrossJoin) + { + Assert.That(sqlLog.GetWholeLog(), Does.Contain("cross join"), "A cross join is expected in the SQL select"); + } + } + } + #region Test Setup protected override HbmMapping GetMappings() diff --git a/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs b/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs index 6d11edace21..86802cdb345 100644 --- a/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs +++ b/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs @@ -8,7 +8,13 @@ //------------------------------------------------------------------------------ +using System; using System.Linq; +using System.Reflection; +using NHibernate.Cfg; +using NHibernate.Engine.Query; +using NHibernate.Util; +using NSubstitute; using NUnit.Framework; using NHibernate.Linq; @@ -31,5 +37,32 @@ public async Task MultipleLinqJoinsWithSameProjectionNamesAsync() Assert.That(orders.Count, Is.EqualTo(828)); Assert.IsTrue(orders.All(x => x.FirstId == x.SecondId - 1 && x.SecondId == x.ThirdId - 1)); } + + [TestCase(false)] + [TestCase(true)] + public async Task CrossJoinWithPredicateInWhereStatementAsync(bool useCrossJoin) + { + if (useCrossJoin && !Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } + + using (var substitute = SubstituteDialect()) + using (var sqlSpy = new SqlLogSpy()) + { + ClearQueryPlanCache(); + substitute.Value.SupportsCrossJoin.Returns(useCrossJoin); + + var result = await ((from o in db.Orders + from o2 in db.Orders.Where(x => x.Freight > 50) + where (o.OrderId == o2.OrderId + 1) || (o.OrderId == o2.OrderId - 1) + select new { o.OrderId, OrderId2 = o2.OrderId }).ToListAsync()); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(result.Count, Is.EqualTo(720)); + Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(useCrossJoin ? 0 : 1)); + } + } } } diff --git a/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs b/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs index 44fd7ef3cd8..a332a24f6ce 100644 --- a/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs +++ b/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs @@ -9,9 +9,11 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Linq; using NHibernate.DomainModel.Northwind.Entities; +using NSubstitute; using NUnit.Framework; using NHibernate.Linq; @@ -757,7 +759,13 @@ from o in c.Orders where c.Address.City == "London" select o; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -863,7 +871,13 @@ from p in db.Products where p.Supplier.Address.Country == "USA" && p.UnitsInStock == 0 select p; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -879,7 +893,16 @@ from et in e.Territories where e.Address.City == "Seattle" select new {e.FirstName, e.LastName, et.Region.Description}; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + // EmployeeTerritories and Territories + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2)); + // Region + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -903,7 +926,13 @@ from e2 in e1.Subordinates e1.Address.City }; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -918,7 +947,13 @@ from c in db.Customers join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into orders select new {c.ContactName, OrderCount = orders.Average(x => x.Freight)}; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "join"), Is.EqualTo(0)); + } } [Category("JOIN")] @@ -930,7 +965,13 @@ from c in db.Customers join o in db.Orders on c.CustomerId equals o.Customer.CustomerId select new { c.ContactName, o.OrderId }; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -959,15 +1000,76 @@ from c in db.Customers } [Category("JOIN")] - [Test(Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] - public async Task DLinqJoin5dAsync() + [TestCase(true, Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] + [TestCase(false, Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] + public async Task DLinqJoin5dAsync(bool useCrossJoin) { + if (useCrossJoin && !Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } + var q = from c in db.Customers - join o in db.Orders on new {c.CustomerId, HasContractTitle = c.ContactTitle != null} equals new {o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } + join o in db.Orders on + new {c.CustomerId, HasContractTitle = c.ContactTitle != null} equals + new {o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } select new { c.ContactName, o.OrderId }; - await (ObjectDumper.WriteAsync(q)); + using (var substitute = SubstituteDialect()) + using (var sqlSpy = new SqlLogSpy()) + { + ClearQueryPlanCache(); + substitute.Value.SupportsCrossJoin.Returns(useCrossJoin); + + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(0)); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(useCrossJoin ? 1 : 2)); + Assert.That(GetTotalOccurrences(sql, "cross join"), Is.EqualTo(useCrossJoin ? 1 : 0)); + } + } + + [Category("JOIN")] + [Test(Description = "This sample joins two tables and projects results from the first table.")] + public async Task DLinqJoin5eAsync() + { + var q = + from c in db.Customers + join o in db.Orders on c.CustomerId equals o.Customer.CustomerId + where c.ContactName != null + select o; + + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } + } + + [Category("JOIN")] + [TestCase(Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] + public async Task DLinqJoin5fAsync() + { + var q = + from o in db.Orders + join c in db.Customers on + new { o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } equals + new { c.CustomerId, HasContractTitle = c.ContactTitle != null } + select new { c.ContactName, o.OrderId }; + + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(0)); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2)); + } } [Category("JOIN")] @@ -983,7 +1085,13 @@ join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into ords join e in db.Employees on c.Address.City equals e.Address.City into emps select new {c.ContactName, ords = ords.Count(), emps = emps.Count()}; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "join"), Is.EqualTo(0)); + } } [Category("JOIN")] @@ -997,32 +1105,55 @@ join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into ords from o in ords select new {c.ContactName, o.OrderId, z}; - await (ObjectDumper.WriteAsync(q)); + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] - [Test(Description = "This sample shows a group join with a composite key.")] - public async Task DLinqJoin9Async() - { - var expected = - (from o in db.Orders.ToList() - from p in db.Products.ToList() - join d in db.OrderLines.ToList() - on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} - into details - from d in details - select new {o.OrderId, p.ProductId, d.UnitPrice}).ToList(); - - var actual = - await ((from o in db.Orders - from p in db.Products - join d in db.OrderLines - on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} - into details - from d in details - select new {o.OrderId, p.ProductId, d.UnitPrice}).ToListAsync()); - - Assert.AreEqual(expected.Count, actual.Count); + [TestCase(true, Description = "This sample shows a group join with a composite key.")] + [TestCase(false, Description = "This sample shows a group join with a composite key.")] + public async Task DLinqJoin9Async(bool useCrossJoin) + { + if (useCrossJoin && !Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } + + // The expected collection can be obtained from the below Linq to Objects query. + //var expected = + // (from o in db.Orders.ToList() + // from p in db.Products.ToList() + // join d in db.OrderLines.ToList() + // on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} + // into details + // from d in details + // select new {o.OrderId, p.ProductId, d.UnitPrice}).ToList(); + + using (var substitute = SubstituteDialect()) + using (var sqlSpy = new SqlLogSpy()) + { + ClearQueryPlanCache(); + substitute.Value.SupportsCrossJoin.Returns(useCrossJoin); + + var actual = + await ((from o in db.Orders + from p in db.Products + join d in db.OrderLines + on new { o.OrderId, p.ProductId } equals new { d.Order.OrderId, d.Product.ProductId } + into details + from d in details + select new { o.OrderId, p.ProductId, d.UnitPrice }).ToListAsync()); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(actual.Count, Is.EqualTo(2155)); + Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(useCrossJoin ? 1 : 2)); + } } [Category("JOIN")] @@ -1036,5 +1167,24 @@ group o by c into x await (ObjectDumper.WriteAsync(q)); } + + [Category("JOIN")] + [Test(Description = "This sample shows how to join multiple tables.")] + public async Task DLinqJoin10aAsync() + { + var q = + from e in db.Employees + join s in db.Employees on e.Superior.EmployeeId equals s.EmployeeId + join s2 in db.Employees on s.Superior.EmployeeId equals s2.EmployeeId + select new { e.FirstName, SuperiorName = s.FirstName, Superior2Name = s2.FirstName }; + + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2)); + } + } } } diff --git a/src/NHibernate.Test/FetchLazyProperties/FetchLazyPropertiesFixture.cs b/src/NHibernate.Test/FetchLazyProperties/FetchLazyPropertiesFixture.cs index 04cb354f3f4..49a75c4b107 100644 --- a/src/NHibernate.Test/FetchLazyProperties/FetchLazyPropertiesFixture.cs +++ b/src/NHibernate.Test/FetchLazyProperties/FetchLazyPropertiesFixture.cs @@ -1,10 +1,12 @@ using System; +using System.Collections.Generic; using System.Linq; using NHibernate.Cache; using NHibernate.Cfg; using NHibernate.Hql.Ast.ANTLR; using NHibernate.Linq; using NUnit.Framework; +using NUnit.Framework.Constraints; using Environment = NHibernate.Cfg.Environment; namespace NHibernate.Test.FetchLazyProperties @@ -932,6 +934,42 @@ public void TestFetchAfterEntityIsInitialized(bool readOnly) Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Formula"), Is.True); } + [Test] + public void TestHqlCrossJoinFetchFormula() + { + if (!Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } + + var persons = new List(); + var bestFriends = new List(); + using (var sqlSpy = new SqlLogSpy()) + using (var s = OpenSession()) + { + var list = s.CreateQuery("select p, bf from Person p cross join Person bf fetch bf.Formula where bf.Id = p.BestFriend.Id").List(); + foreach (var arr in list) + { + persons.Add((Person) arr[0]); + bestFriends.Add((Person) arr[1]); + } + } + + AssertPersons(persons, false); + AssertPersons(bestFriends, true); + + void AssertPersons(List results, bool fetched) + { + foreach (var person in results) + { + Assert.That(person, Is.Not.Null); + Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Image"), Is.False); + Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Address"), Is.False); + Assert.That(NHibernateUtil.IsPropertyInitialized(person, "Formula"), fetched ? Is.True : (IResolveConstraint) Is.False); + } + } + } + private static Person GeneratePerson(int i, Person bestFriend) { return new Person diff --git a/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs b/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs index 0b8d0c11a6b..0a578bd982a 100644 --- a/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs +++ b/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs @@ -264,7 +264,7 @@ public void EntityJoinWithFetches() } [Test, Ignore("Failing for unrelated reasons")] - public void CrossJoinAndWithClause() + public void ImplicitJoinAndWithClause() { //This is about complex theta style join fix that was implemented in hibernate along with entity join functionality //https://hibernate.atlassian.net/browse/HHH-7321 @@ -279,6 +279,27 @@ public void CrossJoinAndWithClause() } } + [Test] + public void CrossJoinAndWhereClause() + { + using (var sqlLog = new SqlLogSpy()) + using (var session = OpenSession()) + { + var result = session.CreateQuery( + "SELECT s " + + "FROM EntityComplex s cross join EntityComplex q " + + "where s.SameTypeChild.Id = q.SameTypeChild.Id" + ).List(); + + Assert.That(result, Has.Count.EqualTo(1)); + Assert.That(sqlLog.Appender.GetEvents().Length, Is.EqualTo(1), "Only one SQL select is expected"); + if (Dialect.SupportsCrossJoin) + { + Assert.That(sqlLog.GetWholeLog(), Does.Contain("cross join"), "A cross join is expected in the SQL select"); + } + } + } + #region Test Setup protected override HbmMapping GetMappings() diff --git a/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs b/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs index 1d78b46b181..a013014cae8 100644 --- a/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs @@ -1,4 +1,10 @@ -using System.Linq; +using System; +using System.Linq; +using System.Reflection; +using NHibernate.Cfg; +using NHibernate.Engine.Query; +using NHibernate.Util; +using NSubstitute; using NUnit.Framework; namespace NHibernate.Test.Linq.ByMethod @@ -19,5 +25,32 @@ public void MultipleLinqJoinsWithSameProjectionNames() Assert.That(orders.Count, Is.EqualTo(828)); Assert.IsTrue(orders.All(x => x.FirstId == x.SecondId - 1 && x.SecondId == x.ThirdId - 1)); } + + [TestCase(false)] + [TestCase(true)] + public void CrossJoinWithPredicateInWhereStatement(bool useCrossJoin) + { + if (useCrossJoin && !Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } + + using (var substitute = SubstituteDialect()) + using (var sqlSpy = new SqlLogSpy()) + { + ClearQueryPlanCache(); + substitute.Value.SupportsCrossJoin.Returns(useCrossJoin); + + var result = (from o in db.Orders + from o2 in db.Orders.Where(x => x.Freight > 50) + where (o.OrderId == o2.OrderId + 1) || (o.OrderId == o2.OrderId - 1) + select new { o.OrderId, OrderId2 = o2.OrderId }).ToList(); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(result.Count, Is.EqualTo(720)); + Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(useCrossJoin ? 0 : 1)); + } + } } } diff --git a/src/NHibernate.Test/Linq/LinqQuerySamples.cs b/src/NHibernate.Test/Linq/LinqQuerySamples.cs index efe18a3caad..193c313130c 100755 --- a/src/NHibernate.Test/Linq/LinqQuerySamples.cs +++ b/src/NHibernate.Test/Linq/LinqQuerySamples.cs @@ -1,7 +1,9 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Linq; using NHibernate.DomainModel.Northwind.Entities; +using NSubstitute; using NUnit.Framework; namespace NHibernate.Test.Linq @@ -1301,7 +1303,13 @@ from o in c.Orders where c.Address.City == "London" select o; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -1407,7 +1415,13 @@ from p in db.Products where p.Supplier.Address.Country == "USA" && p.UnitsInStock == 0 select p; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -1423,7 +1437,16 @@ from et in e.Territories where e.Address.City == "Seattle" select new {e.FirstName, e.LastName, et.Region.Description}; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + // EmployeeTerritories and Territories + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2)); + // Region + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -1447,7 +1470,13 @@ from e2 in e1.Subordinates e1.Address.City }; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -1462,7 +1491,13 @@ from c in db.Customers join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into orders select new {c.ContactName, OrderCount = orders.Average(x => x.Freight)}; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "join"), Is.EqualTo(0)); + } } [Category("JOIN")] @@ -1474,7 +1509,13 @@ from c in db.Customers join o in db.Orders on c.CustomerId equals o.Customer.CustomerId select new { c.ContactName, o.OrderId }; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -1503,15 +1544,76 @@ from c in db.Customers } [Category("JOIN")] - [Test(Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] - public void DLinqJoin5d() + [TestCase(true, Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] + [TestCase(false, Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] + public void DLinqJoin5d(bool useCrossJoin) { + if (useCrossJoin && !Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } + var q = from c in db.Customers - join o in db.Orders on new {c.CustomerId, HasContractTitle = c.ContactTitle != null} equals new {o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } + join o in db.Orders on + new {c.CustomerId, HasContractTitle = c.ContactTitle != null} equals + new {o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } select new { c.ContactName, o.OrderId }; - ObjectDumper.Write(q); + using (var substitute = SubstituteDialect()) + using (var sqlSpy = new SqlLogSpy()) + { + ClearQueryPlanCache(); + substitute.Value.SupportsCrossJoin.Returns(useCrossJoin); + + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(0)); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(useCrossJoin ? 1 : 2)); + Assert.That(GetTotalOccurrences(sql, "cross join"), Is.EqualTo(useCrossJoin ? 1 : 0)); + } + } + + [Category("JOIN")] + [Test(Description = "This sample joins two tables and projects results from the first table.")] + public void DLinqJoin5e() + { + var q = + from c in db.Customers + join o in db.Orders on c.CustomerId equals o.Customer.CustomerId + where c.ContactName != null + select o; + + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } + } + + [Category("JOIN")] + [TestCase(Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] + public void DLinqJoin5f() + { + var q = + from o in db.Orders + join c in db.Customers on + new { o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } equals + new { c.CustomerId, HasContractTitle = c.ContactTitle != null } + select new { c.ContactName, o.OrderId }; + + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(0)); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2)); + } } [Category("JOIN")] @@ -1527,7 +1629,13 @@ join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into ords join e in db.Employees on c.Address.City equals e.Address.City into emps select new {c.ContactName, ords = ords.Count(), emps = emps.Count()}; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "join"), Is.EqualTo(0)); + } } [Category("JOIN")] @@ -1544,7 +1652,13 @@ join o in db.Orders on e equals o.Employee into ords from o in ords.DefaultIfEmpty() select new {e.FirstName, e.LastName, Order = o}; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } } [Category("JOIN")] @@ -1558,32 +1672,55 @@ join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into ords from o in ords select new {c.ContactName, o.OrderId, z}; - ObjectDumper.Write(q); + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(1)); + } } [Category("JOIN")] - [Test(Description = "This sample shows a group join with a composite key.")] - public void DLinqJoin9() - { - var expected = - (from o in db.Orders.ToList() - from p in db.Products.ToList() - join d in db.OrderLines.ToList() - on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} - into details - from d in details - select new {o.OrderId, p.ProductId, d.UnitPrice}).ToList(); - - var actual = - (from o in db.Orders - from p in db.Products - join d in db.OrderLines - on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} - into details - from d in details - select new {o.OrderId, p.ProductId, d.UnitPrice}).ToList(); + [TestCase(true, Description = "This sample shows a group join with a composite key.")] + [TestCase(false, Description = "This sample shows a group join with a composite key.")] + public void DLinqJoin9(bool useCrossJoin) + { + if (useCrossJoin && !Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } - Assert.AreEqual(expected.Count, actual.Count); + // The expected collection can be obtained from the below Linq to Objects query. + //var expected = + // (from o in db.Orders.ToList() + // from p in db.Products.ToList() + // join d in db.OrderLines.ToList() + // on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} + // into details + // from d in details + // select new {o.OrderId, p.ProductId, d.UnitPrice}).ToList(); + + using (var substitute = SubstituteDialect()) + using (var sqlSpy = new SqlLogSpy()) + { + ClearQueryPlanCache(); + substitute.Value.SupportsCrossJoin.Returns(useCrossJoin); + + var actual = + (from o in db.Orders + from p in db.Products + join d in db.OrderLines + on new { o.OrderId, p.ProductId } equals new { d.Order.OrderId, d.Product.ProductId } + into details + from d in details + select new { o.OrderId, p.ProductId, d.UnitPrice }).ToList(); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(actual.Count, Is.EqualTo(2155)); + Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(useCrossJoin ? 1 : 2)); + } } [Category("JOIN")] @@ -1598,6 +1735,25 @@ group o by c into x ObjectDumper.Write(q); } + [Category("JOIN")] + [Test(Description = "This sample shows how to join multiple tables.")] + public void DLinqJoin10a() + { + var q = + from e in db.Employees + join s in db.Employees on e.Superior.EmployeeId equals s.EmployeeId + join s2 in db.Employees on s.Superior.EmployeeId equals s2.EmployeeId + select new { e.FirstName, SuperiorName = s.FirstName, Superior2Name = s2.FirstName }; + + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2)); + } + } + [Category("WHERE")] [Test(Description = "This sample uses WHERE to filter for orders with shipping date equals to null.")] public void DLinq2B() diff --git a/src/NHibernate.Test/TestCase.cs b/src/NHibernate.Test/TestCase.cs index 4c9ea610c6d..0d59b989cf4 100644 --- a/src/NHibernate.Test/TestCase.cs +++ b/src/NHibernate.Test/TestCase.cs @@ -16,6 +16,9 @@ using System.Text; using NHibernate.Dialect; using NHibernate.Driver; +using NHibernate.Engine.Query; +using NHibernate.Util; +using NSubstitute; namespace NHibernate.Test { @@ -477,6 +480,69 @@ protected void AssumeFunctionSupported(string functionName) $"{Dialect} doesn't support {functionName} standard function."); } + protected void ClearQueryPlanCache() + { + var planCacheField = typeof(QueryPlanCache) + .GetField("planCache", BindingFlags.NonPublic | BindingFlags.Instance) + ?? throw new InvalidOperationException("planCache field does not exist in QueryPlanCache."); + + var planCache = (SoftLimitMRUCache) planCacheField.GetValue(Sfi.QueryPlanCache); + planCache.Clear(); + } + + protected Substitute SubstituteDialect() + { + var origDialect = Sfi.Settings.Dialect; + var dialectProperty = (PropertyInfo) ReflectHelper.GetProperty(o => o.Dialect); + var forPartsOfMethod = ReflectHelper.GetMethodDefinition(() => Substitute.ForPartsOf()); + var substitute = (Dialect.Dialect) forPartsOfMethod.MakeGenericMethod(origDialect.GetType()) + .Invoke(null, new object[] { new object[0] }); + + dialectProperty.SetValue(Sfi.Settings, substitute); + + return new Substitute(substitute, Dispose); + + void Dispose() + { + dialectProperty.SetValue(Sfi.Settings, origDialect); + } + } + + protected static int GetTotalOccurrences(string content, string substring) + { + if (string.IsNullOrEmpty(substring)) + { + throw new ArgumentNullException(nameof(substring)); + } + + int occurrences = 0, index = 0; + while ((index = content.IndexOf(substring, index)) >= 0) + { + occurrences++; + index += substring.Length; + } + + return occurrences; + } + + protected struct Substitute : IDisposable + { + private readonly System.Action _disposeAction; + + public Substitute(TType value, System.Action disposeAction) + { + Value = value; + _disposeAction = disposeAction; + } + + public TType Value { get; } + + public void Dispose() + { + _disposeAction(); + } + } + #endregion } } diff --git a/src/NHibernate/AdoNet/Util/BasicFormatter.cs b/src/NHibernate/AdoNet/Util/BasicFormatter.cs index 2e601990e06..c3f35fbddf1 100644 --- a/src/NHibernate/AdoNet/Util/BasicFormatter.cs +++ b/src/NHibernate/AdoNet/Util/BasicFormatter.cs @@ -20,6 +20,7 @@ static BasicFormatter() { beginClauses.Add("left"); beginClauses.Add("right"); + beginClauses.Add("cross"); beginClauses.Add("inner"); beginClauses.Add("outer"); beginClauses.Add("group"); diff --git a/src/NHibernate/Dialect/DB2Dialect.cs b/src/NHibernate/Dialect/DB2Dialect.cs index d56ab6dc06f..3eef1635595 100644 --- a/src/NHibernate/Dialect/DB2Dialect.cs +++ b/src/NHibernate/Dialect/DB2Dialect.cs @@ -300,6 +300,9 @@ public override string ForUpdateString public override bool SupportsResultSetPositionQueryMethodsOnForwardOnlyCursor => false; + /// + public override bool SupportsCrossJoin => false; // DB2 v9.1 doesn't support 'cross join' syntax + public override bool SupportsLobValueChangePropogation => false; public override bool SupportsExistsInSelect => false; diff --git a/src/NHibernate/Dialect/Dialect.cs b/src/NHibernate/Dialect/Dialect.cs index e3d0b6cacd9..423c5082361 100644 --- a/src/NHibernate/Dialect/Dialect.cs +++ b/src/NHibernate/Dialect/Dialect.cs @@ -1337,6 +1337,11 @@ public virtual JoinFragment CreateOuterJoinFragment() return new ANSIJoinFragment(); } + /// + /// Does this dialect support CROSS JOIN? + /// + public virtual bool SupportsCrossJoin => true; + /// /// Create a strategy responsible /// for handling this dialect's variations in how CASE statements are diff --git a/src/NHibernate/Dialect/InformixDialect0940.cs b/src/NHibernate/Dialect/InformixDialect0940.cs index 13563df43e8..bdfcae4f27f 100644 --- a/src/NHibernate/Dialect/InformixDialect0940.cs +++ b/src/NHibernate/Dialect/InformixDialect0940.cs @@ -126,7 +126,10 @@ public override JoinFragment CreateOuterJoinFragment() return new ANSIJoinFragment(); } - /// + /// + public override bool SupportsCrossJoin => false; + + /// /// Does this Dialect have some kind of LIMIT syntax? /// /// False, unless overridden. diff --git a/src/NHibernate/Dialect/Oracle10gDialect.cs b/src/NHibernate/Dialect/Oracle10gDialect.cs index 7219390c0e9..caab3e1f492 100644 --- a/src/NHibernate/Dialect/Oracle10gDialect.cs +++ b/src/NHibernate/Dialect/Oracle10gDialect.cs @@ -15,5 +15,8 @@ public override JoinFragment CreateOuterJoinFragment() { return new ANSIJoinFragment(); } + + /// + public override bool SupportsCrossJoin => true; } } \ No newline at end of file diff --git a/src/NHibernate/Dialect/Oracle8iDialect.cs b/src/NHibernate/Dialect/Oracle8iDialect.cs index f3a943a3b8e..b2103b87965 100644 --- a/src/NHibernate/Dialect/Oracle8iDialect.cs +++ b/src/NHibernate/Dialect/Oracle8iDialect.cs @@ -328,6 +328,9 @@ public override JoinFragment CreateOuterJoinFragment() return new OracleJoinFragment(); } + /// + public override bool SupportsCrossJoin => false; + /// /// Map case support to the Oracle DECODE function. Oracle did not /// add support for CASE until 9i. diff --git a/src/NHibernate/Dialect/SybaseASA9Dialect.cs b/src/NHibernate/Dialect/SybaseASA9Dialect.cs index 12ad1e3c088..f839871b6a1 100644 --- a/src/NHibernate/Dialect/SybaseASA9Dialect.cs +++ b/src/NHibernate/Dialect/SybaseASA9Dialect.cs @@ -97,6 +97,9 @@ public override bool OffsetStartsAtOne get { return true; } } + /// + public override bool SupportsCrossJoin => false; + public override SqlString GetLimitString(SqlString queryString, SqlString offset, SqlString limit) { int intSelectInsertPoint = GetAfterSelectInsertPoint(queryString); diff --git a/src/NHibernate/Dialect/SybaseASE15Dialect.cs b/src/NHibernate/Dialect/SybaseASE15Dialect.cs index 1b33b672cdf..0a84a822424 100644 --- a/src/NHibernate/Dialect/SybaseASE15Dialect.cs +++ b/src/NHibernate/Dialect/SybaseASE15Dialect.cs @@ -247,7 +247,10 @@ public override bool SupportsExpectedLobUsagePattern { get { return false; } } - + + /// + public override bool SupportsCrossJoin => false; + public override char OpenQuote { get { return '['; } diff --git a/src/NHibernate/Hql/Ast/ANTLR/Hql.g b/src/NHibernate/Hql/Ast/ANTLR/Hql.g index 9f67aaa2162..efe660d8918 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Hql.g +++ b/src/NHibernate/Hql/Ast/ANTLR/Hql.g @@ -19,6 +19,7 @@ tokens BETWEEN='between'; CLASS='class'; COUNT='count'; + CROSS='cross'; DELETE='delete'; DESCENDING='desc'; DOT; @@ -255,6 +256,7 @@ fromClause fromJoin : ( ( ( LEFT | RIGHT ) (OUTER)? ) | FULL | INNER )? JOIN^ (FETCH)? path (asAlias)? (propertyFetch)? (withClause)? | ( ( ( LEFT | RIGHT ) (OUTER)? ) | FULL | INNER )? JOIN^ (FETCH)? ELEMENTS! OPEN! path CLOSE! (asAlias)? (propertyFetch)? (withClause)? + | CROSS JOIN^ { WeakKeywords(); } path (asAlias)? (propertyFetch)? ; withClause diff --git a/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.g b/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.g index 22e5450dce1..fba1010337c 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.g +++ b/src/NHibernate/Hql/Ast/ANTLR/HqlSqlWalker.g @@ -306,6 +306,9 @@ joinType returns [int j] | INNER { $j = INNER; } + | CROSS { + $j = CROSS; + } ; // Matches a path and returns the normalized string for the path (usually diff --git a/src/NHibernate/Hql/Ast/ANTLR/Util/JoinProcessor.cs b/src/NHibernate/Hql/Ast/ANTLR/Util/JoinProcessor.cs index c4f850f7aa3..8fbcfba7fd7 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Util/JoinProcessor.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Util/JoinProcessor.cs @@ -48,6 +48,8 @@ public static JoinType ToHibernateJoinType(int astJoinType) return JoinType.RightOuterJoin; case HqlSqlWalker.FULL: return JoinType.FullJoin; + case HqlSqlWalker.CROSS: + return JoinType.CrossJoin; default: throw new AssertionFailure("undefined join type " + astJoinType); } diff --git a/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs b/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs index 49cad7607e1..bb208295afc 100755 --- a/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs +++ b/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs @@ -478,11 +478,21 @@ public HqlIn In(HqlExpression itemExpression, HqlTreeNode source) return new HqlIn(_factory, itemExpression, source); } + public HqlInnerJoin InnerJoin(HqlExpression expression, HqlAlias @alias) + { + return new HqlInnerJoin(_factory, expression, @alias); + } + public HqlLeftJoin LeftJoin(HqlExpression expression, HqlAlias @alias) { return new HqlLeftJoin(_factory, expression, @alias); } + public HqlCrossJoin CrossJoin(HqlExpression expression, HqlAlias @alias) + { + return new HqlCrossJoin(_factory, expression, @alias); + } + public HqlFetchJoin FetchJoin(HqlExpression expression, HqlAlias @alias) { return new HqlFetchJoin(_factory, expression, @alias); diff --git a/src/NHibernate/Hql/Ast/HqlTreeNode.cs b/src/NHibernate/Hql/Ast/HqlTreeNode.cs index 5964b99db90..160739d7f16 100755 --- a/src/NHibernate/Hql/Ast/HqlTreeNode.cs +++ b/src/NHibernate/Hql/Ast/HqlTreeNode.cs @@ -844,6 +844,14 @@ public HqlJoin(IASTFactory factory, HqlExpression expression, HqlAlias @alias) : } } + public class HqlInnerJoin : HqlTreeNode + { + public HqlInnerJoin(IASTFactory factory, HqlExpression expression, HqlAlias alias) + : base(HqlSqlWalker.JOIN, "join", factory, new HqlInner(factory), expression, alias) + { + } + } + public class HqlLeftJoin : HqlTreeNode { public HqlLeftJoin(IASTFactory factory, HqlExpression expression, HqlAlias @alias) : base(HqlSqlWalker.JOIN, "join", factory, new HqlLeft(factory), expression, @alias) @@ -851,6 +859,13 @@ public class HqlLeftJoin : HqlTreeNode } } + public class HqlCrossJoin : HqlTreeNode + { + public HqlCrossJoin(IASTFactory factory, HqlExpression expression, HqlAlias @alias) : base(HqlSqlWalker.JOIN, "join", factory, new HqlCross(factory), expression, @alias) + { + } + } + public class HqlFetchJoin : HqlTreeNode { public HqlFetchJoin(IASTFactory factory, HqlExpression expression, HqlAlias @alias) @@ -898,6 +913,14 @@ public HqlBitwiseAnd(IASTFactory factory, HqlExpression lhs, HqlExpression rhs) } } + public class HqlInner : HqlTreeNode + { + public HqlInner(IASTFactory factory) + : base(HqlSqlWalker.INNER, "inner", factory) + { + } + } + public class HqlLeft : HqlTreeNode { public HqlLeft(IASTFactory factory) @@ -906,6 +929,14 @@ public HqlLeft(IASTFactory factory) } } + public class HqlCross : HqlTreeNode + { + public HqlCross(IASTFactory factory) + : base(HqlSqlWalker.CROSS, "cross", factory) + { + } + } + public class HqlAny : HqlBooleanExpression { public HqlAny(IASTFactory factory) : base(HqlSqlWalker.ANY, "any", factory) diff --git a/src/NHibernate/Linq/Clauses/NhJoinClause.cs b/src/NHibernate/Linq/Clauses/NhJoinClause.cs index d68926f88b3..3a1c17147aa 100644 --- a/src/NHibernate/Linq/Clauses/NhJoinClause.cs +++ b/src/NHibernate/Linq/Clauses/NhJoinClause.cs @@ -54,6 +54,8 @@ public NhJoinClause(string itemName, System.Type itemType, Expression fromExpres public bool IsInner { get; private set; } + internal JoinClause ParentJoinClause { get; set; } + public void TransformExpressions(Func transformation) { if (transformation == null) throw new ArgumentNullException(nameof(transformation)); diff --git a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs index 576bb65a9ea..d022e1ffc88 100644 --- a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs +++ b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Specialized; using System.Linq; using NHibernate.Engine; using NHibernate.Linq.Clauses; @@ -19,11 +20,12 @@ public class AddJoinsReWriter : NhQueryModelVisitorBase, IIsEntityDecider private readonly ISessionFactoryImplementor _sessionFactory; private readonly MemberExpressionJoinDetector _memberExpressionJoinDetector; private readonly WhereJoinDetector _whereJoinDetector; + private JoinClause _currentJoin; private AddJoinsReWriter(ISessionFactoryImplementor sessionFactory, QueryModel queryModel) { _sessionFactory = sessionFactory; - var joiner = new Joiner(queryModel); + var joiner = new Joiner(queryModel, AddJoin); _memberExpressionJoinDetector = new MemberExpressionJoinDetector(this, joiner); _whereJoinDetector = new WhereJoinDetector(this, joiner); } @@ -59,6 +61,22 @@ public override void VisitNhHavingClause(NhHavingClause havingClause, QueryModel _whereJoinDetector.Transform(havingClause); } + public override void VisitJoinClause(JoinClause joinClause, QueryModel queryModel, int index) + { + joinClause.InnerSequence = _whereJoinDetector.Transform(joinClause.InnerSequence); + + // When associations are located in the outer key (e.g. from a in A join b in B b on a.C.D.Id equals b.Id), + // do nothing and leave them to HQL for adding the missing joins. + + // When associations are located in the inner key (e.g. from a in A join b in B b on a.Id equals b.C.D.Id), + // we have to move the condition to the where statement, otherwise the query will be invalid (HQL does not + // support them). + // Link newly created joins with the current join clause in order to later detect which join type to use. + _currentJoin = joinClause; + joinClause.InnerKeySelector = _whereJoinDetector.Transform(joinClause.InnerKeySelector); + _currentJoin = null; + } + public bool IsEntity(System.Type type) { return _sessionFactory.GetImplementors(type.FullName).Any(); @@ -69,5 +87,17 @@ public bool IsIdentifier(System.Type type, string propertyName) var metadata = _sessionFactory.GetClassMetadata(type); return metadata != null && propertyName.Equals(metadata.IdentifierPropertyName); } + + private void AddJoin(QueryModel queryModel, NhJoinClause joinClause) + { + joinClause.ParentJoinClause = _currentJoin; + if (_currentJoin != null) + { + // Match the parent join type + joinClause.MakeInner(); + } + + queryModel.BodyClauses.Add(joinClause); + } } } diff --git a/src/NHibernate/Linq/Visitors/JoinBuilder.cs b/src/NHibernate/Linq/Visitors/JoinBuilder.cs index b84a72c7ac3..dc7bc395c13 100644 --- a/src/NHibernate/Linq/Visitors/JoinBuilder.cs +++ b/src/NHibernate/Linq/Visitors/JoinBuilder.cs @@ -21,11 +21,20 @@ public class Joiner : IJoiner private readonly NameGenerator _nameGenerator; private readonly QueryModel _queryModel; + internal Joiner(QueryModel queryModel, System.Action addJoinMethod) + : this(queryModel) + { + AddJoinMethod = addJoinMethod; + } + internal Joiner(QueryModel queryModel) { _nameGenerator = new NameGenerator(queryModel); _queryModel = queryModel; + AddJoinMethod = AddJoin; } + + internal System.Action AddJoinMethod { get; } public IEnumerable Joins { @@ -39,7 +48,7 @@ public Expression AddJoin(Expression expression, string key) if (!_joins.TryGetValue(key, out join)) { join = new NhJoinClause(_nameGenerator.GetNewName(), expression.Type, expression); - _queryModel.BodyClauses.Add(join); + AddJoinMethod(_queryModel, join); _joins.Add(key, join); } @@ -72,6 +81,11 @@ public bool CanAddJoin(Expression expression) return resultOperatorBase != null && _queryModel.ResultOperators.Contains(resultOperatorBase); } + private void AddJoin(QueryModel queryModel, NhJoinClause joinClause) + { + queryModel.BodyClauses.Add(joinClause); + } + private class QuerySourceExtractor : RelinqExpressionVisitor { private IQuerySource _querySource; @@ -90,4 +104,4 @@ protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpr } } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index f043d30c51f..a487a1281ce 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Linq.Expressions; using System.Reflection; using NHibernate.Hql.Ast; @@ -17,6 +18,8 @@ using Remotion.Linq.Clauses.ResultOperators; using Remotion.Linq.Clauses.StreamedData; using Remotion.Linq.EagerFetching; +using OrderByClause = Remotion.Linq.Clauses.OrderByClause; +using SelectClause = Remotion.Linq.Clauses.SelectClause; namespace NHibernate.Linq.Visitors { @@ -315,22 +318,16 @@ public override void VisitMainFromClause(MainFromClause fromClause, QueryModel q public override void VisitAdditionalFromClause(AdditionalFromClause fromClause, QueryModel queryModel, int index) { var querySourceName = VisitorParameters.QuerySourceNamer.GetName(fromClause); - + var fromExpressionTree = HqlGeneratorExpressionVisitor.Visit(fromClause.FromExpression, VisitorParameters); + var alias = _hqlTree.TreeBuilder.Alias(querySourceName); if (fromClause.FromExpression is MemberExpression) { // It's a join - _hqlTree.AddFromClause( - _hqlTree.TreeBuilder.Join( - HqlGeneratorExpressionVisitor.Visit(fromClause.FromExpression, VisitorParameters).AsExpression(), - _hqlTree.TreeBuilder.Alias(querySourceName))); + _hqlTree.AddFromClause(_hqlTree.TreeBuilder.Join(fromExpressionTree.AsExpression(), alias)); } else { - // TODO - exact same code as in MainFromClause; refactor this out - _hqlTree.AddFromClause( - _hqlTree.TreeBuilder.Range( - HqlGeneratorExpressionVisitor.Visit(fromClause.FromExpression, VisitorParameters), - _hqlTree.TreeBuilder.Alias(querySourceName))); + _hqlTree.AddFromClause(CreateCrossJoin(fromExpressionTree, alias)); } base.VisitAdditionalFromClause(fromClause, queryModel, index); @@ -517,15 +514,24 @@ public override void VisitOrderByClause(OrderByClause orderByClause, QueryModel public override void VisitJoinClause(JoinClause joinClause, QueryModel queryModel, int index) { var equalityVisitor = new EqualityHqlGenerator(VisitorParameters); - var whereClause = equalityVisitor.Visit(joinClause.InnerKeySelector, joinClause.OuterKeySelector); - var querySourceName = VisitorParameters.QuerySourceNamer.GetName(joinClause); - - _hqlTree.AddWhereClause(whereClause); + var withClause = equalityVisitor.Visit(joinClause.InnerKeySelector, joinClause.OuterKeySelector); + var alias = _hqlTree.TreeBuilder.Alias(VisitorParameters.QuerySourceNamer.GetName(joinClause)); + var joinExpression = HqlGeneratorExpressionVisitor.Visit(joinClause.InnerSequence, VisitorParameters); + HqlTreeNode join; + // When associations are located inside the inner key selector we have to use a cross join instead of an inner + // join and add the condition in the where statement. + if (queryModel.BodyClauses.OfType().Any(o => o.ParentJoinClause == joinClause)) + { + _hqlTree.AddWhereClause(withClause); + join = CreateCrossJoin(joinExpression, alias); + } + else + { + join = _hqlTree.TreeBuilder.InnerJoin(joinExpression.AsExpression(), alias); + join.AddChild(_hqlTree.TreeBuilder.With(withClause)); + } - _hqlTree.AddFromClause( - _hqlTree.TreeBuilder.Range( - HqlGeneratorExpressionVisitor.Visit(joinClause.InnerSequence, VisitorParameters), - _hqlTree.TreeBuilder.Alias(querySourceName))); + _hqlTree.AddFromClause(join); } public override void VisitGroupJoinClause(GroupJoinClause groupJoinClause, QueryModel queryModel, int index) @@ -552,5 +558,22 @@ public override void VisitNhWithClause(NhWithClause withClause, QueryModel query var expression = HqlGeneratorExpressionVisitor.Visit(withClause.Predicate, VisitorParameters).ToBooleanExpression(); _hqlTree.AddWhereClause(expression); } + + private HqlTreeNode CreateCrossJoin(HqlTreeNode joinExpression, HqlAlias alias) + { + if (VisitorParameters.SessionFactory.Dialect.SupportsCrossJoin) + { + return _hqlTree.TreeBuilder.CrossJoin(joinExpression.AsExpression(), alias); + } + + // Simulate cross join as a inner join on 1=1 + var join = _hqlTree.TreeBuilder.InnerJoin(joinExpression.AsExpression(), alias); + var onExpression = _hqlTree.TreeBuilder.Equality( + _hqlTree.TreeBuilder.True(), + _hqlTree.TreeBuilder.True()); + join.AddChild(_hqlTree.TreeBuilder.With(onExpression)); + + return join; + } } } diff --git a/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs b/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs index 66508f01eef..68f88c6cc55 100644 --- a/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs +++ b/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs @@ -77,10 +77,21 @@ internal WhereJoinDetector(IIsEntityDecider isEntityDecider, IJoiner joiner) _joiner = joiner; } + public Expression Transform(Expression expression) + { + var result = Visit(expression); + PostTransform(); + return result; + } + public void Transform(IClause whereClause) { whereClause.TransformExpressions(Visit); + PostTransform(); + } + private void PostTransform() + { var values = _values.Pop(); foreach (var memberExpression in values.MemberExpressions) diff --git a/src/NHibernate/SqlCommand/ANSIJoinFragment.cs b/src/NHibernate/SqlCommand/ANSIJoinFragment.cs index aa9421fa08e..063c361eb56 100644 --- a/src/NHibernate/SqlCommand/ANSIJoinFragment.cs +++ b/src/NHibernate/SqlCommand/ANSIJoinFragment.cs @@ -33,12 +33,21 @@ public override void AddJoin(string tableName, string alias, string[] fkColumns, case JoinType.FullJoin: joinString = " full outer join "; break; + case JoinType.CrossJoin: + joinString = " cross join "; + break; default: throw new AssertionFailure("undefined join type"); } - _fromFragment.Add(joinString + tableName + ' ' + alias + " on "); + _fromFragment.Add(joinString).Add(tableName).Add(" ").Add(alias).Add(" "); + if (joinType == JoinType.CrossJoin) + { + // Cross join does not have an 'on' statement + return; + } + _fromFragment.Add("on "); if (fkColumns.Length == 0) { AddBareCondition(_fromFragment, on); diff --git a/src/NHibernate/SqlCommand/JoinFragment.cs b/src/NHibernate/SqlCommand/JoinFragment.cs index 18b249bd57f..38c5d8ce280 100644 --- a/src/NHibernate/SqlCommand/JoinFragment.cs +++ b/src/NHibernate/SqlCommand/JoinFragment.cs @@ -10,7 +10,8 @@ public enum JoinType InnerJoin = 0, FullJoin = 4, LeftOuterJoin = 1, - RightOuterJoin = 2 + RightOuterJoin = 2, + CrossJoin = 8 } /// From 7c9f229f1fa5993976c34b5196a54ed3e4d843a3 Mon Sep 17 00:00:00 2001 From: "g.yakimov" Date: Thu, 19 Mar 2020 10:05:54 +0200 Subject: [PATCH 09/43] improve adding of with clauses when entity overrides property from base class --- src/NHibernate.Test/Hql/EntityJoinHqlTest.cs | 46 +++++++++++- src/NHibernate.Test/Hql/Node.cs | 75 +++++++++++++++++++ .../Hql/Ast/ANTLR/Tree/ComponentJoin.cs | 5 ++ .../Criteria/CriteriaQueryTranslator.cs | 2 +- .../Collection/AbstractCollectionPersister.cs | 6 ++ .../Collection/CollectionPropertyMapping.cs | 8 +- .../Collection/ElementPropertyMapping.cs | 8 +- .../Entity/AbstractEntityPersister.cs | 5 ++ .../Entity/AbstractPropertyMapping.cs | 6 ++ .../Entity/BasicEntityPropertyMapping.cs | 14 ++++ .../Persister/Entity/IPropertyMapping.cs | 4 +- 11 files changed, 174 insertions(+), 5 deletions(-) create mode 100644 src/NHibernate.Test/Hql/Node.cs diff --git a/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs b/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs index 0a578bd982a..fcd64564947 100644 --- a/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs +++ b/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs @@ -1,5 +1,8 @@ -using System.Text.RegularExpressions; +using System.Collections.Generic; +using System.Linq; +using System.Text.RegularExpressions; using NHibernate.Cfg.MappingSchema; +using NHibernate.Criterion; using NHibernate.Mapping.ByCode; using NHibernate.Test.Hql.EntityJoinHqlTestEntities; using NUnit.Framework; @@ -300,6 +303,43 @@ public void CrossJoinAndWhereClause() } } + [Test] + public void Join_Inheritance() + { + // arrange + IEnumerable results; + var person = new PersonBase { Login = "dave", FamilyName = "grohl" }; + var visit_1 = new UserEntityVisit { PersonBase = person }; + var visit_2 = new UserEntityVisit { PersonBase = person }; + + using (ISession arrangeSession = OpenSession()) + using (ITransaction tx = arrangeSession.BeginTransaction()) + { + arrangeSession.Save(person); + arrangeSession.Save(visit_1); + arrangeSession.Save(visit_2); + arrangeSession.Flush(); + + tx.Commit(); + } + + // act + using (var session = OpenSession()) + { + results = session.CreateCriteria() + .CreateCriteria( + $"{nameof(UserEntityVisit.PersonBase)}", + "f", + SqlCommand.JoinType.LeftOuterJoin, + Restrictions.Eq("Deleted", false)) + .List() + .Select(x => x.Id); + } + + // assert + Assert.That(results, Is.EquivalentTo(new[] { visit_1.Id, visit_2.Id, })); + } + #region Test Setup protected override HbmMapping GetMappings() @@ -372,6 +412,10 @@ protected override HbmMapping GetMappings() rc.Property(e => e.Name); }); + + Node.AddMapping(mapper); + UserEntityVisit.AddMapping(mapper); + return mapper.CompileMappingForAllExplicitlyAddedEntities(); } diff --git a/src/NHibernate.Test/Hql/Node.cs b/src/NHibernate.Test/Hql/Node.cs new file mode 100644 index 00000000000..e3cb2937002 --- /dev/null +++ b/src/NHibernate.Test/Hql/Node.cs @@ -0,0 +1,75 @@ +using System; +using NHibernate.Mapping.ByCode; + +namespace NHibernate.Test.Hql +{ + public abstract class Node + { + private int _id; + public virtual int Id + { + get { return _id; } + set { _id = value; } + } + + public virtual bool Deleted { get; set; } + public virtual string FamilyName { get; set; } + + public static void AddMapping(ModelMapper mapper) + { + mapper.Class(ca => + { + ca.Id(x => x.Id, map => map.Generator(Generators.Identity)); + ca.Property(x => x.Deleted); + ca.Property(x => x.FamilyName); + ca.Table("Node"); + ca.Abstract(true); + }); + + mapper.JoinedSubclass(ca => + { + ca.Key(x => x.Column("FK_Node_ID")); + ca.Extends(typeof(Node)); + ca.Property(x => x.Deleted); + ca.Property(x => x.Login); + }); + } + } + + [Serializable] + public class PersonBase : Node + { + public virtual string Login { get; set; } + public override bool Deleted { get; set; } + } + + [Serializable] + public class UserEntityVisit + { + private int _id; + public virtual int Id + { + get { return _id; } + set { _id = value; } + } + + public virtual bool Deleted { get; set; } + + private PersonBase _PersonBase; + public virtual PersonBase PersonBase + { + get { return _PersonBase; } + set { _PersonBase = value; } + } + + public static void AddMapping(ModelMapper mapper) + { + mapper.Class(ca => + { + ca.Id(x => x.Id, map => map.Generator(Generators.Identity)); + ca.Property(x => x.Deleted); + ca.ManyToOne(x => x.PersonBase); + }); + } + } +} diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs index bfffb9be928..f434b3b7f51 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs @@ -150,6 +150,11 @@ public bool TryToType(string propertyName, out IType type) return fromElementType.GetBasePropertyMapping().TryToType(GetPropertyPath(propertyName), out type); } + public string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) + { + return fromElementType.GetBasePropertyMapping().ToColumns(pathCriteria, GetPropertyPath(propertyName), getSQLAlias); + } + public string[] ToColumns(string alias, string propertyName) { return fromElementType.GetBasePropertyMapping().ToColumns(alias, GetPropertyPath(propertyName)); diff --git a/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs b/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs index 682dc73e084..8ebad5641c4 100644 --- a/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs +++ b/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs @@ -766,7 +766,7 @@ private bool TryGetColumns(ICriteria subcriteria, string path, bool verifyProper return false; } - columns = propertyMapping.ToColumns(GetSQLAlias(pathCriteria), propertyName); + columns = propertyMapping.ToColumns(pathCriteria, propertyName, GetSQLAlias); return true; } diff --git a/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs b/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs index 20701df9435..593482633e8 100644 --- a/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs +++ b/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs @@ -1386,6 +1386,12 @@ public bool IsManyToManyFiltered(IDictionary enabledFilters) return IsManyToMany && (manyToManyWhereString != null || manyToManyFilterHelper.IsAffectedBy(enabledFilters)); } + public string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) + { + string alias = getSQLAlias(pathCriteria); + return ToColumns(alias, propertyName); + } + public string[] ToColumns(string alias, string propertyName) { if ("index".Equals(propertyName)) diff --git a/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs b/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs index e9e2f89dc51..de56cb9cd81 100644 --- a/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs +++ b/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs @@ -57,6 +57,12 @@ public bool TryToType(string propertyName, out IType type) } } + public string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) + { + string alias = getSQLAlias(pathCriteria); + return ToColumns(alias, propertyName); + } + public string[] ToColumns(string alias, string propertyName) { string[] cols; @@ -117,4 +123,4 @@ public IType Type get { return memberPersister.CollectionType; } } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs b/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs index 20e9899ddb6..ad412a19774 100644 --- a/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs +++ b/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs @@ -47,6 +47,12 @@ public bool TryToType(string propertyName, out IType outType) } } + public string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) + { + string alias = getSQLAlias(pathCriteria); + return ToColumns(alias, propertyName); + } + public string[] ToColumns(string alias, string propertyName) { if (propertyName == null || "id".Equals(propertyName)) @@ -71,4 +77,4 @@ public IType Type #endregion } -} \ No newline at end of file +} diff --git a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs index 924da726cc1..879a0d670f3 100644 --- a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs @@ -2050,6 +2050,11 @@ public virtual string GetRootTableAlias(string drivingAlias) return drivingAlias; } + public virtual string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) + { + return propertyMapping.ToColumns(pathCriteria, propertyName, getSQLAlias); + } + public virtual string[] ToColumns(string alias, string propertyName) { return propertyMapping.ToColumns(alias, propertyName); diff --git a/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs b/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs index c027568bf18..46e8ca70e34 100644 --- a/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs +++ b/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs @@ -44,6 +44,12 @@ public bool TryToType(string propertyName, out IType type) return typesByPropertyPath.TryGetValue(propertyName, out type); } + public virtual string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) + { + string alias = getSQLAlias(pathCriteria); + return ToColumns(alias, propertyName); + } + public virtual string[] ToColumns(string alias, string propertyName) { //TODO: *two* hashmap lookups here is one too many... diff --git a/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs b/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs index 02f625bd550..6c7b31a6940 100644 --- a/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs +++ b/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs @@ -1,4 +1,7 @@ +using System; +using NHibernate.Criterion; using NHibernate.Type; +using static NHibernate.Impl.CriteriaImpl; namespace NHibernate.Persister.Entity { @@ -26,6 +29,17 @@ public override IType Type get { return persister.Type; } } + public override string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) + { + var withClause = pathCriteria as Subcriteria != null ? ((Subcriteria) pathCriteria).WithClause as SimpleExpression : null; + if (withClause != null && withClause.PropertyName == propertyName) + { + return base.ToColumns(persister.GenerateTableAlias(getSQLAlias(pathCriteria), 0), propertyName); + } + + return base.ToColumns(pathCriteria, propertyName, getSQLAlias); + } + public override string[] ToColumns(string alias, string propertyName) { return diff --git a/src/NHibernate/Persister/Entity/IPropertyMapping.cs b/src/NHibernate/Persister/Entity/IPropertyMapping.cs index dbe08dd9139..b348d36eae9 100644 --- a/src/NHibernate/Persister/Entity/IPropertyMapping.cs +++ b/src/NHibernate/Persister/Entity/IPropertyMapping.cs @@ -29,6 +29,8 @@ public interface IPropertyMapping /// true if a type was found, false if not bool TryToType(string propertyName, out IType type); + string[] ToColumns(ICriteria pathCriteria, string propertyName, System.Func getSQLAlias); + /// /// Given a query alias and a property path, return the qualified column name /// @@ -40,4 +42,4 @@ public interface IPropertyMapping /// Given a property path, return the corresponding column name(s). string[] ToColumns(string propertyName); } -} \ No newline at end of file +} From 3e3555473f7e16858e78a1c4acca9bf4cdcaafbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Delaporte?= <12201973+fredericDelaporte@users.noreply.github.com> Date: Wed, 3 Oct 2018 20:26:14 +0200 Subject: [PATCH 10/43] Support evaluation of DateTime.Now on db side And of all similar properties: UtcNow, Today, and DateTimeOffset's ones. Part of #959 Co-authored-by: maca88 --- doc/reference/modules/configuration.xml | 51 ++++ .../Async/Linq/MiscellaneousTextFixture.cs | 3 +- .../Linq/MiscellaneousTextFixture.cs | 3 +- .../Linq/PreEvaluationTests.cs | 267 ++++++++++++++++++ src/NHibernate.Test/Linq/TryGetMappedTests.cs | 2 +- src/NHibernate.Test/TestCase.cs | 28 +- src/NHibernate/Cfg/Environment.cs | 42 +++ src/NHibernate/Cfg/Settings.cs | 38 +++ src/NHibernate/Cfg/SettingsFactory.cs | 9 + src/NHibernate/Dialect/DB2Dialect.cs | 4 +- src/NHibernate/Dialect/Dialect.cs | 2 +- src/NHibernate/Dialect/FirebirdDialect.cs | 3 +- src/NHibernate/Dialect/HanaDialectBase.cs | 6 +- src/NHibernate/Dialect/InformixDialect.cs | 3 +- src/NHibernate/Dialect/MsSql2000Dialect.cs | 4 +- src/NHibernate/Dialect/MsSql2008Dialect.cs | 11 +- src/NHibernate/Dialect/MsSqlCeDialect.cs | 3 +- src/NHibernate/Dialect/MySQL55Dialect.cs | 7 + src/NHibernate/Dialect/MySQLDialect.cs | 2 +- src/NHibernate/Dialect/Oracle8iDialect.cs | 4 +- src/NHibernate/Dialect/Oracle9iDialect.cs | 10 + src/NHibernate/Dialect/PostgreSQL81Dialect.cs | 6 + src/NHibernate/Dialect/PostgreSQLDialect.cs | 3 +- src/NHibernate/Dialect/SQLiteDialect.cs | 7 +- src/NHibernate/Dialect/SybaseASA9Dialect.cs | 2 +- src/NHibernate/Dialect/SybaseASE15Dialect.cs | 5 +- .../Dialect/SybaseSQLAnywhere10Dialect.cs | 4 +- .../Dialect/SybaseSQLAnywhere12Dialect.cs | 7 + .../Linq/Functions/DateTimeNowHqlGenerator.cs | 74 +++++ .../DefaultLinqToHqlGeneratorsRegistry.cs | 1 + .../IAllowPreEvaluationHqlGenerator.cs | 24 ++ .../Functions/IHqlGeneratorForProperty.cs | 44 ++- src/NHibernate/Linq/NhLinqExpression.cs | 2 +- src/NHibernate/Linq/NhRelinqQueryParser.cs | 21 +- .../Visitors/MemberExpressionJoinDetector.cs | 7 + .../NhPartialEvaluatingExpressionVisitor.cs | 40 ++- .../Visitors/NullableExpressionDetector.cs | 7 + .../Linq/Visitors/WhereJoinDetector.cs | 17 +- src/NHibernate/NHibernateUtil.cs | 5 + src/NHibernate/Type/DateType.cs | 2 +- src/NHibernate/Type/LocalDateType.cs | 17 ++ src/NHibernate/Util/ReflectHelper.cs | 15 + src/NHibernate/nhibernate-configuration.xsd | 34 +++ 43 files changed, 798 insertions(+), 48 deletions(-) create mode 100644 src/NHibernate.Test/Linq/PreEvaluationTests.cs create mode 100644 src/NHibernate/Linq/Functions/DateTimeNowHqlGenerator.cs create mode 100644 src/NHibernate/Linq/Functions/IAllowPreEvaluationHqlGenerator.cs create mode 100644 src/NHibernate/Type/LocalDateType.cs diff --git a/doc/reference/modules/configuration.xml b/doc/reference/modules/configuration.xml index cb67b15519a..ae80b0a1c10 100644 --- a/doc/reference/modules/configuration.xml +++ b/doc/reference/modules/configuration.xml @@ -717,6 +717,57 @@ var session = sessions.OpenSession(conn); + + + linqtohql.legacy_preevaluation + + + Whether to use the legacy pre-evaluation or not in Linq queries. Defaults to true. + + eg. + true | false + + + Legacy pre-evaluation is causing special properties or functions like DateTime.Now + or Guid.NewGuid() to be always evaluated with the .Net runtime and replaced in the + query by parameter values. + + + The new pre-evaluation allows them to be converted to HQL function calls which will be run on the db + side. This allows for example to retrieve the server time instead of the client time, or to generate + UUIDs for each row instead of an unique one for all rows. + + + The new pre-evaluation will likely be enabled by default in the next major version (6.0). + + + + + + linqtohql.fallback_on_preevaluation + + + When the new pre-evaluation is enabled, should methods which translation is not supported by the current + dialect fallback to pre-evaluation? Defaults to false. + + eg. + true | false + + + When this fallback option is enabled while legacy pre-evaluation is disabled, properties or functions + like DateTime.Now or Guid.NewGuid() used in Linq expressions + will not fail when the dialect does not support them, but will instead be pre-evaluated. + + + When this fallback option is disabled while legacy pre-evaluation is disabled, properties or functions + like DateTime.Now or Guid.NewGuid() used in Linq expressions + will fail when the dialect does not support them. + + + This option has no effect if the legacy pre-evaluation is enabled. + + + sql_exception_converter diff --git a/src/NHibernate.Test/Async/Linq/MiscellaneousTextFixture.cs b/src/NHibernate.Test/Async/Linq/MiscellaneousTextFixture.cs index 2d660623a63..588d8e112ee 100644 --- a/src/NHibernate.Test/Async/Linq/MiscellaneousTextFixture.cs +++ b/src/NHibernate.Test/Async/Linq/MiscellaneousTextFixture.cs @@ -27,7 +27,8 @@ public class MiscellaneousTextFixtureAsync : LinqTestCase [Test(Description = "This sample uses Count to find the number of Orders placed before yesterday in the database.")] public async Task CountWithWhereClauseAsync() { - var q = from o in db.Orders where o.OrderDate <= DateTime.Today.AddDays(-1) select o; + var yesterday = DateTime.Today.AddDays(-1); + var q = from o in db.Orders where o.OrderDate <= yesterday select o; var count = await (q.CountAsync()); diff --git a/src/NHibernate.Test/Linq/MiscellaneousTextFixture.cs b/src/NHibernate.Test/Linq/MiscellaneousTextFixture.cs index 9eada444639..983919e2914 100644 --- a/src/NHibernate.Test/Linq/MiscellaneousTextFixture.cs +++ b/src/NHibernate.Test/Linq/MiscellaneousTextFixture.cs @@ -27,7 +27,8 @@ from s in db.Shippers [Test(Description = "This sample uses Count to find the number of Orders placed before yesterday in the database.")] public void CountWithWhereClause() { - var q = from o in db.Orders where o.OrderDate <= DateTime.Today.AddDays(-1) select o; + var yesterday = DateTime.Today.AddDays(-1); + var q = from o in db.Orders where o.OrderDate <= yesterday select o; var count = q.Count(); diff --git a/src/NHibernate.Test/Linq/PreEvaluationTests.cs b/src/NHibernate.Test/Linq/PreEvaluationTests.cs new file mode 100644 index 00000000000..e5921352f16 --- /dev/null +++ b/src/NHibernate.Test/Linq/PreEvaluationTests.cs @@ -0,0 +1,267 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using NHibernate.Cfg; +using NHibernate.SqlTypes; +using NUnit.Framework; +using Environment = NHibernate.Cfg.Environment; + +namespace NHibernate.Test.Linq +{ + [TestFixture(false, false)] + [TestFixture(true, false)] + [TestFixture(false, true)] + public class PreEvaluationTests : LinqTestCase + { + private readonly bool LegacyPreEvaluation; + private readonly bool FallbackOnPreEvaluation; + + public PreEvaluationTests(bool legacy, bool fallback) + { + LegacyPreEvaluation = legacy; + FallbackOnPreEvaluation = fallback; + } + + protected override void Configure(Configuration configuration) + { + base.Configure(configuration); + + configuration.SetProperty(Environment.FormatSql, "false"); + configuration.SetProperty(Environment.LinqToHqlLegacyPreEvaluation, LegacyPreEvaluation.ToString()); + configuration.SetProperty(Environment.LinqToHqlFallbackOnPreEvaluation, FallbackOnPreEvaluation.ToString()); + } + + [Test] + public void CanQueryByDateTimeNowUsingNotEqual() + { + var isSupported = IsFunctionSupported("current_timestamp"); + RunTest( + isSupported, + spy => + { + var x = db.Orders.Count(o => o.OrderDate.Value != DateTime.Now); + + Assert.That(x, Is.GreaterThan(0)); + AssertFunctionInSql("current_timestamp", spy); + }); + } + + [Test] + public void CanQueryByDateTimeNow() + { + var isSupported = IsFunctionSupported("current_timestamp"); + RunTest( + isSupported, + spy => + { + var x = db.Orders.Count(o => o.OrderDate.Value < DateTime.Now); + + Assert.That(x, Is.GreaterThan(0)); + AssertFunctionInSql("current_timestamp", spy); + }); + } + + [Test] + public void CanSelectDateTimeNow() + { + var isSupported = IsFunctionSupported("current_timestamp"); + RunTest( + isSupported, + spy => + { + var x = + db + .Orders.Select(o => new { id = o.OrderId, d = DateTime.Now }) + .OrderBy(o => o.id).Take(1).ToList(); + + Assert.That(x, Has.Count.GreaterThan(0)); + Assert.That(x[0].d.Kind, Is.EqualTo(DateTimeKind.Local)); + AssertFunctionInSql("current_timestamp", spy); + }); + } + + [Test] + public void CanQueryByDateTimeUtcNow() + { + var isSupported = IsFunctionSupported("current_utctimestamp"); + RunTest( + isSupported, + spy => + { + var x = db.Orders.Count(o => o.OrderDate.Value < DateTime.UtcNow); + + Assert.That(x, Is.GreaterThan(0)); + AssertFunctionInSql("current_utctimestamp", spy); + }); + } + + [Test] + public void CanSelectDateTimeUtcNow() + { + var isSupported = IsFunctionSupported("current_utctimestamp"); + RunTest( + isSupported, + spy => + { + var x = + db + .Orders.Select(o => new { id = o.OrderId, d = DateTime.UtcNow }) + .OrderBy(o => o.id).Take(1).ToList(); + + Assert.That(x, Has.Count.GreaterThan(0)); + Assert.That(x[0].d.Kind, Is.EqualTo(DateTimeKind.Utc)); + AssertFunctionInSql("current_utctimestamp", spy); + }); + } + + [Test] + public void CanQueryByDateTimeToday() + { + var isSupported = IsFunctionSupported("current_date"); + RunTest( + isSupported, + spy => + { + var x = db.Orders.Count(o => o.OrderDate.Value < DateTime.Today); + + Assert.That(x, Is.GreaterThan(0)); + AssertFunctionInSql("current_date", spy); + }); + } + + [Test] + public void CanSelectDateTimeToday() + { + var isSupported = IsFunctionSupported("current_date"); + RunTest( + isSupported, + spy => + { + var x = + db + .Orders.Select(o => new { id = o.OrderId, d = DateTime.Today }) + .OrderBy(o => o.id).Take(1).ToList(); + + Assert.That(x, Has.Count.GreaterThan(0)); + Assert.That(x[0].d.Kind, Is.EqualTo(DateTimeKind.Local)); + AssertFunctionInSql("current_date", spy); + }); + } + + [Test] + public void CanQueryByDateTimeOffsetTimeNow() + { + if (!TestDialect.SupportsSqlType(SqlTypeFactory.DateTimeOffSet)) + Assert.Ignore("Dialect does not support DateTimeOffSet"); + + var isSupported = IsFunctionSupported("current_timestamp_offset"); + RunTest( + isSupported, + spy => + { + var testDate = DateTimeOffset.Now.AddDays(-1); + var x = db.Orders.Count(o => testDate < DateTimeOffset.Now); + + Assert.That(x, Is.GreaterThan(0)); + AssertFunctionInSql("current_timestamp_offset", spy); + }); + } + + [Test] + public void CanSelectDateTimeOffsetNow() + { + if (!TestDialect.SupportsSqlType(SqlTypeFactory.DateTimeOffSet)) + Assert.Ignore("Dialect does not support DateTimeOffSet"); + + var isSupported = IsFunctionSupported("current_timestamp_offset"); + RunTest( + isSupported, + spy => + { + var x = + db + .Orders.Select(o => new { id = o.OrderId, d = DateTimeOffset.Now }) + .OrderBy(o => o.id).Take(1).ToList(); + + Assert.That(x, Has.Count.GreaterThan(0)); + Assert.That(x[0].d.Offset, Is.EqualTo(DateTimeOffset.Now.Offset)); + AssertFunctionInSql("current_timestamp_offset", spy); + }); + } + + [Test] + public void CanQueryByDateTimeOffsetUtcNow() + { + if (!TestDialect.SupportsSqlType(SqlTypeFactory.DateTimeOffSet)) + Assert.Ignore("Dialect does not support DateTimeOffSet"); + + var isSupported = IsFunctionSupported("current_utctimestamp_offset"); + RunTest( + isSupported, + spy => + { + var testDate = DateTimeOffset.UtcNow.AddDays(-1); + var x = db.Orders.Count(o => testDate < DateTimeOffset.UtcNow); + + Assert.That(x, Is.GreaterThan(0)); + AssertFunctionInSql("current_utctimestamp_offset", spy); + }); + } + + [Test] + public void CanSelectDateTimeOffsetUtcNow() + { + if (!TestDialect.SupportsSqlType(SqlTypeFactory.DateTimeOffSet)) + Assert.Ignore("Dialect does not support DateTimeOffSet"); + + var isSupported = IsFunctionSupported("current_utctimestamp_offset"); + RunTest( + isSupported, + spy => + { + var x = + db + .Orders.Select(o => new { id = o.OrderId, d = DateTimeOffset.UtcNow }) + .OrderBy(o => o.id).Take(1).ToList(); + + Assert.That(x, Has.Count.GreaterThan(0)); + Assert.That(x[0].d.Offset, Is.EqualTo(TimeSpan.Zero)); + AssertFunctionInSql("current_utctimestamp_offset", spy); + }); + } + + private void RunTest(bool isSupported, Action test) + { + using (var spy = new SqlLogSpy()) + { + try + { + test(spy); + } + catch (QueryException) + { + if (!isSupported && !FallbackOnPreEvaluation) + // Expected failure + return; + throw; + } + } + + if (!isSupported && !FallbackOnPreEvaluation) + Assert.Fail("The test should have thrown a QueryException, but has not thrown anything"); + } + + private void AssertFunctionInSql(string functionName, SqlLogSpy spy) + { + if (!IsFunctionSupported(functionName)) + Assert.Inconclusive($"{functionName} is not supported by the dialect"); + + var function = Dialect.Functions[functionName].Render(new List(), Sfi).ToString(); + + if (LegacyPreEvaluation) + Assert.That(spy.GetWholeLog(), Does.Not.Contain(function)); + else + Assert.That(spy.GetWholeLog(), Does.Contain(function)); + } + } +} diff --git a/src/NHibernate.Test/Linq/TryGetMappedTests.cs b/src/NHibernate.Test/Linq/TryGetMappedTests.cs index 11724e1ac9b..b65aa43f701 100644 --- a/src/NHibernate.Test/Linq/TryGetMappedTests.cs +++ b/src/NHibernate.Test/Linq/TryGetMappedTests.cs @@ -773,7 +773,7 @@ private void AssertResult( expectedComponentType = expectedComponentType ?? (o => o == null); var expression = query.Expression; - NhRelinqQueryParser.PreTransform(expression); + NhRelinqQueryParser.PreTransform(expression, Sfi); var constantToParameterMap = ExpressionParameterVisitor.Visit(expression, Sfi); var queryModel = NhRelinqQueryParser.Parse(expression); var requiredHqlParameters = new List(); diff --git a/src/NHibernate.Test/TestCase.cs b/src/NHibernate.Test/TestCase.cs index 0d59b989cf4..600265ade3f 100644 --- a/src/NHibernate.Test/TestCase.cs +++ b/src/NHibernate.Test/TestCase.cs @@ -460,24 +460,40 @@ protected DateTime RoundForDialect(DateTime value) }} }; + protected bool IsFunctionSupported(string functionName) + { + // We could test Sfi.SQLFunctionRegistry.HasFunction(functionName) which has the advantage of + // accounting for additional functions added in configuration. But Dialect is normally never + // null, while Sfi could be not yet initialized, depending from where this function is called. + // Furthermore there are currently no additional functions added in configuration for NHibernate + // tests. + var dialect = Dialect; + if (!dialect.Functions.ContainsKey(functionName)) + return false; + + return !DialectsNotSupportingStandardFunction.TryGetValue(functionName, out var dialects) || + !dialects.Contains(dialect.GetType()); + } + protected void AssumeFunctionSupported(string functionName) { // We could test Sfi.SQLFunctionRegistry.HasFunction(functionName) which has the advantage of - // accounting for additionnal functions added in configuration. But Dialect is normally never + // accounting for additional functions added in configuration. But Dialect is normally never // null, while Sfi could be not yet initialized, depending from where this function is called. - // Furtermore there are currently no additionnal functions added in configuration for NHibernate + // Furthermore there are currently no additional functions added in configuration for NHibernate // tests. + var dialect = Dialect; Assume.That( - Dialect.Functions, + dialect.Functions, Does.ContainKey(functionName), - $"{Dialect} doesn't support {functionName} function."); + $"{dialect} doesn't support {functionName} function."); if (!DialectsNotSupportingStandardFunction.TryGetValue(functionName, out var dialects)) return; Assume.That( dialects, - Does.Not.Contain(Dialect.GetType()), - $"{Dialect} doesn't support {functionName} standard function."); + Does.Not.Contain(dialect.GetType()), + $"{dialect} doesn't support {functionName} standard function."); } protected void ClearQueryPlanCache() diff --git a/src/NHibernate/Cfg/Environment.cs b/src/NHibernate/Cfg/Environment.cs index 6613fe0fd2e..55a21c43637 100644 --- a/src/NHibernate/Cfg/Environment.cs +++ b/src/NHibernate/Cfg/Environment.cs @@ -226,6 +226,48 @@ public static string Version public const string LinqToHqlGeneratorsRegistry = "linqtohql.generatorsregistry"; + /// + /// Whether to use the legacy pre-evaluation or not in Linq queries. true by default. + /// + /// + /// + /// Legacy pre-evaluation is causing special properties or functions like DateTime.Now or + /// Guid.NewGuid() to be always evaluated with the .Net runtime and replaced in the query by + /// parameter values. + /// + /// + /// The new pre-evaluation allows them to be converted to HQL function calls which will be run on the db + /// side. This allows for example to retrieve the server time instead of the client time, or to generate + /// UUIDs for each row instead of an unique one for all rows. (This does not happen if the dialect does + /// not support the required HQL function.) + /// + /// + /// The new pre-evaluation will likely be enabled by default in the next major version (6.0). + /// + /// + public const string LinqToHqlLegacyPreEvaluation = "linqtohql.legacy_preevaluation"; + + /// + /// When the new pre-evaluation is enabled, should methods which translation is not supported by the current + /// dialect fallback to pre-evaluation? false by default. + /// + /// + /// + /// When this fallback option is enabled while legacy pre-evaluation is disabled, properties or functions + /// like DateTime.Now or Guid.NewGuid() used in Linq expressions will not fail when the dialect does not + /// support them, but will instead be pre-evaluated. + /// + /// + /// When this fallback option is disabled while legacy pre-evaluation is disabled, properties or functions + /// like DateTime.Now or Guid.NewGuid() used in Linq expressions will fail when the dialect does not + /// support them. + /// + /// + /// This option has no effect if the legacy pre-evaluation is enabled. + /// + /// + public const string LinqToHqlFallbackOnPreEvaluation = "linqtohql.fallback_on_preevaluation"; + /// Enable ordering of insert statements for the purpose of more efficient batching. public const string OrderInserts = "order_inserts"; diff --git a/src/NHibernate/Cfg/Settings.cs b/src/NHibernate/Cfg/Settings.cs index 4d4fc1fa96e..be520295c93 100644 --- a/src/NHibernate/Cfg/Settings.cs +++ b/src/NHibernate/Cfg/Settings.cs @@ -143,6 +143,44 @@ public Settings() /// public ILinqToHqlGeneratorsRegistry LinqToHqlGeneratorsRegistry { get; internal set; } + /// + /// Whether to use the legacy pre-evaluation or not in Linq queries. true by default. + /// + /// + /// + /// Legacy pre-evaluation is causing special properties or functions like DateTime.Now or + /// Guid.NewGuid() to be always evaluated with the .Net runtime and replaced in the query by + /// parameter values. + /// + /// + /// The new pre-evaluation allows them to be converted to HQL function calls which will be run on the db + /// side. This allows for example to retrieve the server time instead of the client time, or to generate + /// UUIDs for each row instead of an unique one for all rows. + /// + /// + public bool LinqToHqlLegacyPreEvaluation { get; internal set; } + + /// + /// When the new pre-evaluation is enabled, should methods which translation is not supported by the current + /// dialect fallback to pre-evaluation? false by default. + /// + /// + /// + /// When this fallback option is enabled while legacy pre-evaluation is disabled, properties or functions + /// like DateTime.Now or Guid.NewGuid() used in Linq expressions will not fail when the dialect does not + /// support them, but will instead be pre-evaluated. + /// + /// + /// When this fallback option is disabled while legacy pre-evaluation is disabled, properties or functions + /// like DateTime.Now or Guid.NewGuid() used in Linq expressions will fail when the dialect does not + /// support them. + /// + /// + /// This option has no effect if the legacy pre-evaluation is enabled. + /// + /// + public bool LinqToHqlFallbackOnPreEvaluation { get; internal set; } + public IQueryModelRewriterFactory QueryModelRewriterFactory { get; internal set; } #endregion diff --git a/src/NHibernate/Cfg/SettingsFactory.cs b/src/NHibernate/Cfg/SettingsFactory.cs index 24c5eebce67..46928b002cf 100644 --- a/src/NHibernate/Cfg/SettingsFactory.cs +++ b/src/NHibernate/Cfg/SettingsFactory.cs @@ -54,6 +54,15 @@ public Settings BuildSettings(IDictionary properties) settings.Dialect = dialect; settings.LinqToHqlGeneratorsRegistry = LinqToHqlGeneratorsRegistryFactory.CreateGeneratorsRegistry(properties); + // 6.0 TODO: default to false instead of true, and adjust documentation in xsd, xml comment on Environment + // and Setting properties, and doc\reference. + settings.LinqToHqlLegacyPreEvaluation = PropertiesHelper.GetBoolean( + Environment.LinqToHqlLegacyPreEvaluation, + properties, + true); + settings.LinqToHqlFallbackOnPreEvaluation = PropertiesHelper.GetBoolean( + Environment.LinqToHqlFallbackOnPreEvaluation, + properties); #region SQL Exception converter diff --git a/src/NHibernate/Dialect/DB2Dialect.cs b/src/NHibernate/Dialect/DB2Dialect.cs index 3eef1635595..81c7aae473d 100644 --- a/src/NHibernate/Dialect/DB2Dialect.cs +++ b/src/NHibernate/Dialect/DB2Dialect.cs @@ -87,6 +87,8 @@ public DB2Dialect() RegisterFunction("tan", new StandardSQLFunction("tan", NHibernateUtil.Double)); RegisterFunction("variance", new StandardSQLFunction("variance", NHibernateUtil.Double)); + RegisterFunction("current_timestamp", new NoArgSQLFunction("current_timestamp", NHibernateUtil.LocalDateTime, false)); + RegisterFunction("current_date", new NoArgSQLFunction("current_date", NHibernateUtil.LocalDate, false)); RegisterFunction("julian_day", new StandardSQLFunction("julian_day", NHibernateUtil.Int32)); RegisterFunction("microsecond", new StandardSQLFunction("microsecond", NHibernateUtil.Int32)); RegisterFunction("midnight_seconds", new StandardSQLFunction("midnight_seconds", NHibernateUtil.Int32)); @@ -138,8 +140,6 @@ public DB2Dialect() RegisterFunction("bxor", new Function.BitwiseFunctionOperation("bitxor")); RegisterFunction("bnot", new Function.BitwiseFunctionOperation("bitnot")); - RegisterFunction("current_timestamp", new NoArgSQLFunction("current_timestamp", NHibernateUtil.DateTime, false)); - DefaultProperties[Environment.ConnectionDriver] = "NHibernate.Driver.DB2Driver"; } diff --git a/src/NHibernate/Dialect/Dialect.cs b/src/NHibernate/Dialect/Dialect.cs index 423c5082361..af1995dca06 100644 --- a/src/NHibernate/Dialect/Dialect.cs +++ b/src/NHibernate/Dialect/Dialect.cs @@ -105,7 +105,7 @@ protected Dialect() // the syntax of current_timestamp is extracted from H3.2 tests // - test\hql\ASTParserLoadingTest.java // - test\org\hibernate\test\hql\HQLTest.java - RegisterFunction("current_timestamp", new NoArgSQLFunction("current_timestamp", NHibernateUtil.DateTime, true)); + RegisterFunction("current_timestamp", new NoArgSQLFunction("current_timestamp", NHibernateUtil.LocalDateTime, true)); RegisterFunction("sysdate", new NoArgSQLFunction("sysdate", NHibernateUtil.DateTime, false)); //map second/minute/hour/day/month/year to ANSI extract(), override on subclasses diff --git a/src/NHibernate/Dialect/FirebirdDialect.cs b/src/NHibernate/Dialect/FirebirdDialect.cs index af32c3e09d5..e4b8b5dcb8a 100644 --- a/src/NHibernate/Dialect/FirebirdDialect.cs +++ b/src/NHibernate/Dialect/FirebirdDialect.cs @@ -154,7 +154,7 @@ public override SqlString Render(IList args, ISessionFactoryImplementor factory) [Serializable] private class CurrentTimeStamp : NoArgSQLFunction { - public CurrentTimeStamp() : base("current_timestamp", NHibernateUtil.DateTime, true) + public CurrentTimeStamp() : base("current_timestamp", NHibernateUtil.LocalDateTime, true) { } @@ -413,6 +413,7 @@ protected virtual void RegisterFunctions() private void OverrideStandardHQLFunctions() { RegisterFunction("current_timestamp", new CurrentTimeStamp()); + RegisterFunction("current_date", new NoArgSQLFunction("current_date", NHibernateUtil.LocalDate, false)); RegisterFunction("length", new StandardSafeSQLFunction("char_length", NHibernateUtil.Int64, 1)); RegisterFunction("nullif", new StandardSafeSQLFunction("nullif", 2)); RegisterFunction("lower", new StandardSafeSQLFunction("lower", NHibernateUtil.String, 1)); diff --git a/src/NHibernate/Dialect/HanaDialectBase.cs b/src/NHibernate/Dialect/HanaDialectBase.cs index b8de4f2a5d7..150ba9aab2d 100644 --- a/src/NHibernate/Dialect/HanaDialectBase.cs +++ b/src/NHibernate/Dialect/HanaDialectBase.cs @@ -439,20 +439,20 @@ protected virtual void RegisterHANAFunctions() RegisterFunction("cosh", new StandardSQLFunction("cosh", NHibernateUtil.Double)); RegisterFunction("cot", new StandardSQLFunction("cot", NHibernateUtil.Double)); RegisterFunction("current_connection", new NoArgSQLFunction("current_connection", NHibernateUtil.Int32)); - RegisterFunction("current_date", new NoArgSQLFunction("current_date", NHibernateUtil.DateTime, false)); + RegisterFunction("current_date", new NoArgSQLFunction("current_date", NHibernateUtil.LocalDate, false)); RegisterFunction("current_identity_value", new NoArgSQLFunction("current_identity_value", NHibernateUtil.Int64)); RegisterFunction("current_mvcc_snapshot_timestamp", new NoArgSQLFunction("current_mvcc_snapshot_timestamp", NHibernateUtil.Int32)); RegisterFunction("current_object_schema", new NoArgSQLFunction("current_object_schema", NHibernateUtil.String)); RegisterFunction("current_schema", new NoArgSQLFunction("current_schema", NHibernateUtil.String, false)); RegisterFunction("current_time", new NoArgSQLFunction("current_time", NHibernateUtil.DateTime, false)); - RegisterFunction("current_timestamp", new NoArgSQLFunction("current_timestamp", NHibernateUtil.DateTime, false)); + RegisterFunction("current_timestamp", new NoArgSQLFunction("current_timestamp", NHibernateUtil.LocalDateTime, false)); RegisterFunction("current_transaction_isolation_level", new NoArgSQLFunction("current_transaction_isolation_level", NHibernateUtil.String, false)); RegisterFunction("current_update_statement_sequence", new NoArgSQLFunction("current_update_statement_sequence", NHibernateUtil.Int64)); RegisterFunction("current_update_transaction", new NoArgSQLFunction("current_update_transaction", NHibernateUtil.Int64)); RegisterFunction("current_user", new NoArgSQLFunction("current_user", NHibernateUtil.String, false)); RegisterFunction("current_utcdate", new NoArgSQLFunction("current_utcdate", NHibernateUtil.DateTime, false)); RegisterFunction("current_utctime", new NoArgSQLFunction("current_utctime", NHibernateUtil.DateTime, false)); - RegisterFunction("current_utctimestamp", new NoArgSQLFunction("current_utctimestamp", NHibernateUtil.DateTime, false)); + RegisterFunction("current_utctimestamp", new NoArgSQLFunction("current_utctimestamp", NHibernateUtil.UtcDateTime, false)); RegisterFunction("dayname", new StandardSQLFunction("dayname", NHibernateUtil.String)); RegisterFunction("dayofmonth", new StandardSQLFunction("dayofmonth", NHibernateUtil.Int32)); RegisterFunction("dayofyear", new StandardSQLFunction("dayofyear", NHibernateUtil.Int32)); diff --git a/src/NHibernate/Dialect/InformixDialect.cs b/src/NHibernate/Dialect/InformixDialect.cs index 7739018dc1e..2f942815a57 100644 --- a/src/NHibernate/Dialect/InformixDialect.cs +++ b/src/NHibernate/Dialect/InformixDialect.cs @@ -78,7 +78,8 @@ public InformixDialect() // RegisterFunction("cast", new CastFunction()); // RegisterFunction("concat", new VarArgsSQLFunction(NHibernateUtil.String, "(", "||", ")")); - RegisterFunction("current_timestamp", new NoArgSQLFunction("current", NHibernateUtil.DateTime, false)); + RegisterFunction("current_timestamp", new NoArgSQLFunction("current", NHibernateUtil.LocalDateTime, false)); + RegisterFunction("current_date", new NoArgSQLFunction("today", NHibernateUtil.LocalDate, false)); RegisterFunction("sysdate", new NoArgSQLFunction("today", NHibernateUtil.DateTime, false)); RegisterFunction("current", new NoArgSQLFunction("current", NHibernateUtil.DateTime, false)); RegisterFunction("today", new NoArgSQLFunction("today", NHibernateUtil.DateTime, false)); diff --git a/src/NHibernate/Dialect/MsSql2000Dialect.cs b/src/NHibernate/Dialect/MsSql2000Dialect.cs index 5cada797985..e4dd5a9e0af 100644 --- a/src/NHibernate/Dialect/MsSql2000Dialect.cs +++ b/src/NHibernate/Dialect/MsSql2000Dialect.cs @@ -326,7 +326,9 @@ protected virtual void RegisterFunctions() RegisterFunction("right", new SQLFunctionTemplate(NHibernateUtil.String, "right(?1, ?2)")); RegisterFunction("locate", new StandardSQLFunction("charindex", NHibernateUtil.Int32)); - RegisterFunction("current_timestamp", new NoArgSQLFunction("getdate", NHibernateUtil.DateTime, true)); + RegisterFunction("current_timestamp", new NoArgSQLFunction("getdate", NHibernateUtil.LocalDateTime, true)); + RegisterFunction("current_date", new SQLFunctionTemplate(NHibernateUtil.LocalDate, "dateadd(dd, 0, datediff(dd, 0, getdate()))")); + RegisterFunction("current_utctimestamp", new NoArgSQLFunction("getutcdate", NHibernateUtil.UtcDateTime, true)); RegisterFunction("second", new SQLFunctionTemplate(NHibernateUtil.Int32, "datepart(second, ?1)")); RegisterFunction("minute", new SQLFunctionTemplate(NHibernateUtil.Int32, "datepart(minute, ?1)")); RegisterFunction("hour", new SQLFunctionTemplate(NHibernateUtil.Int32, "datepart(hour, ?1)")); diff --git a/src/NHibernate/Dialect/MsSql2008Dialect.cs b/src/NHibernate/Dialect/MsSql2008Dialect.cs index 7c40549a700..d0ef5580389 100644 --- a/src/NHibernate/Dialect/MsSql2008Dialect.cs +++ b/src/NHibernate/Dialect/MsSql2008Dialect.cs @@ -51,11 +51,20 @@ protected override void RegisterFunctions() { RegisterFunction( "current_timestamp", - new NoArgSQLFunction("sysdatetime", NHibernateUtil.DateTime, true)); + new NoArgSQLFunction("sysdatetime", NHibernateUtil.LocalDateTime, true)); + RegisterFunction( + "current_utctimestamp", + new NoArgSQLFunction("sysutcdatetime", NHibernateUtil.UtcDateTime, true)); } + + RegisterFunction("current_date", new SQLFunctionTemplate(NHibernateUtil.LocalDate, "cast(getdate() as date)")); RegisterFunction( "current_timestamp_offset", new NoArgSQLFunction("sysdatetimeoffset", NHibernateUtil.DateTimeOffset, true)); + RegisterFunction( + "current_utctimestamp_offset", + new SQLFunctionTemplate(NHibernateUtil.DateTimeOffset, "todatetimeoffset(sysutcdatetime(), 0)")); + RegisterFunction("date", new SQLFunctionTemplate(NHibernateUtil.Date, "cast(?1 as date)")); } protected override void RegisterKeywords() diff --git a/src/NHibernate/Dialect/MsSqlCeDialect.cs b/src/NHibernate/Dialect/MsSqlCeDialect.cs index 90da41cf9bb..46c15f40cf5 100644 --- a/src/NHibernate/Dialect/MsSqlCeDialect.cs +++ b/src/NHibernate/Dialect/MsSqlCeDialect.cs @@ -172,7 +172,8 @@ protected virtual void RegisterFunctions() RegisterFunction("str", new SQLFunctionTemplate(NHibernateUtil.String, "cast(?1 as nvarchar)")); RegisterFunction("strguid", new SQLFunctionTemplate(NHibernateUtil.String, "cast(?1 as nvarchar)")); - RegisterFunction("current_timestamp", new NoArgSQLFunction("getdate", NHibernateUtil.DateTime, true)); + RegisterFunction("current_timestamp", new NoArgSQLFunction("getdate", NHibernateUtil.LocalDateTime, true)); + RegisterFunction("current_date", new SQLFunctionTemplate(NHibernateUtil.LocalDate, "dateadd(dd, 0, datediff(dd, 0, getdate()))")); RegisterFunction("date", new SQLFunctionTemplate(NHibernateUtil.DateTime, "dateadd(dd, 0, datediff(dd, 0, ?1))")); RegisterFunction("second", new SQLFunctionTemplate(NHibernateUtil.Int32, "datepart(second, ?1)")); RegisterFunction("minute", new SQLFunctionTemplate(NHibernateUtil.Int32, "datepart(minute, ?1)")); diff --git a/src/NHibernate/Dialect/MySQL55Dialect.cs b/src/NHibernate/Dialect/MySQL55Dialect.cs index c7a8004cb1d..26dd7de2709 100644 --- a/src/NHibernate/Dialect/MySQL55Dialect.cs +++ b/src/NHibernate/Dialect/MySQL55Dialect.cs @@ -10,5 +10,12 @@ public MySQL55Dialect() RegisterColumnType(DbType.Guid, "CHAR(36)"); RegisterFunction("strguid", new SQLFunctionTemplate(NHibernateUtil.String, "?1")); } + + protected override void RegisterFunctions() + { + base.RegisterFunctions(); + + RegisterFunction("current_utctimestamp", new NoArgSQLFunction("UTC_TIMESTAMP", NHibernateUtil.UtcDateTime, true)); + } } } diff --git a/src/NHibernate/Dialect/MySQLDialect.cs b/src/NHibernate/Dialect/MySQLDialect.cs index 9a83de26eb3..a1816e95209 100644 --- a/src/NHibernate/Dialect/MySQLDialect.cs +++ b/src/NHibernate/Dialect/MySQLDialect.cs @@ -294,7 +294,7 @@ protected virtual void RegisterFunctions() RegisterFunction("hex", new StandardSQLFunction("hex", NHibernateUtil.String)); RegisterFunction("soundex", new StandardSQLFunction("soundex", NHibernateUtil.String)); - RegisterFunction("current_date", new NoArgSQLFunction("current_date", NHibernateUtil.Date, false)); + RegisterFunction("current_date", new NoArgSQLFunction("current_date", NHibernateUtil.LocalDate, false)); RegisterFunction("current_time", new NoArgSQLFunction("current_time", NHibernateUtil.Time, false)); RegisterFunction("second", new StandardSQLFunction("second", NHibernateUtil.Int32)); diff --git a/src/NHibernate/Dialect/Oracle8iDialect.cs b/src/NHibernate/Dialect/Oracle8iDialect.cs index b2103b87965..5d6bad35705 100644 --- a/src/NHibernate/Dialect/Oracle8iDialect.cs +++ b/src/NHibernate/Dialect/Oracle8iDialect.cs @@ -252,7 +252,7 @@ protected virtual void RegisterFunctions() // In Oracle, date includes a time, just with fractional seconds dropped. For actually only having // the date, it must be truncated. Otherwise comparisons may yield unexpected results. - RegisterFunction("current_date", new SQLFunctionTemplate(NHibernateUtil.Date, "trunc(current_date)")); + RegisterFunction("current_date", new SQLFunctionTemplate(NHibernateUtil.LocalDate, "trunc(current_date)")); RegisterFunction("current_time", new NoArgSQLFunction("current_timestamp", NHibernateUtil.Time, false)); RegisterFunction("current_timestamp", new CurrentTimeStamp()); @@ -571,7 +571,7 @@ public override bool SupportsExistsInSelect [Serializable] private class CurrentTimeStamp : NoArgSQLFunction { - public CurrentTimeStamp() : base("current_timestamp", NHibernateUtil.DateTime, true) {} + public CurrentTimeStamp() : base("current_timestamp", NHibernateUtil.LocalDateTime, true) {} public override SqlString Render(IList args, ISessionFactoryImplementor factory) { diff --git a/src/NHibernate/Dialect/Oracle9iDialect.cs b/src/NHibernate/Dialect/Oracle9iDialect.cs index 868b8170ccd..b36b1a34b37 100644 --- a/src/NHibernate/Dialect/Oracle9iDialect.cs +++ b/src/NHibernate/Dialect/Oracle9iDialect.cs @@ -1,4 +1,5 @@ using System.Data; +using NHibernate.Dialect.Function; using NHibernate.SqlCommand; using NHibernate.SqlTypes; @@ -41,6 +42,15 @@ protected override void RegisterDateTimeTypeMappings() RegisterColumnType(DbType.Xml, "XMLTYPE"); } + protected override void RegisterFunctions() + { + base.RegisterFunctions(); + + RegisterFunction( + "current_utctimestamp", + new SQLFunctionTemplate(NHibernateUtil.UtcDateTime, "SYS_EXTRACT_UTC(current_timestamp)")); + } + public override long TimestampResolutionInTicks => 1; public override string GetSelectClauseNullString(SqlType sqlType) diff --git a/src/NHibernate/Dialect/PostgreSQL81Dialect.cs b/src/NHibernate/Dialect/PostgreSQL81Dialect.cs index 77953561029..b59ca08388c 100644 --- a/src/NHibernate/Dialect/PostgreSQL81Dialect.cs +++ b/src/NHibernate/Dialect/PostgreSQL81Dialect.cs @@ -1,4 +1,5 @@ using System.Data; +using NHibernate.Dialect.Function; using NHibernate.SqlCommand; namespace NHibernate.Dialect @@ -40,6 +41,11 @@ protected override void RegisterDateTimeTypeMappings() RegisterColumnType(DbType.Time, 6, "time($s)"); // Not overriding default scale: Posgres doc writes it means "no explicit limit", so max of what it can support, // which suits our needs. + + // timezone seems not available prior to version 8.0 + RegisterFunction( + "current_utctimestamp", + new SQLFunctionTemplate(NHibernateUtil.UtcDateTime, "timezone('UTC', current_timestamp)")); } public override string ForUpdateNowaitString diff --git a/src/NHibernate/Dialect/PostgreSQLDialect.cs b/src/NHibernate/Dialect/PostgreSQLDialect.cs index b81c2df86a8..da522e533f3 100644 --- a/src/NHibernate/Dialect/PostgreSQLDialect.cs +++ b/src/NHibernate/Dialect/PostgreSQLDialect.cs @@ -60,7 +60,7 @@ public PostgreSQLDialect() RegisterColumnType(DbType.String, 1073741823, "text"); // Override standard HQL function - RegisterFunction("current_timestamp", new NoArgSQLFunction("now", NHibernateUtil.DateTime, true)); + RegisterFunction("current_timestamp", new NoArgSQLFunction("now", NHibernateUtil.LocalDateTime, true)); RegisterFunction("str", new SQLFunctionTemplate(NHibernateUtil.String, "cast(?1 as varchar)")); RegisterFunction("locate", new PositionSubstringFunction()); RegisterFunction("iif", new SQLFunctionTemplate(null, "case when ?1 then ?2 else ?3 end")); @@ -94,6 +94,7 @@ public PostgreSQLDialect() // Register the date function, since when used in LINQ select clauses, NH must know the data type. RegisterFunction("date", new SQLFunctionTemplate(NHibernateUtil.Date, "cast(?1 as date)")); + RegisterFunction("current_date", new NoArgSQLFunction("current_date", NHibernateUtil.LocalDate, false)); RegisterFunction("strguid", new SQLFunctionTemplate(NHibernateUtil.String, "?1::TEXT")); diff --git a/src/NHibernate/Dialect/SQLiteDialect.cs b/src/NHibernate/Dialect/SQLiteDialect.cs index f1f86743313..22506edf333 100644 --- a/src/NHibernate/Dialect/SQLiteDialect.cs +++ b/src/NHibernate/Dialect/SQLiteDialect.cs @@ -74,12 +74,17 @@ protected virtual void RegisterFunctions() RegisterFunction("month", new SQLFunctionTemplate(NHibernateUtil.Int32, "cast(strftime('%m', ?1) as int)")); RegisterFunction("year", new SQLFunctionTemplate(NHibernateUtil.Int32, "cast(strftime('%Y', ?1) as int)")); // Uses local time like MSSQL and PostgreSQL. - RegisterFunction("current_timestamp", new SQLFunctionTemplate(NHibernateUtil.DateTime, "datetime(current_timestamp, 'localtime')")); + RegisterFunction("current_timestamp", new SQLFunctionTemplate(NHibernateUtil.LocalDateTime, "datetime(current_timestamp, 'localtime')")); + RegisterFunction("current_utctimestamp", new SQLFunctionTemplate(NHibernateUtil.UtcDateTime, "datetime(current_timestamp)")); // The System.Data.SQLite driver stores both Date and DateTime as 'YYYY-MM-DD HH:MM:SS' // The SQLite date() function returns YYYY-MM-DD, which unfortunately SQLite does not consider // as equal to 'YYYY-MM-DD 00:00:00'. Because of this, it is best to return the // 'YYYY-MM-DD 00:00:00' format for the date function. RegisterFunction("date", new SQLFunctionTemplate(NHibernateUtil.Date, "datetime(date(?1))")); + // SQLite has current_date, but as current_timestamp, it is in UTC. So converting the timestamp to + // localtime then to date then, like the above date function, go back to datetime format for comparisons + // sake. + RegisterFunction("current_date", new SQLFunctionTemplate(NHibernateUtil.LocalDate, "datetime(date(current_timestamp, 'localtime'))")); RegisterFunction("substring", new StandardSQLFunction("substr", NHibernateUtil.String)); RegisterFunction("left", new SQLFunctionTemplate(NHibernateUtil.String, "substr(?1,1,?2)")); diff --git a/src/NHibernate/Dialect/SybaseASA9Dialect.cs b/src/NHibernate/Dialect/SybaseASA9Dialect.cs index f839871b6a1..dfae7baa471 100644 --- a/src/NHibernate/Dialect/SybaseASA9Dialect.cs +++ b/src/NHibernate/Dialect/SybaseASA9Dialect.cs @@ -72,7 +72,7 @@ public SybaseASA9Dialect() //RegisterColumnType(DbType.Xml, "TEXT"); // Override standard HQL function - RegisterFunction("current_timestamp", new StandardSQLFunction("current_timestamp")); + RegisterFunction("current_timestamp", new StandardSQLFunction("current_timestamp", NHibernateUtil.LocalDateTime)); RegisterFunction("length", new StandardSafeSQLFunction("length", NHibernateUtil.String, 1)); RegisterFunction("nullif", new StandardSafeSQLFunction("nullif", 2)); RegisterFunction("lower", new StandardSafeSQLFunction("lower", NHibernateUtil.String, 1)); diff --git a/src/NHibernate/Dialect/SybaseASE15Dialect.cs b/src/NHibernate/Dialect/SybaseASE15Dialect.cs index 0a84a822424..de9514431ca 100644 --- a/src/NHibernate/Dialect/SybaseASE15Dialect.cs +++ b/src/NHibernate/Dialect/SybaseASE15Dialect.cs @@ -68,9 +68,10 @@ public SybaseASE15Dialect() RegisterFunction("concat", new VarArgsSQLFunction(NHibernateUtil.String, "(","+",")")); RegisterFunction("cos", new StandardSQLFunction("cos", NHibernateUtil.Double)); RegisterFunction("cot", new StandardSQLFunction("cot", NHibernateUtil.Double)); - RegisterFunction("current_date", new NoArgSQLFunction("current_date", NHibernateUtil.Date)); + RegisterFunction("current_date", new NoArgSQLFunction("current_date", NHibernateUtil.LocalDate)); RegisterFunction("current_time", new NoArgSQLFunction("current_time", NHibernateUtil.Time)); - RegisterFunction("current_timestamp", new NoArgSQLFunction("getdate", NHibernateUtil.DateTime)); + RegisterFunction("current_timestamp", new NoArgSQLFunction("getdate", NHibernateUtil.LocalDateTime)); + RegisterFunction("current_utctimestamp", new NoArgSQLFunction("getutcdate", NHibernateUtil.UtcDateTime)); RegisterFunction("datename", new StandardSQLFunction("datename", NHibernateUtil.String)); RegisterFunction("day", new StandardSQLFunction("day", NHibernateUtil.Int32)); RegisterFunction("degrees", new StandardSQLFunction("degrees", NHibernateUtil.Double)); diff --git a/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs b/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs index 388947b45c4..10256111696 100644 --- a/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs +++ b/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs @@ -238,9 +238,9 @@ protected virtual void RegisterDateFunctions() RegisterFunction("ymd", new StandardSQLFunction("ymd", NHibernateUtil.Date)); // compatibility functions - RegisterFunction("current_timestamp", new NoArgSQLFunction("getdate", NHibernateUtil.DateTime, true)); + RegisterFunction("current_timestamp", new NoArgSQLFunction("getdate", NHibernateUtil.LocalDateTime, true)); RegisterFunction("current_time", new NoArgSQLFunction("getdate", NHibernateUtil.Time, true)); - RegisterFunction("current_date", new SQLFunctionTemplate(NHibernateUtil.Date, "date(getdate())")); + RegisterFunction("current_date", new SQLFunctionTemplate(NHibernateUtil.LocalDate, "date(getdate())")); } protected virtual void RegisterStringFunctions() diff --git a/src/NHibernate/Dialect/SybaseSQLAnywhere12Dialect.cs b/src/NHibernate/Dialect/SybaseSQLAnywhere12Dialect.cs index a0b0125e8b7..88a846ba1de 100644 --- a/src/NHibernate/Dialect/SybaseSQLAnywhere12Dialect.cs +++ b/src/NHibernate/Dialect/SybaseSQLAnywhere12Dialect.cs @@ -77,9 +77,16 @@ protected override void RegisterDateTimeTypeMappings() protected override void RegisterDateFunctions() { base.RegisterDateFunctions(); + + RegisterFunction( + "current_utctimestamp", + new SQLFunctionTemplate(NHibernateUtil.UtcDateTime, "cast(current UTC timestamp as timestamp)")); RegisterFunction( "current_timestamp_offset", new NoArgSQLFunction("sysdatetimeoffset", NHibernateUtil.DateTimeOffset, true)); + RegisterFunction( + "current_utctimestamp_offset", + new SQLFunctionTemplate(NHibernateUtil.DateTimeOffset, "(current UTC timestamp)")); } /// diff --git a/src/NHibernate/Linq/Functions/DateTimeNowHqlGenerator.cs b/src/NHibernate/Linq/Functions/DateTimeNowHqlGenerator.cs new file mode 100644 index 00000000000..ef2c154b81d --- /dev/null +++ b/src/NHibernate/Linq/Functions/DateTimeNowHqlGenerator.cs @@ -0,0 +1,74 @@ +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Reflection; +using NHibernate.Engine; +using NHibernate.Hql.Ast; +using NHibernate.Linq.Visitors; +using NHibernate.Util; +using Environment = NHibernate.Cfg.Environment; + +namespace NHibernate.Linq.Functions +{ + public class DateTimeNowHqlGenerator : BaseHqlGeneratorForProperty, IAllowPreEvaluationHqlGenerator + { + private static readonly MemberInfo DateTimeNow = ReflectHelper.GetProperty(() => DateTime.Now); + private static readonly MemberInfo DateTimeUtcNow = ReflectHelper.GetProperty(() => DateTime.UtcNow); + private static readonly MemberInfo DateTimeToday = ReflectHelper.GetProperty(() => DateTime.Today); + private static readonly MemberInfo DateTimeOffsetNow = ReflectHelper.GetProperty(() => DateTimeOffset.Now); + private static readonly MemberInfo DateTimeOffsetUtcNow = ReflectHelper.GetProperty(() => DateTimeOffset.UtcNow); + + private readonly Dictionary _hqlFunctions = new Dictionary() + { + { DateTimeNow, "current_timestamp" }, + { DateTimeUtcNow, "current_utctimestamp" }, + // There is also sysdate, but it is troublesome: under some databases, "sys" prefixed functions return the + // system time (time according to the server time zone) while "current" prefixed functions return the + // session time (time according to the connection time zone), thus introducing a discrepancy with + // current_timestamp. + // Moreover sysdate is registered by default as a datetime, not as a date. (It could make sense for + // Oracle, which returns a time part for dates, just dropping fractional seconds. But Oracle dialect + // overrides it as a NHibernate date, without truncating it for SQL comparisons...) + { DateTimeToday, "current_date" }, + { DateTimeOffsetNow, "current_timestamp_offset" }, + { DateTimeOffsetUtcNow, "current_utctimestamp_offset" }, + }; + + public DateTimeNowHqlGenerator() + { + SupportedProperties = new[] + { + DateTimeNow, + DateTimeUtcNow, + DateTimeToday, + DateTimeOffsetNow, + DateTimeOffsetUtcNow, + }; + } + + public override HqlTreeNode BuildHql( + MemberInfo member, + Expression expression, + HqlTreeBuilder treeBuilder, + IHqlExpressionVisitor visitor) + { + return treeBuilder.MethodCall(_hqlFunctions[member]); + } + + public bool AllowPreEvaluation(MemberInfo member, ISessionFactoryImplementor factory) + { + var functionName = _hqlFunctions[member]; + if (factory.Dialect.Functions.ContainsKey(functionName)) + return false; + + if (factory.Settings.LinqToHqlFallbackOnPreEvaluation) + return true; + + throw new QueryException( + $"Cannot translate {member.DeclaringType.Name}.{member.Name}: {functionName} is " + + $"not supported by {factory.Dialect}. Either enable the fallback on pre-evaluation " + + $"({Environment.LinqToHqlFallbackOnPreEvaluation}) or evaluate {member.Name} " + + "outside of the query."); + } + } +} diff --git a/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs b/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs index ea5ab8a159c..27c28e1bf72 100644 --- a/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs +++ b/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs @@ -56,6 +56,7 @@ public DefaultLinqToHqlGeneratorsRegistry() this.Merge(new CollectionContainsGenerator()); this.Merge(new DateTimePropertiesHqlGenerator()); + this.Merge(new DateTimeNowHqlGenerator()); this.Merge(new DecimalAddGenerator()); this.Merge(new DecimalDivideGenerator()); diff --git a/src/NHibernate/Linq/Functions/IAllowPreEvaluationHqlGenerator.cs b/src/NHibernate/Linq/Functions/IAllowPreEvaluationHqlGenerator.cs new file mode 100644 index 00000000000..ab2266b0f9a --- /dev/null +++ b/src/NHibernate/Linq/Functions/IAllowPreEvaluationHqlGenerator.cs @@ -0,0 +1,24 @@ +using System; +using System.Reflection; +using NHibernate.Engine; + +namespace NHibernate.Linq.Functions +{ + public interface IAllowPreEvaluationHqlGenerator + { + /// + /// Should pre-evaluation be allowed for this property? + /// + /// The property. + /// The session factory. + /// + /// if the property should be evaluated before running the query whenever possible, + /// if it must always be translated to the equivalent HQL call. + /// + /// Implementors should return by default. Returning + /// is mainly useful when the HQL translation is a non-deterministic function call like NEWGUID() or + /// a function which value on server side can differ from the equivalent client value, like + /// . + bool AllowPreEvaluation(MemberInfo member, ISessionFactoryImplementor factory); + } +} diff --git a/src/NHibernate/Linq/Functions/IHqlGeneratorForProperty.cs b/src/NHibernate/Linq/Functions/IHqlGeneratorForProperty.cs index 83650ae2185..474ea552ced 100644 --- a/src/NHibernate/Linq/Functions/IHqlGeneratorForProperty.cs +++ b/src/NHibernate/Linq/Functions/IHqlGeneratorForProperty.cs @@ -1,14 +1,46 @@ using System.Collections.Generic; using System.Linq.Expressions; using System.Reflection; +using NHibernate.Engine; using NHibernate.Hql.Ast; using NHibernate.Linq.Visitors; namespace NHibernate.Linq.Functions { - public interface IHqlGeneratorForProperty - { - IEnumerable SupportedProperties { get; } - HqlTreeNode BuildHql(MemberInfo member, Expression expression, HqlTreeBuilder treeBuilder, IHqlExpressionVisitor visitor); - } -} \ No newline at end of file + public interface IHqlGeneratorForProperty + { + IEnumerable SupportedProperties { get; } + + HqlTreeNode BuildHql( + MemberInfo member, + Expression expression, + HqlTreeBuilder treeBuilder, + IHqlExpressionVisitor visitor); + } + + // 6.0 TODO: merge into IHqlGeneratorForProperty + public static class HqlGeneratorForPropertyExtensions + { + /// + /// Should pre-evaluation be allowed for this property? + /// + /// The property's HQL generator. + /// The property. + /// The session factory. + /// + /// if the property should be evaluated before running the query whenever possible, + /// if it must always be translated to the equivalent HQL call. + /// + public static bool AllowPreEvaluation( + this IHqlGeneratorForProperty generator, + MemberInfo member, + ISessionFactoryImplementor factory) + { + if (generator is IAllowPreEvaluationHqlGenerator allowPreEvalGenerator) + return allowPreEvalGenerator.AllowPreEvaluation(member, factory); + + // By default, everything should be pre-evaluated whenever possible. + return true; + } + } +} diff --git a/src/NHibernate/Linq/NhLinqExpression.cs b/src/NHibernate/Linq/NhLinqExpression.cs index 2a00840949a..918b56a37f8 100644 --- a/src/NHibernate/Linq/NhLinqExpression.cs +++ b/src/NHibernate/Linq/NhLinqExpression.cs @@ -39,7 +39,7 @@ public class NhLinqExpression : IQueryExpression, ICacheableQueryExpression public NhLinqExpression(Expression expression, ISessionFactoryImplementor sessionFactory) { - _expression = NhRelinqQueryParser.PreTransform(expression); + _expression = NhRelinqQueryParser.PreTransform(expression, sessionFactory); // We want logging to be as close as possible to the original expression sent from the // application. But if we log before partial evaluation done in PreTransform, the log won't diff --git a/src/NHibernate/Linq/NhRelinqQueryParser.cs b/src/NHibernate/Linq/NhRelinqQueryParser.cs index bafd050280b..d32908ed0af 100644 --- a/src/NHibernate/Linq/NhRelinqQueryParser.cs +++ b/src/NHibernate/Linq/NhRelinqQueryParser.cs @@ -1,8 +1,10 @@ +using System; using System.Collections; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Reflection; +using NHibernate.Engine; using NHibernate.Linq.ExpressionTransformers; using NHibernate.Linq.Visitors; using NHibernate.Util; @@ -44,15 +46,30 @@ static NhRelinqQueryParser() QueryParser = new QueryParser(expressionTreeParser); } + // Obsolete since v5.3 /// - /// Applies the minimal transformations required before parameterization, + /// Applies the minimal transformations required before parametrization, /// expression key computing and parsing. /// /// The expression to transform. /// The transformed expression. + [Obsolete("Use overload with an additional sessionFactory parameter")] public static Expression PreTransform(Expression expression) { - var partiallyEvaluatedExpression = NhPartialEvaluatingExpressionVisitor.EvaluateIndependentSubtrees(expression); + return PreTransform(expression, null); + } + + /// + /// Applies the minimal transformations required before parametrization, + /// expression key computing and parsing. + /// + /// The expression to transform. + /// The session factory. + /// The transformed expression. + public static Expression PreTransform(Expression expression, ISessionFactoryImplementor sessionFactory) + { + var partiallyEvaluatedExpression = + NhPartialEvaluatingExpressionVisitor.EvaluateIndependentSubtrees(expression, sessionFactory); return PreProcessor.Process(partiallyEvaluatedExpression); } diff --git a/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs b/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs index 934fba8ec94..f333fc5e61d 100644 --- a/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs +++ b/src/NHibernate/Linq/Visitors/MemberExpressionJoinDetector.cs @@ -32,6 +32,13 @@ public MemberExpressionJoinDetector(IIsEntityDecider isEntityDecider, IJoiner jo protected override Expression VisitMember(MemberExpression expression) { + // A static member expression such as DateTime.Now has a null Expression. + if (expression.Expression == null) + { + // A static member call is never a join, and it is not an instance member access either. + return base.VisitMember(expression); + } + var isIdentifier = _isEntityDecider.IsIdentifier(expression.Expression.Type, expression.Member.Name); if (isIdentifier) _hasIdentifier = true; diff --git a/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs index 105a098817d..dccc28df9ab 100644 --- a/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs @@ -2,6 +2,8 @@ using System.Linq; using System.Linq.Expressions; using NHibernate.Collection; +using NHibernate.Engine; +using NHibernate.Linq.Functions; using NHibernate.Util; using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Parsing; @@ -12,20 +14,31 @@ namespace NHibernate.Linq.Visitors { internal class NhPartialEvaluatingExpressionVisitor : RelinqExpressionVisitor, IPartialEvaluationExceptionExpressionVisitor { + private readonly ISessionFactoryImplementor _sessionFactory; + + internal NhPartialEvaluatingExpressionVisitor(ISessionFactoryImplementor sessionFactory) + { + _sessionFactory = sessionFactory; + } + protected override Expression VisitConstant(ConstantExpression expression) { if (expression.Value is Expression value) { - return EvaluateIndependentSubtrees(value); + return EvaluateIndependentSubtrees(value, _sessionFactory); } return base.VisitConstant(expression); } - public static Expression EvaluateIndependentSubtrees(Expression expression) + public static Expression EvaluateIndependentSubtrees( + Expression expression, + ISessionFactoryImplementor sessionFactory) { - var evaluatedExpression = PartialEvaluatingExpressionVisitor.EvaluateIndependentSubtrees(expression, new NhEvaluatableExpressionFilter()); - return new NhPartialEvaluatingExpressionVisitor().Visit(evaluatedExpression); + var evaluatedExpression = PartialEvaluatingExpressionVisitor.EvaluateIndependentSubtrees( + expression, + new NhEvaluatableExpressionFilter(sessionFactory)); + return new NhPartialEvaluatingExpressionVisitor(sessionFactory).Visit(evaluatedExpression); } public Expression VisitPartialEvaluationException(PartialEvaluationExceptionExpression partialEvaluationExceptionExpression) @@ -38,6 +51,13 @@ public Expression VisitPartialEvaluationException(PartialEvaluationExceptionExpr internal class NhEvaluatableExpressionFilter : EvaluatableExpressionFilterBase { + private readonly ISessionFactoryImplementor _sessionFactory; + + internal NhEvaluatableExpressionFilter(ISessionFactoryImplementor sessionFactory) + { + _sessionFactory = sessionFactory; + } + public override bool IsEvaluatableConstant(ConstantExpression node) { if (node.Value is IPersistentCollection && node.Value is IQueryable) @@ -48,6 +68,18 @@ public override bool IsEvaluatableConstant(ConstantExpression node) return base.IsEvaluatableConstant(node); } + public override bool IsEvaluatableMember(MemberExpression node) + { + if (node == null) + throw new ArgumentNullException(nameof(node)); + + if (_sessionFactory == null || _sessionFactory.Settings.LinqToHqlLegacyPreEvaluation || + !_sessionFactory.Settings.LinqToHqlGeneratorsRegistry.TryGetGenerator(node.Member, out var generator)) + return true; + + return generator.AllowPreEvaluation(node.Member, _sessionFactory); + } + public override bool IsEvaluatableMethodCall(MethodCallExpression node) { if (node == null) diff --git a/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs b/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs index d9e2d6c06f5..0fa6e4da40d 100644 --- a/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs +++ b/src/NHibernate/Linq/Visitors/NullableExpressionDetector.cs @@ -128,6 +128,13 @@ private bool IsNullable(MemberExpression memberExpression, BinaryExpression equa { if (_functionRegistry.TryGetGenerator(memberExpression.Member, out _)) { + // The expression can be null when the member is static (e.g. DateTime.Now). + // In such cases we suppose that the value cannot be null. + if (memberExpression.Expression == null) + { + return false; + } + // We have to skip the property as it will be converted to a function that can return null // if the argument is null return IsNullable(memberExpression.Expression, equalityExpression); diff --git a/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs b/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs index 68f88c6cc55..4be0b8d2af2 100644 --- a/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs +++ b/src/NHibernate/Linq/Visitors/WhereJoinDetector.cs @@ -2,7 +2,6 @@ using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; -using NHibernate.Linq.Clauses; using NHibernate.Linq.ReWriters; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; @@ -289,7 +288,7 @@ protected override Expression VisitSubQuery(SubQueryExpression expression) return expression; } - // We would usually get NULL if one of our inner member expresions was null. + // We would usually get NULL if one of our inner member expressions was null. // However, it's possible a method call will convert the null value from the failed join into a non-null value. // This could be optimized by actually checking what the method does. For example StartsWith("s") would leave null as null and would still allow us to inner join. //protected override Expression VisitMethodCall(MethodCallExpression expression) @@ -307,7 +306,17 @@ protected override Expression VisitMember(MemberExpression expression) // I'm not sure what processing re-linq does to strange member expressions. // TODO: I suspect this code doesn't add the right joins for the last case. - var isIdentifier = _isEntityDecider.IsIdentifier(expression.Expression.Type, expression.Member.Name); + // A static member expression such as DateTime.Now has a null Expression. + if (expression.Expression == null) + { + // A static member call is never a join, and it is not an instance member access either: leave + // the current value on stack, untouched. + return base.VisitMember(expression); + } + + var isIdentifier = _isEntityDecider.IsIdentifier( + expression.Expression.Type, + expression.Member.Name); if (!isIdentifier) _memberExpressionDepth++; @@ -332,7 +341,7 @@ protected override Expression VisitMember(MemberExpression expression) values.MemberExpressionValuesIfEmptyOuterJoined[key] = PossibleValueSet.CreateNull(expression.Type); } SetResultValues(values); - + return result; } diff --git a/src/NHibernate/NHibernateUtil.cs b/src/NHibernate/NHibernateUtil.cs index e781db52bef..226d26bf67c 100644 --- a/src/NHibernate/NHibernateUtil.cs +++ b/src/NHibernate/NHibernateUtil.cs @@ -137,6 +137,11 @@ public static IType GuessType(System.Type type) /// public static readonly DateType Date = new DateType(); + /// + /// NHibernate local date type + /// + public static readonly DateType LocalDate = new LocalDateType(); + /// /// NHibernate decimal type /// diff --git a/src/NHibernate/Type/DateType.cs b/src/NHibernate/Type/DateType.cs index 08e4097a7a7..76d3fbb99e9 100644 --- a/src/NHibernate/Type/DateType.cs +++ b/src/NHibernate/Type/DateType.cs @@ -35,7 +35,7 @@ public DateType() : base(SqlTypeFactory.Date) /// protected override DateTime AdjustDateTime(DateTime dateValue) => - dateValue.Date; + Kind == DateTimeKind.Unspecified ? dateValue.Date : DateTime.SpecifyKind(dateValue.Date, Kind); /// public override bool IsEqual(object x, object y) diff --git a/src/NHibernate/Type/LocalDateType.cs b/src/NHibernate/Type/LocalDateType.cs new file mode 100644 index 00000000000..d6fadaec31b --- /dev/null +++ b/src/NHibernate/Type/LocalDateType.cs @@ -0,0 +1,17 @@ +using System; +using System.Data; + +namespace NHibernate.Type +{ + /// + /// Maps the Year, Month, and Day of a Property to a + /// column. Specify when reading + /// dates from . + /// + [Serializable] + public class LocalDateType : DateType + { + /// + protected override DateTimeKind Kind => DateTimeKind.Local; + } +} diff --git a/src/NHibernate/Util/ReflectHelper.cs b/src/NHibernate/Util/ReflectHelper.cs index 7de13090ff7..fc9856203ba 100644 --- a/src/NHibernate/Util/ReflectHelper.cs +++ b/src/NHibernate/Util/ReflectHelper.cs @@ -149,6 +149,21 @@ public static MemberInfo GetProperty(Expression + /// Gets the static field or property to be accessed. + /// + /// The type of the property. + /// The expression representing the property getter. + /// The of the property. + public static MemberInfo GetProperty(Expression> property) + { + if (property == null) + { + throw new ArgumentNullException(nameof(property)); + } + return ((MemberExpression)property.Body).Member; + } + internal static bool ParameterTypesMatch(ParameterInfo[] parameters, System.Type[] types) { if (parameters.Length != types.Length) diff --git a/src/NHibernate/nhibernate-configuration.xsd b/src/NHibernate/nhibernate-configuration.xsd index 85fc25825e1..c8280b03de8 100644 --- a/src/NHibernate/nhibernate-configuration.xsd +++ b/src/NHibernate/nhibernate-configuration.xsd @@ -152,6 +152,40 @@ + + + + Whether to use the legacy pre-evaluation or not in Linq queries. true by default. + + Legacy pre-evaluation is causing special properties or functions like DateTime.Now or Guid.NewGuid() + to be always evaluated with the .Net runtime and replaced in the query by parameter values. + + The new pre-evaluation allows them to be converted to HQL function calls which will be run on the db + side. This allows for example to retrieve the server time instead of the client time, or to generate + UUIDs for each row instead of an unique one for all rows. + + The new pre-evaluation will likely be enabled by default in the next major version (6.0). + + + + + + + When the new pre-evaluation is enabled, should methods which translation is not supported by the current + dialect fallback to pre-evaluation? false by default. + + When this fallback option is enabled while legacy pre-evaluation is disabled, properties or functions + like DateTime.Now or Guid.NewGuid() used in Linq expressions will not fail when the dialect does not + support them, but will instead be pre-evaluated. + + When this fallback option is disabled while legacy pre-evaluation is disabled, properties or functions + like DateTime.Now or Guid.NewGuid() used in Linq expressions will fail when the dialect does not + support them. + + This option has no effect if the legacy pre-evaluation is enabled. + + + From 130ee880a6537d77008b5dd295826d2dc360cb5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Delaporte?= <12201973+fredericDelaporte@users.noreply.github.com> Date: Thu, 4 Oct 2018 19:37:55 +0200 Subject: [PATCH 11/43] Support evaluation of Guid.NewGuid() on db side Part of #959 Co-authored-by: maca88 --- .../Linq/PreEvaluationTests.cs | 40 ++++++++++++++++ src/NHibernate/Dialect/FirebirdDialect.cs | 1 + src/NHibernate/Dialect/HanaDialectBase.cs | 1 + src/NHibernate/Dialect/MsSql2000Dialect.cs | 2 + src/NHibernate/Dialect/MsSqlCeDialect.cs | 2 + src/NHibernate/Dialect/MySQL5Dialect.cs | 6 +++ src/NHibernate/Dialect/Oracle8iDialect.cs | 2 + src/NHibernate/Dialect/PostgreSQLDialect.cs | 4 ++ src/NHibernate/Dialect/SQLiteDialect.cs | 2 + src/NHibernate/Dialect/SybaseASE15Dialect.cs | 5 ++ .../Dialect/SybaseSQLAnywhere10Dialect.cs | 1 + .../DefaultLinqToHqlGeneratorsRegistry.cs | 2 + .../IAllowPreEvaluationHqlGenerator.cs | 6 +-- .../Linq/Functions/IHqlGeneratorForMethod.cs | 24 ++++++++++ .../Linq/Functions/NewGuidHqlGenerator.cs | 48 +++++++++++++++++++ .../NhPartialEvaluatingExpressionVisitor.cs | 10 +++- 16 files changed, 151 insertions(+), 5 deletions(-) create mode 100644 src/NHibernate/Linq/Functions/NewGuidHqlGenerator.cs diff --git a/src/NHibernate.Test/Linq/PreEvaluationTests.cs b/src/NHibernate.Test/Linq/PreEvaluationTests.cs index e5921352f16..aec3b583298 100644 --- a/src/NHibernate.Test/Linq/PreEvaluationTests.cs +++ b/src/NHibernate.Test/Linq/PreEvaluationTests.cs @@ -251,6 +251,46 @@ private void RunTest(bool isSupported, Action test) Assert.Fail("The test should have thrown a QueryException, but has not thrown anything"); } + [Test] + public void CanQueryByNewGuid() + { + if (!TestDialect.SupportsSqlType(SqlTypeFactory.Guid)) + Assert.Ignore("Guid are not supported by the target database"); + + var isSupported = IsFunctionSupported("new_uuid"); + RunTest( + isSupported, + spy => + { + var guid = Guid.NewGuid(); + var x = db.Orders.Count(o => guid != Guid.NewGuid()); + + Assert.That(x, Is.GreaterThan(0)); + AssertFunctionInSql("new_uuid", spy); + }); + } + + [Test] + public void CanSelectNewGuid() + { + if (!TestDialect.SupportsSqlType(SqlTypeFactory.Guid)) + Assert.Ignore("Guid are not supported by the target database"); + + var isSupported = IsFunctionSupported("new_uuid"); + RunTest( + isSupported, + spy => + { + var x = + db + .Orders.Select(o => new { id = o.OrderId, g = Guid.NewGuid() }) + .OrderBy(o => o.id).Take(1).ToList(); + + Assert.That(x, Has.Count.GreaterThan(0)); + AssertFunctionInSql("new_uuid", spy); + }); + } + private void AssertFunctionInSql(string functionName, SqlLogSpy spy) { if (!IsFunctionSupported(functionName)) diff --git a/src/NHibernate/Dialect/FirebirdDialect.cs b/src/NHibernate/Dialect/FirebirdDialect.cs index e4b8b5dcb8a..09c7fbf1f89 100644 --- a/src/NHibernate/Dialect/FirebirdDialect.cs +++ b/src/NHibernate/Dialect/FirebirdDialect.cs @@ -423,6 +423,7 @@ private void OverrideStandardHQLFunctions() RegisterFunction("strguid", new StandardSQLFunction("uuid_to_char", NHibernateUtil.String)); RegisterFunction("sysdate", new CastedFunction("today", NHibernateUtil.Date)); RegisterFunction("date", new SQLFunctionTemplate(NHibernateUtil.Date, "cast(?1 as date)")); + RegisterFunction("new_uuid", new NoArgSQLFunction("gen_uuid", NHibernateUtil.Guid)); // Bitwise operations RegisterFunction("band", new Function.BitwiseFunctionOperation("bin_and")); RegisterFunction("bor", new Function.BitwiseFunctionOperation("bin_or")); diff --git a/src/NHibernate/Dialect/HanaDialectBase.cs b/src/NHibernate/Dialect/HanaDialectBase.cs index 150ba9aab2d..d94da6f10bc 100644 --- a/src/NHibernate/Dialect/HanaDialectBase.cs +++ b/src/NHibernate/Dialect/HanaDialectBase.cs @@ -395,6 +395,7 @@ protected virtual void RegisterNHibernateFunctions() RegisterFunction("iif", new SQLFunctionTemplate(null, "case when ?1 then ?2 else ?3 end")); RegisterFunction("sysdate", new NoArgSQLFunction("current_timestamp", NHibernateUtil.DateTime, false)); RegisterFunction("truncate", new SQLFunctionTemplateWithRequiredParameters(null, "floor(?1 * power(10, ?2)) / power(10, ?2)", new object[] { null, "0" })); + RegisterFunction("new_uuid", new NoArgSQLFunction("sysuuid", NHibernateUtil.Guid, false)); } protected virtual void RegisterHANAFunctions() diff --git a/src/NHibernate/Dialect/MsSql2000Dialect.cs b/src/NHibernate/Dialect/MsSql2000Dialect.cs index e4dd5a9e0af..aab663e0880 100644 --- a/src/NHibernate/Dialect/MsSql2000Dialect.cs +++ b/src/NHibernate/Dialect/MsSql2000Dialect.cs @@ -360,6 +360,8 @@ protected virtual void RegisterFunctions() RegisterFunction("bit_length", new SQLFunctionTemplate(NHibernateUtil.Int32, "datalength(?1) * 8")); RegisterFunction("extract", new SQLFunctionTemplate(NHibernateUtil.Int32, "datepart(?1, ?3)")); + + RegisterFunction("new_uuid", new NoArgSQLFunction("newid", NHibernateUtil.Guid)); } protected virtual void RegisterGuidTypeMapping() diff --git a/src/NHibernate/Dialect/MsSqlCeDialect.cs b/src/NHibernate/Dialect/MsSqlCeDialect.cs index 46c15f40cf5..fd0baed402b 100644 --- a/src/NHibernate/Dialect/MsSqlCeDialect.cs +++ b/src/NHibernate/Dialect/MsSqlCeDialect.cs @@ -201,6 +201,8 @@ protected virtual void RegisterFunctions() RegisterFunction("bit_length", new SQLFunctionTemplate(NHibernateUtil.Int32, "datalength(?1) * 8")); RegisterFunction("extract", new SQLFunctionTemplate(NHibernateUtil.Int32, "datepart(?1, ?3)")); + + RegisterFunction("new_uuid", new NoArgSQLFunction("newid", NHibernateUtil.Guid)); } protected virtual void RegisterDefaultProperties() diff --git a/src/NHibernate/Dialect/MySQL5Dialect.cs b/src/NHibernate/Dialect/MySQL5Dialect.cs index 1dfac2f6f46..0797206b02a 100644 --- a/src/NHibernate/Dialect/MySQL5Dialect.cs +++ b/src/NHibernate/Dialect/MySQL5Dialect.cs @@ -12,8 +12,14 @@ public MySQL5Dialect() // My SQL supports precision up to 65, but .Net is limited to 28-29. RegisterColumnType(DbType.Decimal, 29, "DECIMAL($p, $s)"); RegisterColumnType(DbType.Guid, "BINARY(16)"); + } + + protected override void RegisterFunctions() + { + base.RegisterFunctions(); RegisterFunction("strguid", new SQLFunctionTemplate(NHibernateUtil.String, "concat(hex(reverse(substr(?1, 1, 4))), '-', hex(reverse(substring(?1, 5, 2))), '-', hex(reverse(substr(?1, 7, 2))), '-', hex(substr(?1, 9, 2)), '-', hex(substr(?1, 11)))")); + RegisterFunction("new_uuid", new NoArgSQLFunction("uuid", NHibernateUtil.Guid)); } protected override void RegisterCastTypes() diff --git a/src/NHibernate/Dialect/Oracle8iDialect.cs b/src/NHibernate/Dialect/Oracle8iDialect.cs index 5d6bad35705..749c1f0d056 100644 --- a/src/NHibernate/Dialect/Oracle8iDialect.cs +++ b/src/NHibernate/Dialect/Oracle8iDialect.cs @@ -310,6 +310,8 @@ protected virtual void RegisterFunctions() RegisterFunction("bor", new SQLFunctionTemplate(null, "?1 + ?2 - BITAND(?1, ?2)")); RegisterFunction("bxor", new SQLFunctionTemplate(null, "?1 + ?2 - BITAND(?1, ?2) * 2")); RegisterFunction("bnot", new SQLFunctionTemplate(null, "(-1 - ?1)")); + + RegisterFunction("new_uuid", new NoArgSQLFunction("sys_guid", NHibernateUtil.Guid)); } protected internal virtual void RegisterDefaultProperties() diff --git a/src/NHibernate/Dialect/PostgreSQLDialect.cs b/src/NHibernate/Dialect/PostgreSQLDialect.cs index da522e533f3..baa9334d788 100644 --- a/src/NHibernate/Dialect/PostgreSQLDialect.cs +++ b/src/NHibernate/Dialect/PostgreSQLDialect.cs @@ -98,6 +98,10 @@ public PostgreSQLDialect() RegisterFunction("strguid", new SQLFunctionTemplate(NHibernateUtil.String, "?1::TEXT")); + // The uuid_generate_v4 is not native and must be installed, but SelectGUIDString property already uses it, + // and NHibernate.TestDatabaseSetup does install it. + RegisterFunction("new_uuid", new NoArgSQLFunction("uuid_generate_v4", NHibernateUtil.Guid)); + RegisterKeywords(); } diff --git a/src/NHibernate/Dialect/SQLiteDialect.cs b/src/NHibernate/Dialect/SQLiteDialect.cs index 22506edf333..f8124d486c4 100644 --- a/src/NHibernate/Dialect/SQLiteDialect.cs +++ b/src/NHibernate/Dialect/SQLiteDialect.cs @@ -61,6 +61,8 @@ protected virtual void RegisterColumnTypes() RegisterColumnType(DbType.DateTime, "DATETIME"); RegisterColumnType(DbType.Time, "TIME"); RegisterColumnType(DbType.Boolean, "BOOL"); + // UNIQUEIDENTIFIER is not a SQLite type, but SQLite does not care much, see + // https://www.sqlite.org/datatype3.html RegisterColumnType(DbType.Guid, "UNIQUEIDENTIFIER"); } diff --git a/src/NHibernate/Dialect/SybaseASE15Dialect.cs b/src/NHibernate/Dialect/SybaseASE15Dialect.cs index de9514431ca..0a2800c5bb6 100644 --- a/src/NHibernate/Dialect/SybaseASE15Dialect.cs +++ b/src/NHibernate/Dialect/SybaseASE15Dialect.cs @@ -56,6 +56,9 @@ public SybaseASE15Dialect() RegisterColumnType(DbType.Date, "date"); RegisterColumnType(DbType.Binary, 8000, "varbinary($l)"); RegisterColumnType(DbType.Binary, "varbinary"); + // newid default is to generate a 32 bytes character uuid (no-dashes), but it has an option for + // including dashes, then raising it to 36 bytes. + RegisterColumnType(DbType.Guid, "varchar(36)"); RegisterFunction("abs", new StandardSQLFunction("abs")); RegisterFunction("acos", new StandardSQLFunction("acos", NHibernateUtil.Double)); @@ -113,6 +116,8 @@ public SybaseASE15Dialect() RegisterFunction("year", new StandardSQLFunction("year", NHibernateUtil.Int32)); RegisterFunction("substring", new EmulatedLengthSubstringFunction()); + + RegisterFunction("new_uuid", new NoArgSQLFunction("newid", NHibernateUtil.Guid)); } public override string AddColumnString diff --git a/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs b/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs index 10256111696..99da6ce9fbd 100644 --- a/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs +++ b/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs @@ -338,6 +338,7 @@ protected virtual void RegisterMiscellaneousFunctions() RegisterFunction("isnull", new VarArgsSQLFunction("isnull(", ",", ")")); RegisterFunction("lesser", new StandardSQLFunction("lesser")); RegisterFunction("newid", new NoArgSQLFunction("newid", NHibernateUtil.String, true)); + RegisterFunction("new_uuid", new NoArgSQLFunction("newid", NHibernateUtil.Guid)); RegisterFunction("nullif", new StandardSQLFunction("nullif")); RegisterFunction("number", new NoArgSQLFunction("number", NHibernateUtil.Int32)); RegisterFunction("plan", new VarArgsSQLFunction(NHibernateUtil.String, "plan(", ",", ")")); diff --git a/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs b/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs index 27c28e1bf72..0e921dff7eb 100644 --- a/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs +++ b/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs @@ -58,6 +58,8 @@ public DefaultLinqToHqlGeneratorsRegistry() this.Merge(new DateTimePropertiesHqlGenerator()); this.Merge(new DateTimeNowHqlGenerator()); + this.Merge(new NewGuidHqlGenerator()); + this.Merge(new DecimalAddGenerator()); this.Merge(new DecimalDivideGenerator()); this.Merge(new DecimalMultiplyGenerator()); diff --git a/src/NHibernate/Linq/Functions/IAllowPreEvaluationHqlGenerator.cs b/src/NHibernate/Linq/Functions/IAllowPreEvaluationHqlGenerator.cs index ab2266b0f9a..f15afa7c7bc 100644 --- a/src/NHibernate/Linq/Functions/IAllowPreEvaluationHqlGenerator.cs +++ b/src/NHibernate/Linq/Functions/IAllowPreEvaluationHqlGenerator.cs @@ -7,12 +7,12 @@ namespace NHibernate.Linq.Functions public interface IAllowPreEvaluationHqlGenerator { /// - /// Should pre-evaluation be allowed for this property? + /// Should pre-evaluation be allowed for this property or method? /// - /// The property. + /// The property or method. /// The session factory. /// - /// if the property should be evaluated before running the query whenever possible, + /// if the property or method should be evaluated before running the query whenever possible, /// if it must always be translated to the equivalent HQL call. /// /// Implementors should return by default. Returning diff --git a/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs b/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs index fde4ffd45f0..c4b871e6380 100644 --- a/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs +++ b/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs @@ -2,6 +2,7 @@ using System.Collections.ObjectModel; using System.Linq.Expressions; using System.Reflection; +using NHibernate.Engine; using NHibernate.Hql.Ast; using NHibernate.Linq.Visitors; @@ -31,5 +32,28 @@ public static bool AllowsNullableReturnType(this IHqlGeneratorForMethod generato return true; } + + // 6.0 TODO: merge into IHqlGeneratorForMethod + /// + /// Should pre-evaluation be allowed for this method? + /// + /// The method's HQL generator. + /// The method. + /// The session factory. + /// + /// if the method should be evaluated before running the query whenever possible, + /// if it must always be translated to the equivalent HQL call. + /// + public static bool AllowPreEvaluation( + this IHqlGeneratorForMethod generator, + MemberInfo member, + ISessionFactoryImplementor factory) + { + if (generator is IAllowPreEvaluationHqlGenerator allowPreEvalGenerator) + return allowPreEvalGenerator.AllowPreEvaluation(member, factory); + + // By default, everything should be pre-evaluated whenever possible. + return true; + } } } diff --git a/src/NHibernate/Linq/Functions/NewGuidHqlGenerator.cs b/src/NHibernate/Linq/Functions/NewGuidHqlGenerator.cs new file mode 100644 index 00000000000..a33720aed55 --- /dev/null +++ b/src/NHibernate/Linq/Functions/NewGuidHqlGenerator.cs @@ -0,0 +1,48 @@ +using System; +using System.Collections.ObjectModel; +using System.Linq.Expressions; +using System.Reflection; +using NHibernate.Engine; +using NHibernate.Hql.Ast; +using NHibernate.Linq.Visitors; +using NHibernate.Util; +using Environment = NHibernate.Cfg.Environment; + +namespace NHibernate.Linq.Functions +{ + public class NewGuidHqlGenerator : BaseHqlGeneratorForMethod, IAllowPreEvaluationHqlGenerator + { + public NewGuidHqlGenerator() + { + SupportedMethods = new[] + { + ReflectHelper.GetMethod(() => Guid.NewGuid()) + }; + } + + public override HqlTreeNode BuildHql( + MethodInfo method, + Expression targetObject, + ReadOnlyCollection arguments, + HqlTreeBuilder treeBuilder, + IHqlExpressionVisitor visitor) + { + return treeBuilder.MethodCall("new_uuid"); + } + + public bool AllowPreEvaluation(MemberInfo member, ISessionFactoryImplementor factory) + { + if (factory.Dialect.Functions.ContainsKey("new_uuid")) + return false; + + if (factory.Settings.LinqToHqlFallbackOnPreEvaluation) + return true; + + throw new QueryException( + "Cannot translate NewGuid: new_uuid is " + + $"not supported by {factory.Dialect}. Either enable the fallback on pre-evaluation " + + $"({Environment.LinqToHqlFallbackOnPreEvaluation}) or evaluate NewGuid " + + "outside of the query."); + } + } +} diff --git a/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs index dccc28df9ab..5cfb22fa178 100644 --- a/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/NhPartialEvaluatingExpressionVisitor.cs @@ -88,8 +88,14 @@ public override bool IsEvaluatableMethodCall(MethodCallExpression node) var attributes = node.Method .GetCustomAttributes(typeof(LinqExtensionMethodAttributeBase), false) .ToArray(x => (LinqExtensionMethodAttributeBase) x); - return attributes.Length == 0 || - attributes.Any(a => a.PreEvaluation == LinqExtensionPreEvaluation.AllowPreEvaluation); + if (attributes.Length > 0) + return attributes.Any(a => a.PreEvaluation == LinqExtensionPreEvaluation.AllowPreEvaluation); + + if (_sessionFactory == null || _sessionFactory.Settings.LinqToHqlLegacyPreEvaluation || + !_sessionFactory.Settings.LinqToHqlGeneratorsRegistry.TryGetGenerator(node.Method, out var generator)) + return true; + + return generator.AllowPreEvaluation(node.Method, _sessionFactory); } } } From 032e4646f8b3af5157216bcbc95d6da8e9653548 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Delaporte?= <12201973+fredericDelaporte@users.noreply.github.com> Date: Sat, 6 Oct 2018 14:28:28 +0200 Subject: [PATCH 12/43] Support evaluation of Random.Next and NextDouble on db side Fixes #959, along with two previous commits --- .../Async/Linq/PreEvaluationTests.cs | 152 ++++++++++++ .../Linq/PreEvaluationTests.cs | 217 ++++++++++++++++++ src/NHibernate/Dialect/DB2Dialect.cs | 1 + src/NHibernate/Dialect/FirebirdDialect.cs | 1 + src/NHibernate/Dialect/HanaDialectBase.cs | 1 + src/NHibernate/Dialect/MsSql2000Dialect.cs | 2 + src/NHibernate/Dialect/MySQLDialect.cs | 3 +- src/NHibernate/Dialect/Oracle10gDialect.cs | 28 ++- src/NHibernate/Dialect/SQLiteDialect.cs | 10 + src/NHibernate/Dialect/SybaseASE15Dialect.cs | 2 + .../Dialect/SybaseSQLAnywhere10Dialect.cs | 1 + .../Linq/Functions/DateTimeNowHqlGenerator.cs | 6 + .../DefaultLinqToHqlGeneratorsRegistry.cs | 1 + .../IAllowPreEvaluationHqlGenerator.cs | 10 + .../Linq/Functions/IHqlGeneratorForMethod.cs | 17 ++ .../Linq/Functions/NewGuidHqlGenerator.cs | 6 + .../Linq/Functions/RandomHqlGenerator.cs | 109 +++++++++ .../Linq/Visitors/SelectClauseNominator.cs | 11 +- 18 files changed, 572 insertions(+), 6 deletions(-) create mode 100644 src/NHibernate.Test/Async/Linq/PreEvaluationTests.cs create mode 100644 src/NHibernate/Linq/Functions/RandomHqlGenerator.cs diff --git a/src/NHibernate.Test/Async/Linq/PreEvaluationTests.cs b/src/NHibernate.Test/Async/Linq/PreEvaluationTests.cs new file mode 100644 index 00000000000..1f55f7b6cb8 --- /dev/null +++ b/src/NHibernate.Test/Async/Linq/PreEvaluationTests.cs @@ -0,0 +1,152 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System; +using System.Collections.Generic; +using System.Linq; +using NHibernate.Cfg; +using NHibernate.SqlTypes; +using NUnit.Framework; +using Environment = NHibernate.Cfg.Environment; +using NHibernate.Linq; + +namespace NHibernate.Test.Linq +{ + using System.Threading.Tasks; + [TestFixture(false, false)] + [TestFixture(true, false)] + [TestFixture(false, true)] + public class PreEvaluationTestsAsync : LinqTestCase + { + private readonly bool LegacyPreEvaluation; + private readonly bool FallbackOnPreEvaluation; + + public PreEvaluationTestsAsync(bool legacy, bool fallback) + { + LegacyPreEvaluation = legacy; + FallbackOnPreEvaluation = fallback; + } + + protected override void Configure(Configuration configuration) + { + base.Configure(configuration); + + configuration.SetProperty(Environment.FormatSql, "false"); + configuration.SetProperty(Environment.LinqToHqlLegacyPreEvaluation, LegacyPreEvaluation.ToString()); + configuration.SetProperty(Environment.LinqToHqlFallbackOnPreEvaluation, FallbackOnPreEvaluation.ToString()); + } + + private void RunTest(bool isSupported, Action test) + { + using (var spy = new SqlLogSpy()) + { + try + { + test(spy); + } + catch (QueryException) + { + if (!isSupported && !FallbackOnPreEvaluation) + // Expected failure + return; + throw; + } + } + + if (!isSupported && !FallbackOnPreEvaluation) + Assert.Fail("The test should have thrown a QueryException, but has not thrown anything"); + } + + [Test] + public async Task CanQueryByRandomIntAsync() + { + var isSupported = IsFunctionSupported("random") && IsFunctionSupported("floor"); + var idMin = await (db.Orders.MinAsync(o => o.OrderId)); + RunTest( + isSupported, + spy => + { + var random = new Random(); + // Dodge a Firebird driver limitation by putting the constants before the order id. + // This driver cast parameters to their types in some cases for avoiding Firebird complaining of not + // knowing the type of the condition. For some reasons the driver considers the casting should not be + // done next to the conditional operator. Having the cast only on one side is enough for avoiding + // Firebird complain, so moving the constants on the left side have been put before the order id, in + // order for these constants to be casted by the driver. + var x = db.Orders.Count(o => -idMin - 1 + o.OrderId < random.Next()); + + Assert.That(x, Is.GreaterThan(0)); + // Next requires support of both floor and rand + AssertFunctionInSql(IsFunctionSupported("floor") ? "random" : "floor", spy); + }); + } + + [Test] + public async Task CanQueryByRandomIntWithMaxAsync() + { + var isSupported = IsFunctionSupported("random") && IsFunctionSupported("floor"); + var idMin = await (db.Orders.MinAsync(o => o.OrderId)); + RunTest( + isSupported, + spy => + { + var random = new Random(); + // Dodge a Firebird driver limitation by putting the constants before the order id. + // This driver cast parameters to their types in some cases for avoiding Firebird complaining of not + // knowing the type of the condition. For some reasons the driver considers the casting should not be + // done next to the conditional operator. Having the cast only on one side is enough for avoiding + // Firebird complain, so moving the constants on the left side have been put before the order id, in + // order for these constants to be casted by the driver. + var x = db.Orders.Count(o => -idMin + o.OrderId <= random.Next(10)); + + Assert.That(x, Is.GreaterThan(0).And.LessThan(11)); + // Next requires support of both floor and rand + AssertFunctionInSql(IsFunctionSupported("floor") ? "random" : "floor", spy); + }); + } + + [Test] + public async Task CanQueryByRandomIntWithMinMaxAsync() + { + var isSupported = IsFunctionSupported("random") && IsFunctionSupported("floor"); + var idMin = await (db.Orders.MinAsync(o => o.OrderId)); + RunTest( + isSupported, + spy => + { + var random = new Random(); + // Dodge a Firebird driver limitation by putting the constants before the order id. + // This driver cast parameters to their types in some cases for avoiding Firebird complaining of not + // knowing the type of the condition. For some reasons the driver considers the casting should not be + // done next to the conditional operator. Having the cast only on one side is enough for avoiding + // Firebird complain, so moving the constants on the left side have been put before the order id, in + // order for these constants to be casted by the driver. + var x = db.Orders.Count(o => -idMin + o.OrderId < random.Next(1, 10)); + + Assert.That(x, Is.GreaterThan(0).And.LessThan(10)); + // Next requires support of both floor and rand + AssertFunctionInSql(IsFunctionSupported("floor") ? "random" : "floor", spy); + }); + } + + private void AssertFunctionInSql(string functionName, SqlLogSpy spy) + { + if (!IsFunctionSupported(functionName)) + Assert.Inconclusive($"{functionName} is not supported by the dialect"); + + var function = Dialect.Functions[functionName].Render(new List(), Sfi).ToString(); + + if (LegacyPreEvaluation) + Assert.That(spy.GetWholeLog(), Does.Not.Contain(function)); + else + Assert.That(spy.GetWholeLog(), Does.Contain(function)); + } + } +} diff --git a/src/NHibernate.Test/Linq/PreEvaluationTests.cs b/src/NHibernate.Test/Linq/PreEvaluationTests.cs index aec3b583298..4e12c4b9d82 100644 --- a/src/NHibernate.Test/Linq/PreEvaluationTests.cs +++ b/src/NHibernate.Test/Linq/PreEvaluationTests.cs @@ -291,6 +291,223 @@ public void CanSelectNewGuid() }); } + [Test] + public void CanQueryByRandomDouble() + { + var isSupported = IsFunctionSupported("random"); + RunTest( + isSupported, + spy => + { + var random = new Random(); + var x = db.Orders.Count(o => o.OrderId > random.NextDouble()); + + Assert.That(x, Is.GreaterThan(0)); + AssertFunctionInSql("random", spy); + }); + } + + [Test] + public void CanSelectRandomDouble() + { + var isSupported = IsFunctionSupported("random"); + RunTest( + isSupported, + spy => + { + var random = new Random(); + var x = + db + .Orders.Select(o => new { id = o.OrderId, r = random.NextDouble() }) + .OrderBy(o => o.id).ToList(); + + Assert.That(x, Has.Count.GreaterThan(0)); + var randomValues = x.Select(o => o.r).Distinct().ToArray(); + Assert.That(randomValues, Has.All.GreaterThanOrEqualTo(0).And.LessThan(1)); + + if (!LegacyPreEvaluation && IsFunctionSupported("random")) + { + // Naïve randomness check + Assert.That( + randomValues, + Has.Length.GreaterThan(x.Count / 2), + "Generated values do not seem very random"); + } + + AssertFunctionInSql("random", spy); + }); + } + + [Test] + public void CanQueryByRandomInt() + { + var isSupported = IsFunctionSupported("random") && IsFunctionSupported("floor"); + var idMin = db.Orders.Min(o => o.OrderId); + RunTest( + isSupported, + spy => + { + var random = new Random(); + // Dodge a Firebird driver limitation by putting the constants before the order id. + // This driver cast parameters to their types in some cases for avoiding Firebird complaining of not + // knowing the type of the condition. For some reasons the driver considers the casting should not be + // done next to the conditional operator. Having the cast only on one side is enough for avoiding + // Firebird complain, so moving the constants on the left side have been put before the order id, in + // order for these constants to be casted by the driver. + var x = db.Orders.Count(o => -idMin - 1 + o.OrderId < random.Next()); + + Assert.That(x, Is.GreaterThan(0)); + // Next requires support of both floor and rand + AssertFunctionInSql(IsFunctionSupported("floor") ? "random" : "floor", spy); + }); + } + + [Test] + public void CanSelectRandomInt() + { + var isSupported = IsFunctionSupported("random") && IsFunctionSupported("floor"); + RunTest( + isSupported, + spy => + { + var random = new Random(); + var x = + db + .Orders.Select(o => new { id = o.OrderId, r = random.Next() }) + .OrderBy(o => o.id).ToList(); + + Assert.That(x, Has.Count.GreaterThan(0)); + var randomValues = x.Select(o => o.r).Distinct().ToArray(); + Assert.That( + randomValues, + Has.All.GreaterThanOrEqualTo(0).And.LessThan(int.MaxValue).And.TypeOf()); + + if (!LegacyPreEvaluation && IsFunctionSupported("random") && IsFunctionSupported("floor")) + { + // Naïve randomness check + Assert.That( + randomValues, + Has.Length.GreaterThan(x.Count / 2), + "Generated values do not seem very random"); + } + + // Next requires support of both floor and rand + AssertFunctionInSql(IsFunctionSupported("floor") ? "random" : "floor", spy); + }); + } + + [Test] + public void CanQueryByRandomIntWithMax() + { + var isSupported = IsFunctionSupported("random") && IsFunctionSupported("floor"); + var idMin = db.Orders.Min(o => o.OrderId); + RunTest( + isSupported, + spy => + { + var random = new Random(); + // Dodge a Firebird driver limitation by putting the constants before the order id. + // This driver cast parameters to their types in some cases for avoiding Firebird complaining of not + // knowing the type of the condition. For some reasons the driver considers the casting should not be + // done next to the conditional operator. Having the cast only on one side is enough for avoiding + // Firebird complain, so moving the constants on the left side have been put before the order id, in + // order for these constants to be casted by the driver. + var x = db.Orders.Count(o => -idMin + o.OrderId <= random.Next(10)); + + Assert.That(x, Is.GreaterThan(0).And.LessThan(11)); + // Next requires support of both floor and rand + AssertFunctionInSql(IsFunctionSupported("floor") ? "random" : "floor", spy); + }); + } + + [Test] + public void CanSelectRandomIntWithMax() + { + var isSupported = IsFunctionSupported("random") && IsFunctionSupported("floor"); + RunTest( + isSupported, + spy => + { + var random = new Random(); + var x = + db + .Orders.Select(o => new { id = o.OrderId, r = random.Next(10) }) + .OrderBy(o => o.id).ToList(); + + Assert.That(x, Has.Count.GreaterThan(0)); + var randomValues = x.Select(o => o.r).Distinct().ToArray(); + Assert.That(randomValues, Has.All.GreaterThanOrEqualTo(0).And.LessThan(10).And.TypeOf()); + + if (!LegacyPreEvaluation && IsFunctionSupported("random") && IsFunctionSupported("floor")) + { + // Naïve randomness check + Assert.That( + randomValues, + Has.Length.GreaterThan(Math.Min(10, x.Count) / 2), + "Generated values do not seem very random"); + } + + // Next requires support of both floor and rand + AssertFunctionInSql(IsFunctionSupported("floor") ? "random" : "floor", spy); + }); + } + + [Test] + public void CanQueryByRandomIntWithMinMax() + { + var isSupported = IsFunctionSupported("random") && IsFunctionSupported("floor"); + var idMin = db.Orders.Min(o => o.OrderId); + RunTest( + isSupported, + spy => + { + var random = new Random(); + // Dodge a Firebird driver limitation by putting the constants before the order id. + // This driver cast parameters to their types in some cases for avoiding Firebird complaining of not + // knowing the type of the condition. For some reasons the driver considers the casting should not be + // done next to the conditional operator. Having the cast only on one side is enough for avoiding + // Firebird complain, so moving the constants on the left side have been put before the order id, in + // order for these constants to be casted by the driver. + var x = db.Orders.Count(o => -idMin + o.OrderId < random.Next(1, 10)); + + Assert.That(x, Is.GreaterThan(0).And.LessThan(10)); + // Next requires support of both floor and rand + AssertFunctionInSql(IsFunctionSupported("floor") ? "random" : "floor", spy); + }); + } + + [Test] + public void CanSelectRandomIntWithMinMax() + { + var isSupported = IsFunctionSupported("random") && IsFunctionSupported("floor"); + RunTest( + isSupported, + spy => + { + var random = new Random(); + var x = + db + .Orders.Select(o => new { id = o.OrderId, r = random.Next(1, 11) }) + .OrderBy(o => o.id).ToList(); + + Assert.That(x, Has.Count.GreaterThan(0)); + var randomValues = x.Select(o => o.r).Distinct().ToArray(); + Assert.That(randomValues, Has.All.GreaterThanOrEqualTo(1).And.LessThan(11).And.TypeOf()); + + if (!LegacyPreEvaluation && IsFunctionSupported("random") && IsFunctionSupported("floor")) + { + // Naïve randomness check + Assert.That( + randomValues, + Has.Length.GreaterThan(Math.Min(10, x.Count) / 2), + "Generated values do not seem very random"); + } + + // Next requires support of both floor and rand + AssertFunctionInSql(IsFunctionSupported("floor") ? "random" : "floor", spy); + }); + } + private void AssertFunctionInSql(string functionName, SqlLogSpy spy) { if (!IsFunctionSupported(functionName)) diff --git a/src/NHibernate/Dialect/DB2Dialect.cs b/src/NHibernate/Dialect/DB2Dialect.cs index 81c7aae473d..bd24dda28d7 100644 --- a/src/NHibernate/Dialect/DB2Dialect.cs +++ b/src/NHibernate/Dialect/DB2Dialect.cs @@ -80,6 +80,7 @@ public DB2Dialect() RegisterFunction("log10", new StandardSQLFunction("log10", NHibernateUtil.Double)); RegisterFunction("radians", new StandardSQLFunction("radians", NHibernateUtil.Double)); RegisterFunction("rand", new NoArgSQLFunction("rand", NHibernateUtil.Double)); + RegisterFunction("random", new NoArgSQLFunction("rand", NHibernateUtil.Double)); RegisterFunction("sin", new StandardSQLFunction("sin", NHibernateUtil.Double)); RegisterFunction("soundex", new StandardSQLFunction("soundex", NHibernateUtil.String)); RegisterFunction("sqrt", new StandardSQLFunction("sqrt", NHibernateUtil.Double)); diff --git a/src/NHibernate/Dialect/FirebirdDialect.cs b/src/NHibernate/Dialect/FirebirdDialect.cs index 09c7fbf1f89..ba37c00cfaa 100644 --- a/src/NHibernate/Dialect/FirebirdDialect.cs +++ b/src/NHibernate/Dialect/FirebirdDialect.cs @@ -465,6 +465,7 @@ private void RegisterMathematicalFunctions() RegisterFunction("log10", new StandardSQLFunction("log10", NHibernateUtil.Double)); RegisterFunction("pi", new NoArgSQLFunction("pi", NHibernateUtil.Double)); RegisterFunction("rand", new NoArgSQLFunction("rand", NHibernateUtil.Double)); + RegisterFunction("random", new NoArgSQLFunction("rand", NHibernateUtil.Double)); RegisterFunction("sign", new StandardSQLFunction("sign", NHibernateUtil.Int32)); RegisterFunction("sqtr", new StandardSQLFunction("sqtr", NHibernateUtil.Double)); RegisterFunction("trunc", new StandardSQLFunction("trunc")); diff --git a/src/NHibernate/Dialect/HanaDialectBase.cs b/src/NHibernate/Dialect/HanaDialectBase.cs index d94da6f10bc..4a8d70db2e8 100644 --- a/src/NHibernate/Dialect/HanaDialectBase.cs +++ b/src/NHibernate/Dialect/HanaDialectBase.cs @@ -396,6 +396,7 @@ protected virtual void RegisterNHibernateFunctions() RegisterFunction("sysdate", new NoArgSQLFunction("current_timestamp", NHibernateUtil.DateTime, false)); RegisterFunction("truncate", new SQLFunctionTemplateWithRequiredParameters(null, "floor(?1 * power(10, ?2)) / power(10, ?2)", new object[] { null, "0" })); RegisterFunction("new_uuid", new NoArgSQLFunction("sysuuid", NHibernateUtil.Guid, false)); + RegisterFunction("random", new NoArgSQLFunction("rand", NHibernateUtil.Double)); } protected virtual void RegisterHANAFunctions() diff --git a/src/NHibernate/Dialect/MsSql2000Dialect.cs b/src/NHibernate/Dialect/MsSql2000Dialect.cs index aab663e0880..7acb3cb9b4f 100644 --- a/src/NHibernate/Dialect/MsSql2000Dialect.cs +++ b/src/NHibernate/Dialect/MsSql2000Dialect.cs @@ -315,6 +315,8 @@ protected virtual void RegisterFunctions() RegisterFunction("mod", new SQLFunctionTemplate(NHibernateUtil.Int32, "((?1) % (?2))")); RegisterFunction("radians", new StandardSQLFunction("radians", NHibernateUtil.Double)); RegisterFunction("rand", new NoArgSQLFunction("rand", NHibernateUtil.Double)); + // SQL Server rand returns the same value for each row, unless hacking it with a random seed per row + RegisterFunction("random", new SQLFunctionTemplate(NHibernateUtil.Double, "rand(checksum(newid()))")); RegisterFunction("sin", new StandardSQLFunction("sin", NHibernateUtil.Double)); RegisterFunction("soundex", new StandardSQLFunction("soundex", NHibernateUtil.String)); RegisterFunction("sqrt", new StandardSQLFunction("sqrt", NHibernateUtil.Double)); diff --git a/src/NHibernate/Dialect/MySQLDialect.cs b/src/NHibernate/Dialect/MySQLDialect.cs index a1816e95209..a6caa6d236a 100644 --- a/src/NHibernate/Dialect/MySQLDialect.cs +++ b/src/NHibernate/Dialect/MySQLDialect.cs @@ -265,7 +265,8 @@ protected virtual void RegisterFunctions() RegisterFunction("truncate", new StandardSQLFunctionWithRequiredParameters("truncate", new object[] {null, "0"})); RegisterFunction("rand", new NoArgSQLFunction("rand", NHibernateUtil.Double)); - + RegisterFunction("random", new NoArgSQLFunction("rand", NHibernateUtil.Double)); + RegisterFunction("power", new StandardSQLFunction("power", NHibernateUtil.Double)); RegisterFunction("stddev", new StandardSQLFunction("stddev", NHibernateUtil.Double)); diff --git a/src/NHibernate/Dialect/Oracle10gDialect.cs b/src/NHibernate/Dialect/Oracle10gDialect.cs index caab3e1f492..1ad7f135b44 100644 --- a/src/NHibernate/Dialect/Oracle10gDialect.cs +++ b/src/NHibernate/Dialect/Oracle10gDialect.cs @@ -1,3 +1,4 @@ +using NHibernate.Dialect.Function; using NHibernate.SqlCommand; namespace NHibernate.Dialect @@ -16,7 +17,32 @@ public override JoinFragment CreateOuterJoinFragment() return new ANSIJoinFragment(); } + protected override void RegisterFunctions() + { + base.RegisterFunctions(); + + // DBMS_RANDOM package was available in previous versions, but it was requiring initialization and + // was not having the value function. + // It yields a decimal between 0 included and 1 excluded, with 38 significant digits. It sometimes + // causes an overflow when read by the Oracle provider as a .Net Decimal, so better explicitly cast + // it to double. + RegisterFunction("random", new SQLFunctionTemplate(NHibernateUtil.Double, "cast(DBMS_RANDOM.VALUE() as binary_double)")); + } + + /* 6.0 TODO: consider redefining float and double registrations + protected override void RegisterNumericTypeMappings() + { + base.RegisterNumericTypeMappings(); + + // Use binary_float (available since 10g) instead of float. With Oracle, float is a decimal but + // with a precision expressed in number of bytes instead of digits. + RegisterColumnType(DbType.Single, "binary_float"); + // Using binary_double (available since 10g) instead of double precision. With Oracle, double + // precision is a float(126), which is a decimal with a 126 bytes precision. + RegisterColumnType(DbType.Double, "binary_double"); + }*/ + /// public override bool SupportsCrossJoin => true; } -} \ No newline at end of file +} diff --git a/src/NHibernate/Dialect/SQLiteDialect.cs b/src/NHibernate/Dialect/SQLiteDialect.cs index f8124d486c4..864defac9a3 100644 --- a/src/NHibernate/Dialect/SQLiteDialect.cs +++ b/src/NHibernate/Dialect/SQLiteDialect.cs @@ -113,6 +113,16 @@ protected virtual void RegisterFunctions() RegisterFunction("strguid", new SQLFunctionTemplate(NHibernateUtil.String, "substr(hex(?1), 7, 2) || substr(hex(?1), 5, 2) || substr(hex(?1), 3, 2) || substr(hex(?1), 1, 2) || '-' || substr(hex(?1), 11, 2) || substr(hex(?1), 9, 2) || '-' || substr(hex(?1), 15, 2) || substr(hex(?1), 13, 2) || '-' || substr(hex(?1), 17, 4) || '-' || substr(hex(?1), 21) ")); else RegisterFunction("strguid", new SQLFunctionTemplate(NHibernateUtil.String, "cast(?1 as char)")); + + // SQLite random function yields a long, ranging form MinValue to MaxValue. (-9223372036854775808 to + // 9223372036854775807). HQL random requires a float from 0 inclusive to 1 exclusive, so we divide by + // 9223372036854775808 then 2 for having a value between -0.5 included to 0.5 excluded, and finally + // add 0.5. The division is written as "/ 4611686018427387904 / 4" for avoiding overflowing long. + RegisterFunction( + "random", + new SQLFunctionTemplate( + NHibernateUtil.Double, + "(cast(random() as real) / 4611686018427387904 / 4 + 0.5)")); } public override void Configure(IDictionary settings) diff --git a/src/NHibernate/Dialect/SybaseASE15Dialect.cs b/src/NHibernate/Dialect/SybaseASE15Dialect.cs index 0a2800c5bb6..a5233395d36 100644 --- a/src/NHibernate/Dialect/SybaseASE15Dialect.cs +++ b/src/NHibernate/Dialect/SybaseASE15Dialect.cs @@ -98,6 +98,8 @@ public SybaseASE15Dialect() RegisterFunction("pi", new NoArgSQLFunction("pi", NHibernateUtil.Double)); RegisterFunction("radians", new StandardSQLFunction("radians", NHibernateUtil.Double)); RegisterFunction("rand", new StandardSQLFunction("rand", NHibernateUtil.Double)); + // rand returns the same value for each row, rand2 returns a new one for each row. + RegisterFunction("random", new StandardSQLFunction("rand2", NHibernateUtil.Double)); RegisterFunction("reverse", new StandardSQLFunction("reverse")); RegisterFunction("round", new StandardSQLFunction("round")); RegisterFunction("rtrim", new StandardSQLFunction("rtrim")); diff --git a/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs b/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs index 99da6ce9fbd..f5b61250049 100644 --- a/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs +++ b/src/NHibernate/Dialect/SybaseSQLAnywhere10Dialect.cs @@ -142,6 +142,7 @@ protected virtual void RegisterMathFunctions() RegisterFunction("power", new StandardSQLFunction("power", NHibernateUtil.Double)); RegisterFunction("radians", new StandardSQLFunction("radians", NHibernateUtil.Double)); RegisterFunction("rand", new StandardSQLFunction("rand", NHibernateUtil.Double)); + RegisterFunction("random", new StandardSQLFunction("rand", NHibernateUtil.Double)); RegisterFunction("remainder", new StandardSQLFunction("remainder")); RegisterFunction("round", new StandardSQLFunctionWithRequiredParameters("round", new object[] {null, "0"})); RegisterFunction("sign", new StandardSQLFunction("sign", NHibernateUtil.Int32)); diff --git a/src/NHibernate/Linq/Functions/DateTimeNowHqlGenerator.cs b/src/NHibernate/Linq/Functions/DateTimeNowHqlGenerator.cs index ef2c154b81d..9039693fc1c 100644 --- a/src/NHibernate/Linq/Functions/DateTimeNowHqlGenerator.cs +++ b/src/NHibernate/Linq/Functions/DateTimeNowHqlGenerator.cs @@ -70,5 +70,11 @@ public bool AllowPreEvaluation(MemberInfo member, ISessionFactoryImplementor fac $"({Environment.LinqToHqlFallbackOnPreEvaluation}) or evaluate {member.Name} " + "outside of the query."); } + + public bool IgnoreInstance(MemberInfo member) + { + // They are all static properties + return true; + } } } diff --git a/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs b/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs index 0e921dff7eb..29595877d9f 100644 --- a/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs +++ b/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs @@ -59,6 +59,7 @@ public DefaultLinqToHqlGeneratorsRegistry() this.Merge(new DateTimeNowHqlGenerator()); this.Merge(new NewGuidHqlGenerator()); + this.Merge(new RandomHqlGenerator()); this.Merge(new DecimalAddGenerator()); this.Merge(new DecimalDivideGenerator()); diff --git a/src/NHibernate/Linq/Functions/IAllowPreEvaluationHqlGenerator.cs b/src/NHibernate/Linq/Functions/IAllowPreEvaluationHqlGenerator.cs index f15afa7c7bc..2cd67b7d2d1 100644 --- a/src/NHibernate/Linq/Functions/IAllowPreEvaluationHqlGenerator.cs +++ b/src/NHibernate/Linq/Functions/IAllowPreEvaluationHqlGenerator.cs @@ -20,5 +20,15 @@ public interface IAllowPreEvaluationHqlGenerator /// a function which value on server side can differ from the equivalent client value, like /// . bool AllowPreEvaluation(MemberInfo member, ISessionFactoryImplementor factory); + + /// + /// Should the instance holding the property or method be ignored? + /// + /// The property or method. + /// + /// if the property or method translation does not depend on the instance to which it + /// belongs, otherwise. + /// + bool IgnoreInstance(MemberInfo member); } } diff --git a/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs b/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs index c4b871e6380..73ad8b3d9e4 100644 --- a/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs +++ b/src/NHibernate/Linq/Functions/IHqlGeneratorForMethod.cs @@ -55,5 +55,22 @@ public static bool AllowPreEvaluation( // By default, everything should be pre-evaluated whenever possible. return true; } + + /// + /// Should the instance holding the method be ignored? + /// + /// The method's HQL generator. + /// The method. + /// + /// if the method translation does not depend on the instance to which it + /// belongs, otherwise. + /// + public static bool IgnoreInstance(this IHqlGeneratorForMethod generator, MemberInfo member) + { + if (generator is IAllowPreEvaluationHqlGenerator allowPreEvalGenerator) + return allowPreEvalGenerator.IgnoreInstance(member); + + return false; + } } } diff --git a/src/NHibernate/Linq/Functions/NewGuidHqlGenerator.cs b/src/NHibernate/Linq/Functions/NewGuidHqlGenerator.cs index a33720aed55..2318d19f51b 100644 --- a/src/NHibernate/Linq/Functions/NewGuidHqlGenerator.cs +++ b/src/NHibernate/Linq/Functions/NewGuidHqlGenerator.cs @@ -44,5 +44,11 @@ public bool AllowPreEvaluation(MemberInfo member, ISessionFactoryImplementor fac $"({Environment.LinqToHqlFallbackOnPreEvaluation}) or evaluate NewGuid " + "outside of the query."); } + + public bool IgnoreInstance(MemberInfo member) + { + // There is only a static method + return true; + } } } diff --git a/src/NHibernate/Linq/Functions/RandomHqlGenerator.cs b/src/NHibernate/Linq/Functions/RandomHqlGenerator.cs new file mode 100644 index 00000000000..bde4895e1e5 --- /dev/null +++ b/src/NHibernate/Linq/Functions/RandomHqlGenerator.cs @@ -0,0 +1,109 @@ +using System; +using System.Collections.ObjectModel; +using System.Linq.Expressions; +using System.Reflection; +using NHibernate.Engine; +using NHibernate.Hql.Ast; +using NHibernate.Linq.Visitors; +using NHibernate.Util; +using Environment = NHibernate.Cfg.Environment; + +namespace NHibernate.Linq.Functions +{ + public class RandomHqlGenerator : BaseHqlGeneratorForMethod, IAllowPreEvaluationHqlGenerator + { + private readonly MethodInfo _nextDouble = ReflectHelper.GetMethod(r => r.NextDouble()); + private const string _randomFunctionName = "random"; + private const string _floorFunctionName = "floor"; + + public RandomHqlGenerator() + { + SupportedMethods = new[] + { + _nextDouble, + ReflectHelper.GetMethod(r => r.Next()), + ReflectHelper.GetMethod(r => r.Next(2)), + ReflectHelper.GetMethod(r => r.Next(-1, 1)) + }; + } + + public override HqlTreeNode BuildHql( + MethodInfo method, + Expression targetObject, + ReadOnlyCollection arguments, + HqlTreeBuilder treeBuilder, + IHqlExpressionVisitor visitor) + { + if (method == _nextDouble) + return treeBuilder.MethodCall(_randomFunctionName); + + switch (arguments.Count) + { + case 0: + return treeBuilder.Cast( + treeBuilder.MethodCall( + _floorFunctionName, + treeBuilder.Multiply( + treeBuilder.MethodCall(_randomFunctionName), + treeBuilder.Constant(int.MaxValue))), + typeof(int)); + case 1: + return treeBuilder.Cast( + treeBuilder.MethodCall( + _floorFunctionName, + treeBuilder.Multiply( + treeBuilder.MethodCall(_randomFunctionName), + visitor.Visit(arguments[0]).AsExpression())), + typeof(int)); + case 2: + var minValue = visitor.Visit(arguments[0]).AsExpression(); + var maxValue = visitor.Visit(arguments[1]).AsExpression(); + return treeBuilder.Cast( + treeBuilder.Add( + treeBuilder.MethodCall( + _floorFunctionName, + treeBuilder.Multiply( + treeBuilder.MethodCall(_randomFunctionName), + treeBuilder.Subtract(maxValue, minValue))), + minValue), + typeof(int)); + default: + throw new NotSupportedException(); + } + } + + /// + public bool AllowPreEvaluation(MemberInfo member, ISessionFactoryImplementor factory) + { + if (factory.Dialect.Functions.ContainsKey(_randomFunctionName) && + (member == _nextDouble || factory.Dialect.Functions.ContainsKey(_floorFunctionName))) + return false; + + if (factory.Settings.LinqToHqlFallbackOnPreEvaluation) + return true; + + var functionName = factory.Dialect.Functions.ContainsKey(_randomFunctionName) + ? _floorFunctionName + : _randomFunctionName; + throw new QueryException( + $"Cannot translate {member.DeclaringType.Name}.{member.Name}: {functionName} is " + + $"not supported by {factory.Dialect}. Either enable the fallback on pre-evaluation " + + $"({Environment.LinqToHqlFallbackOnPreEvaluation}) or evaluate {member.Name} " + + "outside of the query."); + } + + /// + public bool IgnoreInstance(MemberInfo member) + { + // The translation ignores the Random instance, so long if it was specifically seeded: the user should + // pass the random value as a local variable in the Linq query in such case. + // Returning false here would cause the method, when appearing in a select clause, to be post-evaluated. + // Contrary to pre-evaluation, the post-evaluation is done for each row so it at least would avoid having + // the same random value for each result. + // But that would still be not executed in database which would be unexpected, in my opinion. + // It would even cause failures if the random instance used for querying is shared among threads or is + // too similarly seeded between queries. + return true; + } + } +} diff --git a/src/NHibernate/Linq/Visitors/SelectClauseNominator.cs b/src/NHibernate/Linq/Visitors/SelectClauseNominator.cs index 085d6681926..9205b4f1b05 100644 --- a/src/NHibernate/Linq/Visitors/SelectClauseNominator.cs +++ b/src/NHibernate/Linq/Visitors/SelectClauseNominator.cs @@ -114,11 +114,14 @@ private bool IsRegisteredFunction(Expression expression) if (expression.NodeType == ExpressionType.Call) { var methodCallExpression = (MethodCallExpression) expression; - IHqlGeneratorForMethod methodGenerator; - if (_functionRegistry.TryGetGenerator(methodCallExpression.Method, out methodGenerator)) + if (_functionRegistry.TryGetGenerator(methodCallExpression.Method, out var methodGenerator)) { - return methodCallExpression.Object == null || // is static or extension method - methodCallExpression.Object.NodeType != ExpressionType.Constant; // does not belong to parameter + // is static or extension method + return methodCallExpression.Object == null || + // does not belong to parameter + methodCallExpression.Object.NodeType != ExpressionType.Constant || + // does not ignore the parameter it belongs to + methodGenerator.IgnoreInstance(methodCallExpression.Method); } } else if (expression is NhSumExpression || From f4b8bd77d1814ed2679632416f310c3af8038bcb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Delaporte?= <12201973+fredericdelaporte@users.noreply.github.com> Date: Sun, 22 Mar 2020 18:56:10 +0100 Subject: [PATCH 13/43] Fix some NRE in work isolation and connection handling Fixes #2336 --- src/NHibernate/Async/Transaction/AdoNetTransactionFactory.cs | 2 +- src/NHibernate/Connection/ConnectionProvider.cs | 3 +++ src/NHibernate/Transaction/AdoNetTransactionFactory.cs | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/NHibernate/Async/Transaction/AdoNetTransactionFactory.cs b/src/NHibernate/Async/Transaction/AdoNetTransactionFactory.cs index ec3ea6bc578..22f50f1876a 100644 --- a/src/NHibernate/Async/Transaction/AdoNetTransactionFactory.cs +++ b/src/NHibernate/Async/Transaction/AdoNetTransactionFactory.cs @@ -126,7 +126,7 @@ async Task InternalExecuteWorkInIsolationAsync() isolaterLog.Warn(ignore, "Unable to dispose transaction"); } - if (session.Factory.Dialect is SQLiteDialect == false) + if (connection != null && session.Factory.Dialect is SQLiteDialect == false) session.Factory.ConnectionProvider.CloseConnection(connection); } } diff --git a/src/NHibernate/Connection/ConnectionProvider.cs b/src/NHibernate/Connection/ConnectionProvider.cs index a49012c1262..e30c2fdf21d 100644 --- a/src/NHibernate/Connection/ConnectionProvider.cs +++ b/src/NHibernate/Connection/ConnectionProvider.cs @@ -24,6 +24,9 @@ public abstract partial class ConnectionProvider : IConnectionProvider /// The to clean up. public virtual void CloseConnection(DbConnection conn) { + if (conn == null) + throw new ArgumentNullException(nameof(conn)); + log.Debug("Closing connection"); try { diff --git a/src/NHibernate/Transaction/AdoNetTransactionFactory.cs b/src/NHibernate/Transaction/AdoNetTransactionFactory.cs index e1d485aeecb..7e7ef81e2e5 100644 --- a/src/NHibernate/Transaction/AdoNetTransactionFactory.cs +++ b/src/NHibernate/Transaction/AdoNetTransactionFactory.cs @@ -136,7 +136,7 @@ public virtual void ExecuteWorkInIsolation(ISessionImplementor session, IIsolate isolaterLog.Warn(ignore, "Unable to dispose transaction"); } - if (session.Factory.Dialect is SQLiteDialect == false) + if (connection != null && session.Factory.Dialect is SQLiteDialect == false) session.Factory.ConnectionProvider.CloseConnection(connection); } } From 1af6757f5ab72d7bf16b84856b0a1d0b4fb5385f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Delaporte?= <12201973+fredericdelaporte@users.noreply.github.com> Date: Sun, 22 Mar 2020 19:34:19 +0100 Subject: [PATCH 14/43] Clean up a bit isolation work code --- .../Transaction/AdoNetTransactionFactory.cs | 56 +++++-------------- .../Transaction/AdoNetTransactionFactory.cs | 54 +++++------------- 2 files changed, 29 insertions(+), 81 deletions(-) diff --git a/src/NHibernate/Async/Transaction/AdoNetTransactionFactory.cs b/src/NHibernate/Async/Transaction/AdoNetTransactionFactory.cs index 22f50f1876a..76466173109 100644 --- a/src/NHibernate/Async/Transaction/AdoNetTransactionFactory.cs +++ b/src/NHibernate/Async/Transaction/AdoNetTransactionFactory.cs @@ -16,7 +16,6 @@ using NHibernate.Engine; using NHibernate.Engine.Transaction; using NHibernate.Exceptions; -using NHibernate.Impl; namespace NHibernate.Transaction { @@ -42,25 +41,17 @@ async Task InternalExecuteWorkInIsolationAsync() DbConnection connection = null; DbTransaction trans = null; - // bool wasAutoCommit = false; try { // We make an exception for SQLite and use the session's connection, // since SQLite only allows one connection to the database. - if (session.Factory.Dialect is SQLiteDialect) - connection = session.Connection; - else - connection = await (session.Factory.ConnectionProvider.GetConnectionAsync(cancellationToken)).ConfigureAwait(false); + connection = session.Factory.Dialect is SQLiteDialect + ? session.Connection + : await (session.Factory.ConnectionProvider.GetConnectionAsync(cancellationToken)).ConfigureAwait(false); if (transacted) { trans = connection.BeginTransaction(); - // TODO NH: a way to read the autocommit state is needed - //if (TransactionManager.GetAutoCommit(connection)) - //{ - // wasAutoCommit = true; - // TransactionManager.SetAutoCommit(connection, false); - //} } await (work.DoWorkAsync(connection, trans, cancellationToken)).ConfigureAwait(false); @@ -68,7 +59,6 @@ async Task InternalExecuteWorkInIsolationAsync() if (transacted) { trans.Commit(); - //TransactionManager.Commit(connection); } } catch (Exception t) @@ -84,46 +74,30 @@ async Task InternalExecuteWorkInIsolationAsync() } catch (Exception ignore) { - isolaterLog.Debug(ignore, "Unable to rollback transaction"); + _isolatorLog.Debug(ignore, "Unable to rollback transaction"); } - if (t is HibernateException) - { - throw; - } - else if (t is DbException) - { - throw ADOExceptionHelper.Convert(session.Factory.SQLExceptionConverter, t, - "error performing isolated work"); - } - else - { - throw new HibernateException("error performing isolated work", t); - } + switch (t) + { + case HibernateException _: + throw; + case DbException _: + throw ADOExceptionHelper.Convert(session.Factory.SQLExceptionConverter, t, + "error performing isolated work"); + default: + throw new HibernateException("error performing isolated work", t); + } } } finally { - //if (transacted && wasAutoCommit) - //{ - // try - // { - // // TODO NH: reset autocommit - // // TransactionManager.SetAutoCommit(connection, true); - // } - // catch (Exception) - // { - // log.Debug("was unable to reset connection back to auto-commit"); - // } - //} - try { trans?.Dispose(); } catch (Exception ignore) { - isolaterLog.Warn(ignore, "Unable to dispose transaction"); + _isolatorLog.Warn(ignore, "Unable to dispose transaction"); } if (connection != null && session.Factory.Dialect is SQLiteDialect == false) diff --git a/src/NHibernate/Transaction/AdoNetTransactionFactory.cs b/src/NHibernate/Transaction/AdoNetTransactionFactory.cs index 7e7ef81e2e5..5fe50f06757 100644 --- a/src/NHibernate/Transaction/AdoNetTransactionFactory.cs +++ b/src/NHibernate/Transaction/AdoNetTransactionFactory.cs @@ -6,7 +6,6 @@ using NHibernate.Engine; using NHibernate.Engine.Transaction; using NHibernate.Exceptions; -using NHibernate.Impl; namespace NHibernate.Transaction { @@ -16,7 +15,7 @@ namespace NHibernate.Transaction /// public partial class AdoNetTransactionFactory : ITransactionFactory { - private readonly INHibernateLogger isolaterLog = NHibernateLogger.For(typeof(ITransactionFactory)); + private static readonly INHibernateLogger _isolatorLog = NHibernateLogger.For(typeof(ITransactionFactory)); /// public virtual ITransaction CreateTransaction(ISessionImplementor session) @@ -52,25 +51,17 @@ public virtual void ExecuteWorkInIsolation(ISessionImplementor session, IIsolate DbConnection connection = null; DbTransaction trans = null; - // bool wasAutoCommit = false; try { // We make an exception for SQLite and use the session's connection, // since SQLite only allows one connection to the database. - if (session.Factory.Dialect is SQLiteDialect) - connection = session.Connection; - else - connection = session.Factory.ConnectionProvider.GetConnection(); + connection = session.Factory.Dialect is SQLiteDialect + ? session.Connection + : session.Factory.ConnectionProvider.GetConnection(); if (transacted) { trans = connection.BeginTransaction(); - // TODO NH: a way to read the autocommit state is needed - //if (TransactionManager.GetAutoCommit(connection)) - //{ - // wasAutoCommit = true; - // TransactionManager.SetAutoCommit(connection, false); - //} } work.DoWork(connection, trans); @@ -78,7 +69,6 @@ public virtual void ExecuteWorkInIsolation(ISessionImplementor session, IIsolate if (transacted) { trans.Commit(); - //TransactionManager.Commit(connection); } } catch (Exception t) @@ -94,46 +84,30 @@ public virtual void ExecuteWorkInIsolation(ISessionImplementor session, IIsolate } catch (Exception ignore) { - isolaterLog.Debug(ignore, "Unable to rollback transaction"); + _isolatorLog.Debug(ignore, "Unable to rollback transaction"); } - if (t is HibernateException) + switch (t) { - throw; - } - else if (t is DbException) - { - throw ADOExceptionHelper.Convert(session.Factory.SQLExceptionConverter, t, - "error performing isolated work"); - } - else - { - throw new HibernateException("error performing isolated work", t); + case HibernateException _: + throw; + case DbException _: + throw ADOExceptionHelper.Convert(session.Factory.SQLExceptionConverter, t, + "error performing isolated work"); + default: + throw new HibernateException("error performing isolated work", t); } } } finally { - //if (transacted && wasAutoCommit) - //{ - // try - // { - // // TODO NH: reset autocommit - // // TransactionManager.SetAutoCommit(connection, true); - // } - // catch (Exception) - // { - // log.Debug("was unable to reset connection back to auto-commit"); - // } - //} - try { trans?.Dispose(); } catch (Exception ignore) { - isolaterLog.Warn(ignore, "Unable to dispose transaction"); + _isolatorLog.Warn(ignore, "Unable to dispose transaction"); } if (connection != null && session.Factory.Dialect is SQLiteDialect == false) From 9e12fb4d13b06fa581068371ce3d34eeb08c28e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Delaporte?= <12201973+fredericDelaporte@users.noreply.github.com> Date: Sun, 29 Mar 2020 16:49:09 +0200 Subject: [PATCH 15/43] Support setting parameters with a dynamic object (#2321) Fixes #2009 --- .../Async/Legacy/SQLFunctionsTest.cs | 108 ++++++++++++++++++ .../Async/Legacy/SQLLoaderTest.cs | 59 ++++++++++ .../Legacy/SQLFunctionsTest.cs | 108 ++++++++++++++++++ src/NHibernate.Test/Legacy/SQLLoaderTest.cs | 59 ++++++++++ src/NHibernate/Impl/AbstractQueryImpl.cs | 50 ++++++++ 5 files changed, 384 insertions(+) diff --git a/src/NHibernate.Test/Async/Legacy/SQLFunctionsTest.cs b/src/NHibernate.Test/Async/Legacy/SQLFunctionsTest.cs index 5acc3e36f8f..9ab540d2e8f 100644 --- a/src/NHibernate.Test/Async/Legacy/SQLFunctionsTest.cs +++ b/src/NHibernate.Test/Async/Legacy/SQLFunctionsTest.cs @@ -11,6 +11,7 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Dynamic; using log4net; using NHibernate.Dialect; using NHibernate.Dialect.Function; @@ -131,6 +132,113 @@ public async Task SetPropertiesAsync() s.Close(); } + [Test] + public async Task SetParametersWithDictionaryAsync() + { + using (var s = OpenSession()) + using (var t = s.BeginTransaction()) + { + var simple = new Simple { Name = "Simple 1" }; + await (s.SaveAsync(simple, 10L)); + var q = s.CreateQuery("from s in class Simple where s.Name = :Name and s.Count = :Count"); + var parameters = new Dictionary + { + { nameof(simple.Name), simple.Name }, + { nameof(simple.Count), simple.Count }, + }; + q.SetProperties(parameters); + var results = await (q.ListAsync()); + Assert.That(results, Has.One.EqualTo(simple)); + await (s.DeleteAsync(simple)); + await (t.CommitAsync()); + } + } + + [Test] + public async Task SetParametersWithHashtableAsync() + { + using (var s = OpenSession()) + using (var t = s.BeginTransaction()) + { + var simple = new Simple { Name = "Simple 1" }; + await (s.SaveAsync(simple, 10L)); + var q = s.CreateQuery("from s in class Simple where s.Name = :Name and (s.Address = :Address or :Address is null and s.Address is null)"); + var parameters = new Hashtable + { + { nameof(simple.Name), simple.Name }, + { nameof(simple.Address), simple.Address }, + }; + q.SetProperties(parameters); + var results = await (q.ListAsync()); + Assert.That(results, Has.One.EqualTo(simple)); + await (s.DeleteAsync(simple)); + await (t.CommitAsync()); + } + } + + [Test] + public async Task SetParametersWithDynamicAsync() + { + using (var s = OpenSession()) + using (var t = s.BeginTransaction()) + { + var simple = new Simple { Name = "Simple 1" }; + await (s.SaveAsync(simple, 10L)); + var q = s.CreateQuery("from s in class Simple where s.Name = :Name and s.Count = :Count"); + dynamic parameters = new ExpandoObject(); + parameters.Name = simple.Name; + parameters.Count = simple.Count; + q.SetProperties(parameters); + var results = await (q.ListAsync()); + Assert.That(results, Has.One.EqualTo(simple)); + await (s.DeleteAsync(simple)); + await (t.CommitAsync()); + } + } + + [Test] + public async Task SetNullParameterWithDictionaryAsync() + { + using (var s = OpenSession()) + using (var t = s.BeginTransaction()) + { + var simple = new Simple { Name = "Simple 1" }; + await (s.SaveAsync(simple, 10L)); + var q = s.CreateQuery("from s in class Simple where s.Name = :Name and (s.Address = :Address or :Address is null and s.Address is null)"); + var parameters = new Dictionary + { + { nameof(simple.Name), simple.Name }, + { nameof(simple.Address), null }, + }; + q.SetProperties(parameters); + var results = await (q.ListAsync()); + Assert.That(results, Has.One.EqualTo(simple)); + await (s.DeleteAsync(simple)); + await (t.CommitAsync()); + } + } + + [Test] + public async Task SetParameterListWithDictionaryAsync() + { + using (var s = OpenSession()) + using (var t = s.BeginTransaction()) + { + var simple = new Simple { Name = "Simple 1" }; + await (s.SaveAsync(simple, 10L)); + var q = s.CreateQuery("from s in class Simple where s.Name in (:Name)"); + var parameters = new Dictionary + { + { nameof(simple.Name), new [] {simple.Name} } + }; + q.SetProperties(parameters); + var results = await (q.ListAsync()); + Assert.That(results, Has.One.EqualTo(simple)); + await (s.DeleteAsync(simple)); + await (t.CommitAsync()); + } + } + [Test] public async Task BrokenAsync() { diff --git a/src/NHibernate.Test/Async/Legacy/SQLLoaderTest.cs b/src/NHibernate.Test/Async/Legacy/SQLLoaderTest.cs index fc607833a7c..3d9119be986 100644 --- a/src/NHibernate.Test/Async/Legacy/SQLLoaderTest.cs +++ b/src/NHibernate.Test/Async/Legacy/SQLLoaderTest.cs @@ -11,6 +11,7 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Dynamic; using NHibernate.Dialect; using NHibernate.DomainModel; using NUnit.Framework; @@ -152,6 +153,64 @@ public async Task FindBySQLPropertiesAsync() session.Close(); } + [Test] + public async Task FindBySQLDictionaryAsync() + { + using (var session = OpenSession()) + using (var tran = session.BeginTransaction()) + { + var s = new Category { Name = nextLong.ToString() }; + nextLong++; + await (session.SaveAsync(s)); + + s = new Category { Name = "WannaBeFound" }; + await (session.FlushAsync()); + + var query = + session.CreateSQLQuery("select {category.*} from Category {category} where {category}.Name = :Name") + .AddEntity("category", typeof(Category)); + var parameters = new Dictionary + { + { nameof(s.Name), s.Name } + }; + query.SetProperties(parameters); + var results = await (query.ListAsync()); + Assert.That(results, Is.Empty); + + await (session.DeleteAsync("from Category")); + await (tran.CommitAsync()); + } + } + + [Test] + public async Task FindBySQLDynamicAsync() + { + using (var session = OpenSession()) + using (var tran = session.BeginTransaction()) + { + var s = new Category { Name = nextLong.ToString() }; + nextLong++; + await (session.SaveAsync(s)); + + s = new Category { Name = "WannaBeFound" }; + await (session.FlushAsync()); + + var query = + session.CreateSQLQuery("select {category.*} from Category {category} where {category}.Name = :Name") + .AddEntity("category", typeof(Category)); + dynamic parameters = new ExpandoObject(); + parameters.Name = s.Name; + // dynamic does not work on inherited interface method calls. https://stackoverflow.com/q/3071634 + IQuery q = query; + q.SetProperties(parameters); + var results = await (query.ListAsync()); + Assert.That(results, Is.Empty); + + await (session.DeleteAsync("from Category")); + await (tran.CommitAsync()); + } + } + [Test] public async Task FindBySQLAssociatedObjectAsync() { diff --git a/src/NHibernate.Test/Legacy/SQLFunctionsTest.cs b/src/NHibernate.Test/Legacy/SQLFunctionsTest.cs index 8a2fcd8a6cd..1a75c230105 100644 --- a/src/NHibernate.Test/Legacy/SQLFunctionsTest.cs +++ b/src/NHibernate.Test/Legacy/SQLFunctionsTest.cs @@ -1,6 +1,7 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Dynamic; using log4net; using NHibernate.Dialect; using NHibernate.Dialect.Function; @@ -137,6 +138,113 @@ public void SetProperties() s.Close(); } + [Test] + public void SetParametersWithDictionary() + { + using (var s = OpenSession()) + using (var t = s.BeginTransaction()) + { + var simple = new Simple { Name = "Simple 1" }; + s.Save(simple, 10L); + var q = s.CreateQuery("from s in class Simple where s.Name = :Name and s.Count = :Count"); + var parameters = new Dictionary + { + { nameof(simple.Name), simple.Name }, + { nameof(simple.Count), simple.Count }, + }; + q.SetProperties(parameters); + var results = q.List(); + Assert.That(results, Has.One.EqualTo(simple)); + s.Delete(simple); + t.Commit(); + } + } + + [Test] + public void SetParametersWithHashtable() + { + using (var s = OpenSession()) + using (var t = s.BeginTransaction()) + { + var simple = new Simple { Name = "Simple 1" }; + s.Save(simple, 10L); + var q = s.CreateQuery("from s in class Simple where s.Name = :Name and (s.Address = :Address or :Address is null and s.Address is null)"); + var parameters = new Hashtable + { + { nameof(simple.Name), simple.Name }, + { nameof(simple.Address), simple.Address }, + }; + q.SetProperties(parameters); + var results = q.List(); + Assert.That(results, Has.One.EqualTo(simple)); + s.Delete(simple); + t.Commit(); + } + } + + [Test] + public void SetParametersWithDynamic() + { + using (var s = OpenSession()) + using (var t = s.BeginTransaction()) + { + var simple = new Simple { Name = "Simple 1" }; + s.Save(simple, 10L); + var q = s.CreateQuery("from s in class Simple where s.Name = :Name and s.Count = :Count"); + dynamic parameters = new ExpandoObject(); + parameters.Name = simple.Name; + parameters.Count = simple.Count; + q.SetProperties(parameters); + var results = q.List(); + Assert.That(results, Has.One.EqualTo(simple)); + s.Delete(simple); + t.Commit(); + } + } + + [Test] + public void SetNullParameterWithDictionary() + { + using (var s = OpenSession()) + using (var t = s.BeginTransaction()) + { + var simple = new Simple { Name = "Simple 1" }; + s.Save(simple, 10L); + var q = s.CreateQuery("from s in class Simple where s.Name = :Name and (s.Address = :Address or :Address is null and s.Address is null)"); + var parameters = new Dictionary + { + { nameof(simple.Name), simple.Name }, + { nameof(simple.Address), null }, + }; + q.SetProperties(parameters); + var results = q.List(); + Assert.That(results, Has.One.EqualTo(simple)); + s.Delete(simple); + t.Commit(); + } + } + + [Test] + public void SetParameterListWithDictionary() + { + using (var s = OpenSession()) + using (var t = s.BeginTransaction()) + { + var simple = new Simple { Name = "Simple 1" }; + s.Save(simple, 10L); + var q = s.CreateQuery("from s in class Simple where s.Name in (:Name)"); + var parameters = new Dictionary + { + { nameof(simple.Name), new [] {simple.Name} } + }; + q.SetProperties(parameters); + var results = q.List(); + Assert.That(results, Has.One.EqualTo(simple)); + s.Delete(simple); + t.Commit(); + } + } + [Test] public void Broken() { diff --git a/src/NHibernate.Test/Legacy/SQLLoaderTest.cs b/src/NHibernate.Test/Legacy/SQLLoaderTest.cs index ba2429530b4..5818a4f9909 100644 --- a/src/NHibernate.Test/Legacy/SQLLoaderTest.cs +++ b/src/NHibernate.Test/Legacy/SQLLoaderTest.cs @@ -1,6 +1,7 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Dynamic; using NHibernate.Dialect; using NHibernate.DomainModel; using NUnit.Framework; @@ -140,6 +141,64 @@ public void FindBySQLProperties() session.Close(); } + [Test] + public void FindBySQLDictionary() + { + using (var session = OpenSession()) + using (var tran = session.BeginTransaction()) + { + var s = new Category { Name = nextLong.ToString() }; + nextLong++; + session.Save(s); + + s = new Category { Name = "WannaBeFound" }; + session.Flush(); + + var query = + session.CreateSQLQuery("select {category.*} from Category {category} where {category}.Name = :Name") + .AddEntity("category", typeof(Category)); + var parameters = new Dictionary + { + { nameof(s.Name), s.Name } + }; + query.SetProperties(parameters); + var results = query.List(); + Assert.That(results, Is.Empty); + + session.Delete("from Category"); + tran.Commit(); + } + } + + [Test] + public void FindBySQLDynamic() + { + using (var session = OpenSession()) + using (var tran = session.BeginTransaction()) + { + var s = new Category { Name = nextLong.ToString() }; + nextLong++; + session.Save(s); + + s = new Category { Name = "WannaBeFound" }; + session.Flush(); + + var query = + session.CreateSQLQuery("select {category.*} from Category {category} where {category}.Name = :Name") + .AddEntity("category", typeof(Category)); + dynamic parameters = new ExpandoObject(); + parameters.Name = s.Name; + // dynamic does not work on inherited interface method calls. https://stackoverflow.com/q/3071634 + IQuery q = query; + q.SetProperties(parameters); + var results = query.List(); + Assert.That(results, Is.Empty); + + session.Delete("from Category"); + tran.Commit(); + } + } + [Test] public void FindBySQLAssociatedObject() { diff --git a/src/NHibernate/Impl/AbstractQueryImpl.cs b/src/NHibernate/Impl/AbstractQueryImpl.cs index ba1e1402854..cd078ca7dd3 100644 --- a/src/NHibernate/Impl/AbstractQueryImpl.cs +++ b/src/NHibernate/Impl/AbstractQueryImpl.cs @@ -651,6 +651,8 @@ public IQuery SetEnum(string name, Enum val) return this; } + // Since 5.3 + [Obsolete("This method was never surfaced to a query interface. Use the overload taking an object instead, and supply to it a generic IDictionary.")] public IQuery SetProperties(IDictionary map) { string[] @params = NamedParameters; @@ -674,8 +676,56 @@ public IQuery SetProperties(IDictionary map) return this; } + private IQuery SetParameters(IDictionary map) + { + foreach (var namedParam in NamedParameters) + { + if (map.TryGetValue(namedParam, out var obj)) + { + switch (obj) + { + case IEnumerable enumerable when !(enumerable is string): + SetParameterList(namedParam, enumerable); + break; + default: + SetParameter(namedParam, obj); + break; + } + } + } + return this; + } + + private IQuery SetParameters(IDictionary map) + { + foreach (var namedParam in NamedParameters) + { + var obj = map[namedParam]; + switch (obj) + { + case IEnumerable enumerable when !(enumerable is string): + SetParameterList(namedParam, enumerable); + break; + case null when map.Contains(namedParam): + default: + SetParameter(namedParam, obj); + break; + } + } + return this; + } + public IQuery SetProperties(object bean) { + if (bean is IDictionary map) + { + return SetParameters(map); + } + if (bean is IDictionary hashtable) + { + return SetParameters(hashtable); + } + System.Type clazz = bean.GetType(); string[] @params = NamedParameters; for (int i = 0; i < @params.Length; i++) From 89fda9673f877d041a2d58113aad47cf5215206b Mon Sep 17 00:00:00 2001 From: maca88 Date: Fri, 6 Mar 2020 19:06:56 +0100 Subject: [PATCH 16/43] Fix Linq Fetch/FetchMany after SelectMany method --- .../Async/Linq/EagerLoadTests.cs | 42 +++++++++++++++++++ src/NHibernate.Test/Linq/EagerLoadTests.cs | 42 +++++++++++++++++++ .../ResultOperatorProcessors/ProcessFetch.cs | 3 +- 3 files changed, 86 insertions(+), 1 deletion(-) diff --git a/src/NHibernate.Test/Async/Linq/EagerLoadTests.cs b/src/NHibernate.Test/Async/Linq/EagerLoadTests.cs index 47095f66ba7..655b0b41862 100644 --- a/src/NHibernate.Test/Async/Linq/EagerLoadTests.cs +++ b/src/NHibernate.Test/Async/Linq/EagerLoadTests.cs @@ -34,6 +34,48 @@ public async Task CanSelectAndFetchAsync() Assert.IsTrue(NHibernateUtil.IsInitialized(result[0].Orders)); } + [Test] + public async Task CanSelectAndFetchManyAsync() + { + var result = await (db.OrderLines + .Select(o => o.Product) + .FetchMany(o => o.OrderLines) + .ToListAsync()); + + session.Close(); + + Assert.IsNotEmpty(result); + Assert.IsTrue(NHibernateUtil.IsInitialized(result[0].OrderLines)); + } + + [Test] + public async Task CanSelectManyAndFetchAsync() + { + var result = await (db.Orders + .SelectMany(o => o.OrderLines) + .Fetch(o => o.Product) + .ToListAsync()); + + session.Close(); + + Assert.IsNotEmpty(result); + Assert.IsTrue(NHibernateUtil.IsInitialized(result[0].Product)); + } + + [Test] + public async Task CanSelectManyAndFetchManyAsync() + { + var result = await (db.Employees + .SelectMany(o => o.Orders) + .FetchMany(o => o.OrderLines) + .ToListAsync()); + + session.Close(); + + Assert.IsNotEmpty(result); + Assert.IsTrue(NHibernateUtil.IsInitialized(result[0].OrderLines)); + } + [Test] public async Task CanSelectAndFetchHqlAsync() { diff --git a/src/NHibernate.Test/Linq/EagerLoadTests.cs b/src/NHibernate.Test/Linq/EagerLoadTests.cs index d813e76a92f..3f8a680b266 100644 --- a/src/NHibernate.Test/Linq/EagerLoadTests.cs +++ b/src/NHibernate.Test/Linq/EagerLoadTests.cs @@ -23,6 +23,48 @@ public void CanSelectAndFetch() Assert.IsTrue(NHibernateUtil.IsInitialized(result[0].Orders)); } + [Test] + public void CanSelectAndFetchMany() + { + var result = db.OrderLines + .Select(o => o.Product) + .FetchMany(o => o.OrderLines) + .ToList(); + + session.Close(); + + Assert.IsNotEmpty(result); + Assert.IsTrue(NHibernateUtil.IsInitialized(result[0].OrderLines)); + } + + [Test] + public void CanSelectManyAndFetch() + { + var result = db.Orders + .SelectMany(o => o.OrderLines) + .Fetch(o => o.Product) + .ToList(); + + session.Close(); + + Assert.IsNotEmpty(result); + Assert.IsTrue(NHibernateUtil.IsInitialized(result[0].Product)); + } + + [Test] + public void CanSelectManyAndFetchMany() + { + var result = db.Employees + .SelectMany(o => o.Orders) + .FetchMany(o => o.OrderLines) + .ToList(); + + session.Close(); + + Assert.IsNotEmpty(result); + Assert.IsTrue(NHibernateUtil.IsInitialized(result[0].OrderLines)); + } + [Test] public void CanSelectAndFetchHql() { diff --git a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessFetch.cs b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessFetch.cs index 102d1cb6dc2..4eb9a2cbcd6 100644 --- a/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessFetch.cs +++ b/src/NHibernate/Linq/Visitors/ResultOperatorProcessors/ProcessFetch.cs @@ -13,8 +13,9 @@ public void Process(FetchRequestBase resultOperator, QueryModelVisitor queryMode var querySource = QuerySourceLocator.FindQuerySource( queryModelVisitor.Model, resultOperator.RelationMember.DeclaringType); + var name = queryModelVisitor.VisitorParameters.QuerySourceNamer.GetName(querySource); - Process(resultOperator, queryModelVisitor, tree, querySource.ItemName); + Process(resultOperator, queryModelVisitor, tree, name); } public void Process(FetchRequestBase resultOperator, QueryModelVisitor queryModelVisitor, IntermediateHqlTree tree, string sourceAlias) From d935eb8017cd90e5226bd8fb4f933c0abe31864a Mon Sep 17 00:00:00 2001 From: Roman Artiukhin Date: Tue, 31 Mar 2020 00:39:22 +0300 Subject: [PATCH 17/43] Refactored session List method for Criteria (#1627) --- .../Criteria/DetachedCriteriaSerializable.cs | 15 +++++++----- .../Criteria/DetachedCriteriaSerializable.cs | 15 +++++++----- .../Async/Impl/AbstractSessionImpl.cs | 13 +++++----- src/NHibernate/Async/Impl/CriteriaImpl.cs | 20 +++++++--------- src/NHibernate/Async/Impl/SessionImpl.cs | 20 ++++++++++------ .../Async/Impl/StatelessSessionImpl.cs | 19 +++++++++------ .../Async/Loader/Criteria/CriteriaLoader.cs | 24 +++++++++++++++++++ src/NHibernate/Impl/AbstractSessionImpl.cs | 13 +++++----- src/NHibernate/Impl/CriteriaImpl.cs | 18 ++++++-------- src/NHibernate/Impl/SessionImpl.cs | 19 +++++++++------ src/NHibernate/Impl/StatelessSessionImpl.cs | 18 ++++++++------ .../Loader/Criteria/CriteriaLoader.cs | 23 ++++++++++++++++++ src/NHibernate/Util/EnumerableExtensions.cs | 5 ++++ 13 files changed, 147 insertions(+), 75 deletions(-) diff --git a/src/NHibernate.Test/Async/Criteria/DetachedCriteriaSerializable.cs b/src/NHibernate.Test/Async/Criteria/DetachedCriteriaSerializable.cs index eac01bb46aa..acaabbf375b 100644 --- a/src/NHibernate.Test/Async/Criteria/DetachedCriteriaSerializable.cs +++ b/src/NHibernate.Test/Async/Criteria/DetachedCriteriaSerializable.cs @@ -137,13 +137,16 @@ public async Task ExecutableCriteriaAsync() await (SerializeAndListAsync(dc)); // Subquery - dc = DetachedCriteria.For(typeof(Student)) - .Add(Property.ForName("StudentNumber").Eq(232L)) - .SetProjection(Property.ForName("Name")); + if (TestDialect.SupportsOperatorAll) + { + dc = DetachedCriteria.For(typeof(Student)) + .Add(Property.ForName("StudentNumber").Eq(232L)) + .SetProjection(Property.ForName("Name")); - DetachedCriteria dcs = DetachedCriteria.For(typeof(Student)) - .Add(Subqueries.PropertyEqAll("Name", dc)); - await (SerializeAndListAsync(dc)); + DetachedCriteria dcs = DetachedCriteria.For(typeof(Student)) + .Add(Subqueries.PropertyEqAll("Name", dc)); + await (SerializeAndListAsync(dcs)); + } // SQLCriterion dc = DetachedCriteria.For(typeof(Student)) diff --git a/src/NHibernate.Test/Criteria/DetachedCriteriaSerializable.cs b/src/NHibernate.Test/Criteria/DetachedCriteriaSerializable.cs index d0a8f39e65d..2ecad73a656 100644 --- a/src/NHibernate.Test/Criteria/DetachedCriteriaSerializable.cs +++ b/src/NHibernate.Test/Criteria/DetachedCriteriaSerializable.cs @@ -391,13 +391,16 @@ public void ExecutableCriteria() SerializeAndList(dc); // Subquery - dc = DetachedCriteria.For(typeof(Student)) - .Add(Property.ForName("StudentNumber").Eq(232L)) - .SetProjection(Property.ForName("Name")); + if (TestDialect.SupportsOperatorAll) + { + dc = DetachedCriteria.For(typeof(Student)) + .Add(Property.ForName("StudentNumber").Eq(232L)) + .SetProjection(Property.ForName("Name")); - DetachedCriteria dcs = DetachedCriteria.For(typeof(Student)) - .Add(Subqueries.PropertyEqAll("Name", dc)); - SerializeAndList(dc); + DetachedCriteria dcs = DetachedCriteria.For(typeof(Student)) + .Add(Subqueries.PropertyEqAll("Name", dc)); + SerializeAndList(dcs); + } // SQLCriterion dc = DetachedCriteria.For(typeof(Student)) diff --git a/src/NHibernate/Async/Impl/AbstractSessionImpl.cs b/src/NHibernate/Async/Impl/AbstractSessionImpl.cs index 2d28c429495..be4e02a2734 100644 --- a/src/NHibernate/Async/Impl/AbstractSessionImpl.cs +++ b/src/NHibernate/Async/Impl/AbstractSessionImpl.cs @@ -30,6 +30,7 @@ using NHibernate.Persister.Entity; using NHibernate.Transaction; using NHibernate.Type; +using NHibernate.Util; namespace NHibernate.Impl { @@ -68,6 +69,7 @@ public virtual async Task> ListAsync(IQueryExpression query, QueryPa } } + //TODO 6.0: Make abstract public virtual async Task> ListAsync(CriteriaImpl criteria, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); @@ -79,17 +81,16 @@ public virtual async Task> ListAsync(CriteriaImpl criteria, Cancella } } + //TODO 6.0: Make virtual public abstract Task ListAsync(CriteriaImpl criteria, IList results, CancellationToken cancellationToken); + //{ + // ArrayHelper.AddAll(results, List(criteria)); + //} public virtual async Task ListAsync(CriteriaImpl criteria, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (BeginProcess()) - { - var results = new List(); - await (ListAsync(criteria, results, cancellationToken)).ConfigureAwait(false); - return results; - } + return (await (ListAsync(criteria, cancellationToken)).ConfigureAwait(false)).ToIList(); } public abstract Task ListFilterAsync(object collection, string filter, QueryParameters parameters, CancellationToken cancellationToken); diff --git a/src/NHibernate/Async/Impl/CriteriaImpl.cs b/src/NHibernate/Async/Impl/CriteriaImpl.cs index 81e735bc052..478ad6b58b6 100644 --- a/src/NHibernate/Async/Impl/CriteriaImpl.cs +++ b/src/NHibernate/Async/Impl/CriteriaImpl.cs @@ -29,18 +29,22 @@ public partial class CriteriaImpl : ICriteria, ISupportEntityJoinCriteria, ISupp public async Task ListAsync(CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); - var results = new List(); - await (ListAsync(results, cancellationToken)).ConfigureAwait(false); - return results; + return (await (ListAsync(cancellationToken)).ConfigureAwait(false)).ToIList(); } public async Task ListAsync(IList results, CancellationToken cancellationToken = default(CancellationToken)) + { + cancellationToken.ThrowIfCancellationRequested(); + ArrayHelper.AddAll(results, await (ListAsync(cancellationToken)).ConfigureAwait(false)); + } + + public async Task> ListAsync(CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); Before(); try { - await (session.ListAsync(this, results, cancellationToken)).ConfigureAwait(false); + return await (session.ListAsync(this, cancellationToken)).ConfigureAwait(false); } finally { @@ -48,14 +52,6 @@ public partial class CriteriaImpl : ICriteria, ISupportEntityJoinCriteria, ISupp } } - public async Task> ListAsync(CancellationToken cancellationToken = default(CancellationToken)) - { - cancellationToken.ThrowIfCancellationRequested(); - List results = new List(); - await (ListAsync(results, cancellationToken)).ConfigureAwait(false); - return results; - } - public async Task UniqueResultAsync(CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); diff --git a/src/NHibernate/Async/Impl/SessionImpl.cs b/src/NHibernate/Async/Impl/SessionImpl.cs index 7843a87a307..8335e43d62f 100644 --- a/src/NHibernate/Async/Impl/SessionImpl.cs +++ b/src/NHibernate/Async/Impl/SessionImpl.cs @@ -1121,7 +1121,7 @@ public override async Task> EnumerableFilterAsync(object colle } } - public override async Task ListAsync(CriteriaImpl criteria, IList results, CancellationToken cancellationToken) + public override async Task> ListAsync(CriteriaImpl criteria, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); using (BeginProcess()) @@ -1134,7 +1134,7 @@ public override async Task ListAsync(CriteriaImpl criteria, IList results, Cance for (int i = 0; i < size; i++) { - loaders[i] = new CriteriaLoader( + var loader = new CriteriaLoader( GetOuterJoinLoadable(implementors[i]), Factory, criteria, @@ -1142,7 +1142,8 @@ public override async Task ListAsync(CriteriaImpl criteria, IList results, Cance enabledFilters ); - spaces.UnionWith(loaders[i].QuerySpaces); + spaces.UnionWith(loader.QuerySpaces); + loaders[size - 1 - i] = loader; } await (AutoFlushIfRequiredAsync(spaces, cancellationToken)).ConfigureAwait(false); @@ -1152,11 +1153,9 @@ public override async Task ListAsync(CriteriaImpl criteria, IList results, Cance { try { - for (int i = size - 1; i >= 0; i--) - { - ArrayHelper.AddAll(results, await (loaders[i].ListAsync(this, cancellationToken)).ConfigureAwait(false)); - } + var results = await (loaders.LoadAllToListAsync(this, cancellationToken)).ConfigureAwait(false); success = true; + return results; } catch (OperationCanceledException) { throw; } catch (HibernateException) @@ -1176,6 +1175,13 @@ public override async Task ListAsync(CriteriaImpl criteria, IList results, Cance } } + //TODO 6.0: Remove (use base class implementation) + public override async Task ListAsync(CriteriaImpl criteria, IList results, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + ArrayHelper.AddAll(results, await (ListAsync(criteria, cancellationToken)).ConfigureAwait(false)); + } + /// /// remove any hard references to the entity that are held by the infrastructure /// (references held by application or other persistant instances are okay) diff --git a/src/NHibernate/Async/Impl/StatelessSessionImpl.cs b/src/NHibernate/Async/Impl/StatelessSessionImpl.cs index f1ba521b612..317f731c37a 100644 --- a/src/NHibernate/Async/Impl/StatelessSessionImpl.cs +++ b/src/NHibernate/Async/Impl/StatelessSessionImpl.cs @@ -122,7 +122,7 @@ public override async Task ListAsync(IQueryExpression queryExpression, QueryPara } } - public override async Task ListAsync(CriteriaImpl criteria, IList results, CancellationToken cancellationToken) + public override async Task> ListAsync(CriteriaImpl criteria, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); using (BeginProcess()) @@ -133,18 +133,16 @@ public override async Task ListAsync(CriteriaImpl criteria, IList results, Cance CriteriaLoader[] loaders = new CriteriaLoader[size]; for (int i = 0; i < size; i++) { - loaders[i] = new CriteriaLoader(GetOuterJoinLoadable(implementors[i]), Factory, + loaders[size - 1 - i] = new CriteriaLoader(GetOuterJoinLoadable(implementors[i]), Factory, criteria, implementors[i], EnabledFilters); } bool success = false; try { - for (int i = size - 1; i >= 0; i--) - { - ArrayHelper.AddAll(results, await (loaders[i].ListAsync(this, cancellationToken)).ConfigureAwait(false)); - } + var results = await (loaders.LoadAllToListAsync(this, cancellationToken)).ConfigureAwait(false); success = true; + return results; } catch (OperationCanceledException) { throw; } catch (HibernateException) @@ -159,11 +157,18 @@ public override async Task ListAsync(CriteriaImpl criteria, IList results, Cance finally { await (AfterOperationAsync(success, cancellationToken)).ConfigureAwait(false); + temporaryPersistenceContext.Clear(); } - temporaryPersistenceContext.Clear(); } } + //TODO 6.0: Remove (use base class implementation) + public override async Task ListAsync(CriteriaImpl criteria, IList results, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + ArrayHelper.AddAll(results, await (ListAsync(criteria, cancellationToken)).ConfigureAwait(false)); + } + public override Task EnumerableAsync(IQueryExpression queryExpression, QueryParameters queryParameters, CancellationToken cancellationToken) { throw new NotImplementedException(); diff --git a/src/NHibernate/Async/Loader/Criteria/CriteriaLoader.cs b/src/NHibernate/Async/Loader/Criteria/CriteriaLoader.cs index 8d5cbbd7cc6..ea6f8c3a5ce 100644 --- a/src/NHibernate/Async/Loader/Criteria/CriteriaLoader.cs +++ b/src/NHibernate/Async/Loader/Criteria/CriteriaLoader.cs @@ -11,6 +11,7 @@ using System.Collections; using System.Collections.Generic; using System.Data.Common; +using System.Linq; using NHibernate.Engine; using NHibernate.Impl; using NHibernate.Param; @@ -24,6 +25,29 @@ namespace NHibernate.Loader.Criteria { using System.Threading.Tasks; using System.Threading; + internal static partial class CriteriaLoaderExtensions + { + /// + /// Loads all loaders results to single typed list + /// + internal static async Task> LoadAllToListAsync(this IList loaders, ISessionImplementor session, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + var subresults = new List(loaders.Count); + foreach(var l in loaders) + { + subresults.Add(await (l.ListAsync(session, cancellationToken)).ConfigureAwait(false)); + } + + var results = new List(subresults.Sum(r => r.Count)); + foreach(var list in subresults) + { + ArrayHelper.AddAll(results, list); + } + return results; + } + } + public partial class CriteriaLoader : OuterJoinLoader { diff --git a/src/NHibernate/Impl/AbstractSessionImpl.cs b/src/NHibernate/Impl/AbstractSessionImpl.cs index 173a8cba7bd..74caba114cf 100644 --- a/src/NHibernate/Impl/AbstractSessionImpl.cs +++ b/src/NHibernate/Impl/AbstractSessionImpl.cs @@ -20,6 +20,7 @@ using NHibernate.Persister.Entity; using NHibernate.Transaction; using NHibernate.Type; +using NHibernate.Util; namespace NHibernate.Impl { @@ -149,6 +150,7 @@ public virtual IList List(IQueryExpression query, QueryParameters paramete } } + //TODO 6.0: Make abstract public virtual IList List(CriteriaImpl criteria) { using (BeginProcess()) @@ -159,16 +161,15 @@ public virtual IList List(CriteriaImpl criteria) } } + //TODO 6.0: Make virtual public abstract void List(CriteriaImpl criteria, IList results); + //{ + // ArrayHelper.AddAll(results, List(criteria)); + //} public virtual IList List(CriteriaImpl criteria) { - using (BeginProcess()) - { - var results = new List(); - List(criteria, results); - return results; - } + return List(criteria).ToIList(); } public abstract IList ListFilter(object collection, string filter, QueryParameters parameters); diff --git a/src/NHibernate/Impl/CriteriaImpl.cs b/src/NHibernate/Impl/CriteriaImpl.cs index 89e703e6d4a..40a196ee27d 100644 --- a/src/NHibernate/Impl/CriteriaImpl.cs +++ b/src/NHibernate/Impl/CriteriaImpl.cs @@ -280,17 +280,20 @@ public ICriteria Add(ICriterion expression) public IList List() { - var results = new List(); - List(results); - return results; + return List().ToIList(); } public void List(IList results) + { + ArrayHelper.AddAll(results, List()); + } + + public IList List() { Before(); try { - session.List(this, results); + return session.List(this); } finally { @@ -298,13 +301,6 @@ public void List(IList results) } } - public IList List() - { - List results = new List(); - List(results); - return results; - } - public T UniqueResult() { object result = UniqueResult(); diff --git a/src/NHibernate/Impl/SessionImpl.cs b/src/NHibernate/Impl/SessionImpl.cs index 95bf4473137..b690cd8d637 100644 --- a/src/NHibernate/Impl/SessionImpl.cs +++ b/src/NHibernate/Impl/SessionImpl.cs @@ -1698,7 +1698,7 @@ public IQueryOver QueryOver(string entityName, Expression> alia } } - public override void List(CriteriaImpl criteria, IList results) + public override IList List(CriteriaImpl criteria) { using (BeginProcess()) { @@ -1710,7 +1710,7 @@ public override void List(CriteriaImpl criteria, IList results) for (int i = 0; i < size; i++) { - loaders[i] = new CriteriaLoader( + var loader = new CriteriaLoader( GetOuterJoinLoadable(implementors[i]), Factory, criteria, @@ -1718,7 +1718,8 @@ public override void List(CriteriaImpl criteria, IList results) enabledFilters ); - spaces.UnionWith(loaders[i].QuerySpaces); + spaces.UnionWith(loader.QuerySpaces); + loaders[size - 1 - i] = loader; } AutoFlushIfRequired(spaces); @@ -1728,11 +1729,9 @@ public override void List(CriteriaImpl criteria, IList results) { try { - for (int i = size - 1; i >= 0; i--) - { - ArrayHelper.AddAll(results, loaders[i].List(this)); - } + var results = loaders.LoadAllToList(this); success = true; + return results; } catch (HibernateException) { @@ -1751,6 +1750,12 @@ public override void List(CriteriaImpl criteria, IList results) } } + //TODO 6.0: Remove (use base class implementation) + public override void List(CriteriaImpl criteria, IList results) + { + ArrayHelper.AddAll(results, List(criteria)); + } + public bool Contains(object obj) { using (BeginProcess()) diff --git a/src/NHibernate/Impl/StatelessSessionImpl.cs b/src/NHibernate/Impl/StatelessSessionImpl.cs index d717e25d62e..12ff0fde552 100644 --- a/src/NHibernate/Impl/StatelessSessionImpl.cs +++ b/src/NHibernate/Impl/StatelessSessionImpl.cs @@ -136,7 +136,7 @@ public override void List(IQueryExpression queryExpression, QueryParameters quer } } - public override void List(CriteriaImpl criteria, IList results) + public override IList List(CriteriaImpl criteria) { using (BeginProcess()) { @@ -146,18 +146,16 @@ public override void List(CriteriaImpl criteria, IList results) CriteriaLoader[] loaders = new CriteriaLoader[size]; for (int i = 0; i < size; i++) { - loaders[i] = new CriteriaLoader(GetOuterJoinLoadable(implementors[i]), Factory, + loaders[size - 1 - i] = new CriteriaLoader(GetOuterJoinLoadable(implementors[i]), Factory, criteria, implementors[i], EnabledFilters); } bool success = false; try { - for (int i = size - 1; i >= 0; i--) - { - ArrayHelper.AddAll(results, loaders[i].List(this)); - } + var results = loaders.LoadAllToList(this); success = true; + return results; } catch (HibernateException) { @@ -171,11 +169,17 @@ public override void List(CriteriaImpl criteria, IList results) finally { AfterOperation(success); + temporaryPersistenceContext.Clear(); } - temporaryPersistenceContext.Clear(); } } + //TODO 6.0: Remove (use base class implementation) + public override void List(CriteriaImpl criteria, IList results) + { + ArrayHelper.AddAll(results, List(criteria)); + } + public override IEnumerable Enumerable(IQueryExpression queryExpression, QueryParameters queryParameters) { throw new NotImplementedException(); diff --git a/src/NHibernate/Loader/Criteria/CriteriaLoader.cs b/src/NHibernate/Loader/Criteria/CriteriaLoader.cs index 45cf0bd37a8..e3b816d2661 100644 --- a/src/NHibernate/Loader/Criteria/CriteriaLoader.cs +++ b/src/NHibernate/Loader/Criteria/CriteriaLoader.cs @@ -1,6 +1,7 @@ using System.Collections; using System.Collections.Generic; using System.Data.Common; +using System.Linq; using NHibernate.Engine; using NHibernate.Impl; using NHibernate.Param; @@ -12,6 +13,28 @@ namespace NHibernate.Loader.Criteria { + internal static partial class CriteriaLoaderExtensions + { + /// + /// Loads all loaders results to single typed list + /// + internal static List LoadAllToList(this IList loaders, ISessionImplementor session) + { + var subresults = new List(loaders.Count); + foreach(var l in loaders) + { + subresults.Add(l.List(session)); + } + + var results = new List(subresults.Sum(r => r.Count)); + foreach(var list in subresults) + { + ArrayHelper.AddAll(results, list); + } + return results; + } + } + /// /// A Loader for queries. /// diff --git a/src/NHibernate/Util/EnumerableExtensions.cs b/src/NHibernate/Util/EnumerableExtensions.cs index a4ac7e52979..b0e35892b06 100644 --- a/src/NHibernate/Util/EnumerableExtensions.cs +++ b/src/NHibernate/Util/EnumerableExtensions.cs @@ -97,5 +97,10 @@ internal static List ToList(this List input, C { return input.ConvertAll(converter); } + + internal static IList ToIList(this IEnumerable list) + { + return list as IList ?? list.ToList(); + } } } From d71c527a478a5e13e8565b1c97704ae553d58f9b Mon Sep 17 00:00:00 2001 From: Roman Artiukhin Date: Tue, 31 Mar 2020 00:42:41 +0300 Subject: [PATCH 18/43] Call generic query.List from Linq queries (#2238) --- .../Linq/ByMethod/WithOptionsTests.cs | 2 ++ .../Async/Linq/DefaultQueryProvider.cs | 17 +++++++++++++++++ src/NHibernate/Linq/DefaultQueryProvider.cs | 16 ++++++++++++++++ src/NHibernate/Linq/LinqExtensionMethods.cs | 7 +++++-- src/NHibernate/Linq/NhQueryable.cs | 12 +++++++++++- 5 files changed, 51 insertions(+), 3 deletions(-) diff --git a/src/NHibernate.Test/Linq/ByMethod/WithOptionsTests.cs b/src/NHibernate.Test/Linq/ByMethod/WithOptionsTests.cs index b32a55f287a..13cba8d2061 100644 --- a/src/NHibernate.Test/Linq/ByMethod/WithOptionsTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/WithOptionsTests.cs @@ -23,6 +23,7 @@ public void AppliesOptionsToQuery() var query = Substitute.For(); query.List().Returns(new List()); + query.List().Returns(new List()); session.CreateQuery(Arg.Any()).Returns(query); @@ -56,6 +57,7 @@ public void DoNotContaminateQueryWithOptions() var query = Substitute.For(); query.List().Returns(new List()); + query.List().Returns(new List()); session.CreateQuery(Arg.Any()).Returns(query); diff --git a/src/NHibernate/Async/Linq/DefaultQueryProvider.cs b/src/NHibernate/Async/Linq/DefaultQueryProvider.cs index d1e3549e354..4d0344a27eb 100644 --- a/src/NHibernate/Async/Linq/DefaultQueryProvider.cs +++ b/src/NHibernate/Async/Linq/DefaultQueryProvider.cs @@ -32,6 +32,23 @@ public partial interface INhQueryProvider : IQueryProvider public partial class DefaultQueryProvider : INhQueryProvider, IQueryProviderWithOptions, ISupportFutureBatchNhQueryProvider { + //TODO 6.0: Add to INhQueryProvider interface + public virtual async Task> ExecuteListAsync(Expression expression, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + var linqExpression = PrepareQuery(expression, out var query); + var resultTransformer = linqExpression.ExpressionToHqlTranslationResults?.PostExecuteTransformer; + if (resultTransformer == null) + { + return await (query.ListAsync(cancellationToken)).ConfigureAwait(false); + } + + return new List + { + (TResult) resultTransformer.DynamicInvoke((await (query.ListAsync(cancellationToken)).ConfigureAwait(false)).AsQueryable()) + }; + } + // Since v5.1 [Obsolete("Use ExecuteQuery(NhLinqExpression nhLinqExpression, IQuery query) instead")] protected virtual async Task ExecuteQueryAsync(NhLinqExpression nhLinqExpression, IQuery query, NhLinqExpression nhQuery, CancellationToken cancellationToken) diff --git a/src/NHibernate/Linq/DefaultQueryProvider.cs b/src/NHibernate/Linq/DefaultQueryProvider.cs index 3442c3c7d7c..c8de5a37a5e 100644 --- a/src/NHibernate/Linq/DefaultQueryProvider.cs +++ b/src/NHibernate/Linq/DefaultQueryProvider.cs @@ -100,6 +100,22 @@ public TResult Execute(Expression expression) return (TResult)Execute(expression); } + //TODO 6.0: Add to INhQueryProvider interface + public virtual IList ExecuteList(Expression expression) + { + var linqExpression = PrepareQuery(expression, out var query); + var resultTransformer = linqExpression.ExpressionToHqlTranslationResults?.PostExecuteTransformer; + if (resultTransformer == null) + { + return query.List(); + } + + return new List + { + (TResult) resultTransformer.DynamicInvoke(query.List().AsQueryable()) + }; + } + public IQueryProvider WithOptions(Action setOptions) { if (setOptions == null) throw new ArgumentNullException(nameof(setOptions)); diff --git a/src/NHibernate/Linq/LinqExtensionMethods.cs b/src/NHibernate/Linq/LinqExtensionMethods.cs index 4d70b7e40e1..8b31c2988a0 100644 --- a/src/NHibernate/Linq/LinqExtensionMethods.cs +++ b/src/NHibernate/Linq/LinqExtensionMethods.cs @@ -2389,8 +2389,11 @@ public static class LinqExtensionMethods async Task> InternalToListAsync() { - var result = await provider.ExecuteAsync>(source.Expression, cancellationToken).ConfigureAwait(false); - return result.ToList(); + //TODO 6.0: Replace with provider.ExecuteListAsync + var result = provider is DefaultQueryProvider nhQueryProvider + ? await nhQueryProvider.ExecuteListAsync(source.Expression, cancellationToken).ConfigureAwait(false) + : await provider.ExecuteAsync>(source.Expression, cancellationToken).ConfigureAwait(false); + return (result as List) ?? result.ToList(); } } diff --git a/src/NHibernate/Linq/NhQueryable.cs b/src/NHibernate/Linq/NhQueryable.cs index d9db57de377..949f851c923 100644 --- a/src/NHibernate/Linq/NhQueryable.cs +++ b/src/NHibernate/Linq/NhQueryable.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using NHibernate.Engine; @@ -17,7 +18,7 @@ interface IEntityNameProvider /// /// Provides the main entry point to a LINQ query. /// - public class NhQueryable : QueryableBase, IEntityNameProvider + public class NhQueryable : QueryableBase, IEntityNameProvider, IEnumerable { // This constructor is called by our users, create a new IQueryExecutor. public NhQueryable(ISessionImplementor session) @@ -57,5 +58,14 @@ public override string ToString() { return "NHibernate.Linq.NhQueryable`1[" + EntityName + "]"; } + + IEnumerator IEnumerable.GetEnumerator() + { + //TODO 6.0: Cast to INhQueryProvider + return + Provider is DefaultQueryProvider nhProvider + ? nhProvider.ExecuteList(Expression).GetEnumerator() + : base.GetEnumerator(); + } } } From 0b368c73de766614bfd381b894e7df4f3b3be195 Mon Sep 17 00:00:00 2001 From: "g.yakimov" Date: Tue, 31 Mar 2020 12:25:11 +0300 Subject: [PATCH 19/43] remove ToColumns with 3 arguments --- src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs | 5 ----- .../Loader/Criteria/CriteriaQueryTranslator.cs | 2 +- .../Collection/AbstractCollectionPersister.cs | 6 ------ .../Collection/CollectionPropertyMapping.cs | 8 +------- .../Persister/Collection/ElementPropertyMapping.cs | 8 +------- .../Persister/Entity/AbstractEntityPersister.cs | 5 ----- .../Persister/Entity/AbstractPropertyMapping.cs | 6 ------ .../Persister/Entity/BasicEntityPropertyMapping.cs | 14 -------------- .../Persister/Entity/IPropertyMapping.cs | 4 +--- 9 files changed, 4 insertions(+), 54 deletions(-) diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs index f434b3b7f51..bfffb9be928 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs @@ -150,11 +150,6 @@ public bool TryToType(string propertyName, out IType type) return fromElementType.GetBasePropertyMapping().TryToType(GetPropertyPath(propertyName), out type); } - public string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) - { - return fromElementType.GetBasePropertyMapping().ToColumns(pathCriteria, GetPropertyPath(propertyName), getSQLAlias); - } - public string[] ToColumns(string alias, string propertyName) { return fromElementType.GetBasePropertyMapping().ToColumns(alias, GetPropertyPath(propertyName)); diff --git a/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs b/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs index 8ebad5641c4..682dc73e084 100644 --- a/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs +++ b/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs @@ -766,7 +766,7 @@ private bool TryGetColumns(ICriteria subcriteria, string path, bool verifyProper return false; } - columns = propertyMapping.ToColumns(pathCriteria, propertyName, GetSQLAlias); + columns = propertyMapping.ToColumns(GetSQLAlias(pathCriteria), propertyName); return true; } diff --git a/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs b/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs index 593482633e8..20701df9435 100644 --- a/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs +++ b/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs @@ -1386,12 +1386,6 @@ public bool IsManyToManyFiltered(IDictionary enabledFilters) return IsManyToMany && (manyToManyWhereString != null || manyToManyFilterHelper.IsAffectedBy(enabledFilters)); } - public string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) - { - string alias = getSQLAlias(pathCriteria); - return ToColumns(alias, propertyName); - } - public string[] ToColumns(string alias, string propertyName) { if ("index".Equals(propertyName)) diff --git a/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs b/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs index de56cb9cd81..e9e2f89dc51 100644 --- a/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs +++ b/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs @@ -57,12 +57,6 @@ public bool TryToType(string propertyName, out IType type) } } - public string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) - { - string alias = getSQLAlias(pathCriteria); - return ToColumns(alias, propertyName); - } - public string[] ToColumns(string alias, string propertyName) { string[] cols; @@ -123,4 +117,4 @@ public IType Type get { return memberPersister.CollectionType; } } } -} +} \ No newline at end of file diff --git a/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs b/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs index ad412a19774..20e9899ddb6 100644 --- a/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs +++ b/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs @@ -47,12 +47,6 @@ public bool TryToType(string propertyName, out IType outType) } } - public string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) - { - string alias = getSQLAlias(pathCriteria); - return ToColumns(alias, propertyName); - } - public string[] ToColumns(string alias, string propertyName) { if (propertyName == null || "id".Equals(propertyName)) @@ -77,4 +71,4 @@ public IType Type #endregion } -} +} \ No newline at end of file diff --git a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs index 879a0d670f3..924da726cc1 100644 --- a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs @@ -2050,11 +2050,6 @@ public virtual string GetRootTableAlias(string drivingAlias) return drivingAlias; } - public virtual string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) - { - return propertyMapping.ToColumns(pathCriteria, propertyName, getSQLAlias); - } - public virtual string[] ToColumns(string alias, string propertyName) { return propertyMapping.ToColumns(alias, propertyName); diff --git a/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs b/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs index 46e8ca70e34..c027568bf18 100644 --- a/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs +++ b/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs @@ -44,12 +44,6 @@ public bool TryToType(string propertyName, out IType type) return typesByPropertyPath.TryGetValue(propertyName, out type); } - public virtual string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) - { - string alias = getSQLAlias(pathCriteria); - return ToColumns(alias, propertyName); - } - public virtual string[] ToColumns(string alias, string propertyName) { //TODO: *two* hashmap lookups here is one too many... diff --git a/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs b/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs index 6c7b31a6940..02f625bd550 100644 --- a/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs +++ b/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs @@ -1,7 +1,4 @@ -using System; -using NHibernate.Criterion; using NHibernate.Type; -using static NHibernate.Impl.CriteriaImpl; namespace NHibernate.Persister.Entity { @@ -29,17 +26,6 @@ public override IType Type get { return persister.Type; } } - public override string[] ToColumns(ICriteria pathCriteria, string propertyName, Func getSQLAlias) - { - var withClause = pathCriteria as Subcriteria != null ? ((Subcriteria) pathCriteria).WithClause as SimpleExpression : null; - if (withClause != null && withClause.PropertyName == propertyName) - { - return base.ToColumns(persister.GenerateTableAlias(getSQLAlias(pathCriteria), 0), propertyName); - } - - return base.ToColumns(pathCriteria, propertyName, getSQLAlias); - } - public override string[] ToColumns(string alias, string propertyName) { return diff --git a/src/NHibernate/Persister/Entity/IPropertyMapping.cs b/src/NHibernate/Persister/Entity/IPropertyMapping.cs index b348d36eae9..dbe08dd9139 100644 --- a/src/NHibernate/Persister/Entity/IPropertyMapping.cs +++ b/src/NHibernate/Persister/Entity/IPropertyMapping.cs @@ -29,8 +29,6 @@ public interface IPropertyMapping /// true if a type was found, false if not bool TryToType(string propertyName, out IType type); - string[] ToColumns(ICriteria pathCriteria, string propertyName, System.Func getSQLAlias); - /// /// Given a query alias and a property path, return the qualified column name /// @@ -42,4 +40,4 @@ public interface IPropertyMapping /// Given a property path, return the corresponding column name(s). string[] ToColumns(string propertyName); } -} +} \ No newline at end of file From 39e7f52e6f983fff9e4b22b7d92f29f2d4444fd8 Mon Sep 17 00:00:00 2001 From: "g.yakimov" Date: Tue, 31 Mar 2020 12:39:24 +0300 Subject: [PATCH 20/43] consider overriden properties when getting subclass property table number --- src/NHibernate/Persister/Entity/AbstractEntityPersister.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs index 924da726cc1..0bf2e57c673 100644 --- a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs @@ -2110,7 +2110,7 @@ public virtual int GetSubclassPropertyTableNumber(string propertyPath) return getSubclassColumnTableNumberClosure()[idx]; } }*/ - int index = Array.IndexOf(SubclassPropertyNameClosure, rootPropertyName); //TODO: optimize this better! + int index = Array.LastIndexOf(SubclassPropertyNameClosure, rootPropertyName); //TODO: optimize this better! return index == -1 ? 0 : GetSubclassPropertyTableNumber(index); } From f8a3a22f2c323e6e6b9ce1d38533564c6c9093e9 Mon Sep 17 00:00:00 2001 From: hailtondecastro Date: Tue, 31 Mar 2020 19:43:59 -0300 Subject: [PATCH 21/43] Fix custom sql loader with composite id (#452) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Frédéric Delaporte <12201973+fredericdelaporte@users.noreply.github.com> Co-authored-by: maca88 --- src/AsyncGenerator.yml | 4 + .../Async/NHSpecificTest/NH3079/Fixture.cs | 246 ++++++++++++++++++ .../NHSpecificTest/NH3079/Employer.cs | 13 + .../NHSpecificTest/NH3079/EmployerCpId.cs | 22 ++ .../NHSpecificTest/NH3079/Employment.cs | 9 + .../NHSpecificTest/NH3079/EmploymentCpId.cs | 25 ++ .../NHSpecificTest/NH3079/Fixture.cs | 235 +++++++++++++++++ .../NHSpecificTest/NH3079/Mappings.hbm.xml | 163 ++++++++++++ .../NHSpecificTest/NH3079/Person.cs | 13 + .../NHSpecificTest/NH3079/PersonCpId.cs | 22 ++ .../NH3079/PersonNoComponent.cs | 24 ++ src/NHibernate/Async/Impl/SqlQueryImpl.cs | 2 + src/NHibernate/Impl/AbstractQueryImpl.cs | 23 +- src/NHibernate/Impl/SqlQueryImpl.cs | 64 +++++ 14 files changed, 856 insertions(+), 9 deletions(-) create mode 100644 src/NHibernate.Test/Async/NHSpecificTest/NH3079/Fixture.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/NH3079/Employer.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/NH3079/EmployerCpId.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/NH3079/Employment.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/NH3079/EmploymentCpId.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/NH3079/Fixture.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/NH3079/Mappings.hbm.xml create mode 100644 src/NHibernate.Test/NHSpecificTest/NH3079/Person.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/NH3079/PersonCpId.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/NH3079/PersonNoComponent.cs diff --git a/src/AsyncGenerator.yml b/src/AsyncGenerator.yml index 68ae0a1c735..99c765a110c 100644 --- a/src/AsyncGenerator.yml +++ b/src/AsyncGenerator.yml @@ -110,6 +110,10 @@ - conversion: Ignore name: GetEnumerator containingTypeName: IFutureEnumerable +# TODO 6.0: Consider if ComputeFlattenedParameters should remain ignored or not + - conversion: Ignore + name: ComputeFlattenedParameters + containingTypeName: SqlQueryImpl - conversion: ToAsync name: ExecuteReader containingTypeName: IBatcher diff --git a/src/NHibernate.Test/Async/NHSpecificTest/NH3079/Fixture.cs b/src/NHibernate.Test/Async/NHSpecificTest/NH3079/Fixture.cs new file mode 100644 index 00000000000..e9980e3a398 --- /dev/null +++ b/src/NHibernate.Test/Async/NHSpecificTest/NH3079/Fixture.cs @@ -0,0 +1,246 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System.Collections.Generic; +using System.Linq; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.NH3079 +{ + using System.Threading.Tasks; + [TestFixture] + public class FixtureAsync : BugTestCase + { + // Disable second level cache + protected override string CacheConcurrencyStrategy => null; + + protected override void OnTearDown() + { + using (var s = OpenSession()) + using (var t = s.BeginTransaction()) + { + s.CreateQuery("delete from Employment").ExecuteUpdate(); + + s.CreateQuery("delete from System.Object").ExecuteUpdate(); + + t.Commit(); + } + } + + protected override void OnSetUp() + { + using (var s = OpenSession()) + using (var t = s.BeginTransaction()) + { + var personList = new List(); + var employerList = new List(); + + // Global id to avoid false positive assertion with positional parameter + var gId = 1; + + for (var i = 0; i < 3; i++) + { + var personCpId = new PersonCpId { IdA = gId++, IdB = gId++ }; + var personObj = new Person + { CpId = personCpId, Name = "PERSON_" + personCpId.IdA + "_" + personCpId.IdB }; + s.Save(personObj); + personList.Add(personObj); + } + + for (var i = 0; i < 3; i++) + { + var employerCpId = new EmployerCpId { IdA = gId++, IdB = gId++ }; + var employerObj = new Employer + { CpId = employerCpId, Name = "EMPLOYER_" + employerCpId.IdA + "_" + employerCpId.IdB }; + s.Save(employerObj); + employerList.Add(employerObj); + } + + var employmentIds = new[] + { + gId++, + gId++, + gId++, + gId++, + }; + var employmentNames = new[] + { + //P1 + E1 + "EMPLOYMENT_" + employmentIds[0] + "_" + + personList[0].CpId.IdA + "_" + personList[0].CpId.IdA + "_" + + employerList[0].CpId.IdA + "_" + employerList[0].CpId.IdB, + //P1 + E2 + "EMPLOYMENT_" + employmentIds[1] + "_" + + personList[0].CpId.IdA + "_" + personList[0].CpId.IdA + "_" + + employerList[1].CpId.IdA + "_" + employerList[1].CpId.IdB, + //P2 + E2 + "EMPLOYMENT_" + employmentIds[2] + "_" + + personList[1].CpId.IdA + "_" + personList[1].CpId.IdA + "_" + + employerList[1].CpId.IdA + "_" + employerList[1].CpId.IdB, + //P2 + E3 + "EMPLOYMENT_" + employmentIds[2] + "_" + + personList[1].CpId.IdA + "_" + personList[1].CpId.IdA + "_" + + employerList[2].CpId.IdA + "_" + employerList[2].CpId.IdB + }; + var employmentPersons = new[] + { + personList[0], + personList[0], + personList[1], + personList[1] + }; + var employmentEmployers = new[] + { + employerList[0], + employerList[1], + employerList[1], + employerList[2] + }; + + for (var k = 0; k < employmentIds.Length; k++) + { + var employmentCpId = new EmploymentCpId + { + Id = employmentIds[k], + PersonObj = employmentPersons[k], + EmployerObj = employmentEmployers[k] + }; + var employmentObj = new Employment { CpId = employmentCpId, Name = employmentNames[k] }; + s.Save(employmentObj); + } + + for (var i = 0; i < 3; i++) + { + var personNoComponentObj = new PersonNoComponent { IdA = gId++, IdB = gId++ }; + personNoComponentObj.Name = "PERSON_NO_COMPONENT_" + personNoComponentObj.IdA + "_" + + personNoComponentObj.IdB; + s.Save(personNoComponentObj); + } + + t.Commit(); + } + } + + // Test reproducing the problem. + [Test] + public async Task GetPersonTestAsync() + { + using (var session = OpenSession()) + { + var person1_2 = await (session.GetAsync(new PersonCpId { IdA = 1, IdB = 2 })); + Assert.That(person1_2.Name, Is.EqualTo("PERSON_1_2")); + Assert.That( + person1_2.EmploymentList.Select(e => e.Name), + Is.EquivalentTo(new[] { "EMPLOYMENT_13_1_1_7_8", "EMPLOYMENT_14_1_1_9_10" })); + } + } + + // Test reproducing the problem. + [Test] + public async Task GetEmployerTestAsync() + { + using (var session = OpenSession()) + { + var employer7_8 = await (session.GetAsync(new EmployerCpId { IdA = 7, IdB = 8 })); + Assert.That(employer7_8.Name, Is.EqualTo("EMPLOYER_7_8")); + Assert.That( + employer7_8.EmploymentList.Select(e => e.Name), + Is.EquivalentTo(new[] { "EMPLOYMENT_13_1_1_7_8" })); + } + } + + [Test] + public async Task GetEmploymentTestAsync() + { + using (var session = OpenSession()) + { + var employment_13_1_2_7_8 = + await (session.GetAsync( + new EmploymentCpId + { + Id = 13, + PersonObj = + new Person + { + CpId = new PersonCpId { IdA = 1, IdB = 2 } + }, + EmployerObj = + new Employer + { + CpId = new EmployerCpId { IdA = 7, IdB = 8 } + } + })); + Assert.That(employment_13_1_2_7_8.Name, Is.EqualTo("EMPLOYMENT_13_1_1_7_8")); + } + } + + [Test] + public async Task HqlPersonPositionalAsync() + { + using (var session = OpenSession()) + { + var personList = + await (session + .GetNamedQuery("personPositional") + .SetParameter(0, new PersonCpId { IdA = 1, IdB = 2 }) + .SetParameter(1, new PersonCpId { IdA = 3, IdB = 4 }) + .ListAsync()); + Assert.That( + personList.Select(e => e.Name), + Is.EquivalentTo(new[] { "PERSON_1_2", "PERSON_3_4" })); + } + } + + [Test] + public async Task HqlPersonNamedAsync() + { + using (var session = OpenSession()) + { + var personList = + await (session + .GetNamedQuery("personNamed") + .SetParameter("id1", new PersonCpId { IdA = 1, IdB = 2 }) + .SetParameter("id2", new PersonCpId { IdA = 3, IdB = 4 }) + .ListAsync()); + Assert.That( + personList.Select(e => e.Name), + Is.EquivalentTo(new[] { "PERSON_1_2", "PERSON_3_4" })); + } + } + + [Test] + public async Task GetPersonNoComponentAsync() + { + using (var session = OpenSession()) + { + var person17_18 = + await (session.GetAsync(new PersonNoComponent { IdA = 17, IdB = 18 })); + Assert.That(person17_18.Name, Is.EqualTo("PERSON_NO_COMPONENT_17_18")); + } + } + + [Test] + public async Task SqlPersonNoComponentAsync() + { + using (var session = OpenSession()) + { + var personList = + await (session + .GetNamedQuery("personNoComponentSql") + .SetParameter(0, new PersonNoComponent { IdA = 17, IdB = 18 }) + .SetParameter(1, new PersonNoComponent { IdA = 19, IdB = 20 }) + .ListAsync()); + Assert.That( + personList.Select(e => e.Name), + Is.EquivalentTo(new[] { "PERSON_NO_COMPONENT_17_18", "PERSON_NO_COMPONENT_19_20" })); + } + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3079/Employer.cs b/src/NHibernate.Test/NHSpecificTest/NH3079/Employer.cs new file mode 100644 index 00000000000..c5fe0ffd9ec --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3079/Employer.cs @@ -0,0 +1,13 @@ +using System.Collections.Generic; + +namespace NHibernate.Test.NHSpecificTest.NH3079 +{ + public class Employer + { + public virtual EmployerCpId CpId { get; set; } + + public virtual string Name { get; set; } + + public virtual ICollection EmploymentList { get; set; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3079/EmployerCpId.cs b/src/NHibernate.Test/NHSpecificTest/NH3079/EmployerCpId.cs new file mode 100644 index 00000000000..dd8bd8cb69a --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3079/EmployerCpId.cs @@ -0,0 +1,22 @@ +namespace NHibernate.Test.NHSpecificTest.NH3079 +{ + public class EmployerCpId + { + public virtual int IdA { get; set; } + + public virtual int IdB { get; set; } + + public override bool Equals(object obj) + { + if (!(obj is EmployerCpId objCpId)) + return false; + + return IdA == objCpId.IdA && IdB == objCpId.IdB; + } + + public override int GetHashCode() + { + return IdA.GetHashCode() ^ IdB.GetHashCode(); + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3079/Employment.cs b/src/NHibernate.Test/NHSpecificTest/NH3079/Employment.cs new file mode 100644 index 00000000000..a84fd07cc6c --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3079/Employment.cs @@ -0,0 +1,9 @@ +namespace NHibernate.Test.NHSpecificTest.NH3079 +{ + public class Employment + { + public virtual EmploymentCpId CpId { get; set; } + + public virtual string Name { get; set; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3079/EmploymentCpId.cs b/src/NHibernate.Test/NHSpecificTest/NH3079/EmploymentCpId.cs new file mode 100644 index 00000000000..2207044966a --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3079/EmploymentCpId.cs @@ -0,0 +1,25 @@ +namespace NHibernate.Test.NHSpecificTest.NH3079 +{ + public class EmploymentCpId + { + public virtual int Id { get; set; } + + public virtual Person PersonObj { get; set; } + + public virtual Employer EmployerObj { get; set; } + + public override bool Equals(object obj) + { + if (!(obj is EmploymentCpId objCpId)) + return false; + + return Id == objCpId.Id && PersonObj.CpId == objCpId.PersonObj.CpId && + EmployerObj.CpId == objCpId.EmployerObj.CpId; + } + + public override int GetHashCode() + { + return Id.GetHashCode() ^ PersonObj.CpId.GetHashCode() ^ EmployerObj.CpId.GetHashCode(); + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3079/Fixture.cs b/src/NHibernate.Test/NHSpecificTest/NH3079/Fixture.cs new file mode 100644 index 00000000000..f5a298c6afd --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3079/Fixture.cs @@ -0,0 +1,235 @@ +using System.Collections.Generic; +using System.Linq; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.NH3079 +{ + [TestFixture] + public class Fixture : BugTestCase + { + // Disable second level cache + protected override string CacheConcurrencyStrategy => null; + + protected override void OnTearDown() + { + using (var s = OpenSession()) + using (var t = s.BeginTransaction()) + { + s.CreateQuery("delete from Employment").ExecuteUpdate(); + + s.CreateQuery("delete from System.Object").ExecuteUpdate(); + + t.Commit(); + } + } + + protected override void OnSetUp() + { + using (var s = OpenSession()) + using (var t = s.BeginTransaction()) + { + var personList = new List(); + var employerList = new List(); + + // Global id to avoid false positive assertion with positional parameter + var gId = 1; + + for (var i = 0; i < 3; i++) + { + var personCpId = new PersonCpId { IdA = gId++, IdB = gId++ }; + var personObj = new Person + { CpId = personCpId, Name = "PERSON_" + personCpId.IdA + "_" + personCpId.IdB }; + s.Save(personObj); + personList.Add(personObj); + } + + for (var i = 0; i < 3; i++) + { + var employerCpId = new EmployerCpId { IdA = gId++, IdB = gId++ }; + var employerObj = new Employer + { CpId = employerCpId, Name = "EMPLOYER_" + employerCpId.IdA + "_" + employerCpId.IdB }; + s.Save(employerObj); + employerList.Add(employerObj); + } + + var employmentIds = new[] + { + gId++, + gId++, + gId++, + gId++, + }; + var employmentNames = new[] + { + //P1 + E1 + "EMPLOYMENT_" + employmentIds[0] + "_" + + personList[0].CpId.IdA + "_" + personList[0].CpId.IdA + "_" + + employerList[0].CpId.IdA + "_" + employerList[0].CpId.IdB, + //P1 + E2 + "EMPLOYMENT_" + employmentIds[1] + "_" + + personList[0].CpId.IdA + "_" + personList[0].CpId.IdA + "_" + + employerList[1].CpId.IdA + "_" + employerList[1].CpId.IdB, + //P2 + E2 + "EMPLOYMENT_" + employmentIds[2] + "_" + + personList[1].CpId.IdA + "_" + personList[1].CpId.IdA + "_" + + employerList[1].CpId.IdA + "_" + employerList[1].CpId.IdB, + //P2 + E3 + "EMPLOYMENT_" + employmentIds[2] + "_" + + personList[1].CpId.IdA + "_" + personList[1].CpId.IdA + "_" + + employerList[2].CpId.IdA + "_" + employerList[2].CpId.IdB + }; + var employmentPersons = new[] + { + personList[0], + personList[0], + personList[1], + personList[1] + }; + var employmentEmployers = new[] + { + employerList[0], + employerList[1], + employerList[1], + employerList[2] + }; + + for (var k = 0; k < employmentIds.Length; k++) + { + var employmentCpId = new EmploymentCpId + { + Id = employmentIds[k], + PersonObj = employmentPersons[k], + EmployerObj = employmentEmployers[k] + }; + var employmentObj = new Employment { CpId = employmentCpId, Name = employmentNames[k] }; + s.Save(employmentObj); + } + + for (var i = 0; i < 3; i++) + { + var personNoComponentObj = new PersonNoComponent { IdA = gId++, IdB = gId++ }; + personNoComponentObj.Name = "PERSON_NO_COMPONENT_" + personNoComponentObj.IdA + "_" + + personNoComponentObj.IdB; + s.Save(personNoComponentObj); + } + + t.Commit(); + } + } + + // Test reproducing the problem. + [Test] + public void GetPersonTest() + { + using (var session = OpenSession()) + { + var person1_2 = session.Get(new PersonCpId { IdA = 1, IdB = 2 }); + Assert.That(person1_2.Name, Is.EqualTo("PERSON_1_2")); + Assert.That( + person1_2.EmploymentList.Select(e => e.Name), + Is.EquivalentTo(new[] { "EMPLOYMENT_13_1_1_7_8", "EMPLOYMENT_14_1_1_9_10" })); + } + } + + // Test reproducing the problem. + [Test] + public void GetEmployerTest() + { + using (var session = OpenSession()) + { + var employer7_8 = session.Get(new EmployerCpId { IdA = 7, IdB = 8 }); + Assert.That(employer7_8.Name, Is.EqualTo("EMPLOYER_7_8")); + Assert.That( + employer7_8.EmploymentList.Select(e => e.Name), + Is.EquivalentTo(new[] { "EMPLOYMENT_13_1_1_7_8" })); + } + } + + [Test] + public void GetEmploymentTest() + { + using (var session = OpenSession()) + { + var employment_13_1_2_7_8 = + session.Get( + new EmploymentCpId + { + Id = 13, + PersonObj = + new Person + { + CpId = new PersonCpId { IdA = 1, IdB = 2 } + }, + EmployerObj = + new Employer + { + CpId = new EmployerCpId { IdA = 7, IdB = 8 } + } + }); + Assert.That(employment_13_1_2_7_8.Name, Is.EqualTo("EMPLOYMENT_13_1_1_7_8")); + } + } + + [Test] + public void HqlPersonPositional() + { + using (var session = OpenSession()) + { + var personList = + session + .GetNamedQuery("personPositional") + .SetParameter(0, new PersonCpId { IdA = 1, IdB = 2 }) + .SetParameter(1, new PersonCpId { IdA = 3, IdB = 4 }) + .List(); + Assert.That( + personList.Select(e => e.Name), + Is.EquivalentTo(new[] { "PERSON_1_2", "PERSON_3_4" })); + } + } + + [Test] + public void HqlPersonNamed() + { + using (var session = OpenSession()) + { + var personList = + session + .GetNamedQuery("personNamed") + .SetParameter("id1", new PersonCpId { IdA = 1, IdB = 2 }) + .SetParameter("id2", new PersonCpId { IdA = 3, IdB = 4 }) + .List(); + Assert.That( + personList.Select(e => e.Name), + Is.EquivalentTo(new[] { "PERSON_1_2", "PERSON_3_4" })); + } + } + + [Test] + public void GetPersonNoComponent() + { + using (var session = OpenSession()) + { + var person17_18 = + session.Get(new PersonNoComponent { IdA = 17, IdB = 18 }); + Assert.That(person17_18.Name, Is.EqualTo("PERSON_NO_COMPONENT_17_18")); + } + } + + [Test] + public void SqlPersonNoComponent() + { + using (var session = OpenSession()) + { + var personList = + session + .GetNamedQuery("personNoComponentSql") + .SetParameter(0, new PersonNoComponent { IdA = 17, IdB = 18 }) + .SetParameter(1, new PersonNoComponent { IdA = 19, IdB = 20 }) + .List(); + Assert.That( + personList.Select(e => e.Name), + Is.EquivalentTo(new[] { "PERSON_NO_COMPONENT_17_18", "PERSON_NO_COMPONENT_19_20" })); + } + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3079/Mappings.hbm.xml b/src/NHibernate.Test/NHSpecificTest/NH3079/Mappings.hbm.xml new file mode 100644 index 00000000000..a9b8a606aa3 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3079/Mappings.hbm.xml @@ -0,0 +1,163 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/NHibernate.Test/NHSpecificTest/NH3079/Person.cs b/src/NHibernate.Test/NHSpecificTest/NH3079/Person.cs new file mode 100644 index 00000000000..b058470b237 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3079/Person.cs @@ -0,0 +1,13 @@ +using System.Collections.Generic; + +namespace NHibernate.Test.NHSpecificTest.NH3079 +{ + public class Person + { + public virtual PersonCpId CpId { get; set; } + + public virtual string Name { get; set; } + + public virtual ICollection EmploymentList { get; set; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3079/PersonCpId.cs b/src/NHibernate.Test/NHSpecificTest/NH3079/PersonCpId.cs new file mode 100644 index 00000000000..51788014334 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3079/PersonCpId.cs @@ -0,0 +1,22 @@ +namespace NHibernate.Test.NHSpecificTest.NH3079 +{ + public class PersonCpId + { + public int IdA { get; set; } + + public int IdB { get; set; } + + public override bool Equals(object obj) + { + if (!(obj is PersonCpId objCpId)) + return false; + + return IdA == objCpId.IdA && IdB == objCpId.IdB; + } + + public override int GetHashCode() + { + return IdA.GetHashCode() ^ IdB.GetHashCode(); + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/NH3079/PersonNoComponent.cs b/src/NHibernate.Test/NHSpecificTest/NH3079/PersonNoComponent.cs new file mode 100644 index 00000000000..7948701375d --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/NH3079/PersonNoComponent.cs @@ -0,0 +1,24 @@ +namespace NHibernate.Test.NHSpecificTest.NH3079 +{ + public class PersonNoComponent + { + public virtual int IdA { get; set; } + + public virtual int IdB { get; set; } + + public virtual string Name { get; set; } + + public override bool Equals(object obj) + { + if (!(obj is PersonNoComponent objNoComponent)) + return false; + + return IdA == objNoComponent.IdA && IdB == objNoComponent.IdB; + } + + public override int GetHashCode() + { + return IdA.GetHashCode() ^ IdB.GetHashCode(); + } + } +} diff --git a/src/NHibernate/Async/Impl/SqlQueryImpl.cs b/src/NHibernate/Async/Impl/SqlQueryImpl.cs index ec46826430a..f4a0342ce37 100644 --- a/src/NHibernate/Async/Impl/SqlQueryImpl.cs +++ b/src/NHibernate/Async/Impl/SqlQueryImpl.cs @@ -11,6 +11,7 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Linq; using NHibernate.Engine; using NHibernate.Engine.Query; using NHibernate.Engine.Query.Sql; @@ -98,6 +99,7 @@ public partial class SqlQueryImpl : AbstractQueryImpl, ISQLQuery, ISynchronizabl Before(); try { + ComputeFlattenedParameters(); return await (Session.ExecuteNativeUpdateAsync(GenerateQuerySpecification(namedParams), GetQueryParameters(namedParams), cancellationToken)).ConfigureAwait(false); } finally diff --git a/src/NHibernate/Impl/AbstractQueryImpl.cs b/src/NHibernate/Impl/AbstractQueryImpl.cs index cd078ca7dd3..9ff4c712b0d 100644 --- a/src/NHibernate/Impl/AbstractQueryImpl.cs +++ b/src/NHibernate/Impl/AbstractQueryImpl.cs @@ -79,11 +79,11 @@ protected internal virtual void VerifyParameters() } /// - /// Perform parameter validation. Used prior to executing the encapsulated query. + /// Perform parameters validation. Flatten them if needed. Used prior to executing the encapsulated query. /// /// - /// if true, the first ? will not be verified since - /// its needed for e.g. callable statements returning a out parameter + /// If true, the first positional parameter will not be verified since + /// its needed for e.g. callable statements returning an out parameter. /// protected internal virtual void VerifyParameters(bool reserveFirstParameter) { @@ -95,11 +95,15 @@ protected internal virtual void VerifyParameters(bool reserveFirstParameter) throw new QueryException("Not all named parameters have been set: " + CollectionPrinter.ToString(missingParams), QueryString); } - int positionalValueSpan = 0; - for (int i = 0; i < values.Count; i++) + var positionalValueSpan = 0; + // Values and Types may be overriden to yield refined parameters, check them + // instead of the fields. + var values = Values; + var types = Types; + for (var i = 0; i < values.Count; i++) { - object obj = types[i]; - if (values[i] == UNSET_PARAMETER || obj == UNSET_TYPE) + var type = types[i]; + if (values[i] == UNSET_PARAMETER || type == UNSET_TYPE) { if (reserveFirstParameter && i == 0) { @@ -813,12 +817,13 @@ protected IDictionary NamedParameterLists get { return namedParameterLists; } } - protected IList Values + // TODO 6.0: Change type to IList + protected virtual IList Values { get { return values; } } - protected IList Types + protected virtual IList Types { get { return types; } } diff --git a/src/NHibernate/Impl/SqlQueryImpl.cs b/src/NHibernate/Impl/SqlQueryImpl.cs index 079c390cde1..5629e2f714a 100644 --- a/src/NHibernate/Impl/SqlQueryImpl.cs +++ b/src/NHibernate/Impl/SqlQueryImpl.cs @@ -1,6 +1,7 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Linq; using NHibernate.Engine; using NHibernate.Engine.Query; using NHibernate.Engine.Query.Sql; @@ -29,6 +30,8 @@ public partial class SqlQueryImpl : AbstractQueryImpl, ISQLQuery, ISynchronizabl private readonly bool callable; private bool autoDiscoverTypes; private readonly HashSet addedQuerySpaces = new HashSet(); + private List _flattenedTypes; + private List _flattenedValues; /// Constructs a SQLQueryImpl given a sql query defined in the mappings. /// The representation of the defined sql-query. @@ -298,6 +301,66 @@ protected internal override void VerifyParameters() } } + /// + protected internal override void VerifyParameters(bool reserveFirstParameter) + { + ComputeFlattenedParameters(); + base.VerifyParameters(reserveFirstParameter); + } + + // Flattening parameters is required for custom SQL loaders when entities have composite ids. + // See NH-3079 (#1117) + private void ComputeFlattenedParameters() + { + _flattenedTypes = new List(base.Types.Count * 2); + _flattenedValues = new List(base.Types.Count * 2); + FlattenTypesAndValues(base.Types, base.Values); + + void FlattenTypesAndValues(IList types, IList values) + { + for (var i = 0; i < types.Count; i++) + { + var type = types[i]; + var value = values[i]; + if (type is EntityType entityType) + { + type = entityType.GetIdentifierType(session); + value = entityType.GetIdentifier(value, session); + } + + if (type is IAbstractComponentType componentType) + { + FlattenTypesAndValues( + componentType.Subtypes, + componentType.GetPropertyValues(value, session)); + } + else + { + _flattenedTypes.Add(type); + _flattenedValues.Add(value); + } + } + } + } + + protected override IList Values => _flattenedValues ?? + throw new InvalidOperationException("Flattened parameters have not been computed"); + + protected override IList Types => _flattenedTypes ?? + throw new InvalidOperationException("Flattened parameters have not been computed"); + + public override object[] ValueArray() + { + // TODO 6.0: Change to Values.ToArray() + return _flattenedValues?.ToArray() ?? + throw new InvalidOperationException("Flattened parameters have not been computed"); + } + + public override IType[] TypeArray() + { + return Types.ToArray(); + } + public override IQuery SetLockMode(string alias, LockMode lockMode) { throw new NotSupportedException("cannot set the lock mode for a native SQL query"); @@ -309,6 +372,7 @@ public override int ExecuteUpdate() Before(); try { + ComputeFlattenedParameters(); return Session.ExecuteNativeUpdate(GenerateQuerySpecification(namedParams), GetQueryParameters(namedParams)); } finally From aff08719fe1ae64df27906faea7461e1bcf1aa1b Mon Sep 17 00:00:00 2001 From: Roman Artiukhin Date: Wed, 1 Apr 2020 01:50:33 +0300 Subject: [PATCH 22/43] Force join for comparisons in hql when not null entity key can represent null entity (#2081) Co-authored-by: Alexander Zaytsev --- .../Async/Hql/EntityJoinHqlTest.cs | 153 ++++++++++++++++- src/NHibernate.Test/Hql/EntityJoinHqlTest.cs | 154 +++++++++++++++++- .../Hql/EntityJoinHqlTestEntities.cs | 26 ++- src/NHibernate/Hql/Ast/ANTLR/Tree/DotNode.cs | 13 +- .../Hql/Ast/ANTLR/Tree/FromElement.cs | 21 ++- 5 files changed, 350 insertions(+), 17 deletions(-) diff --git a/src/NHibernate.Test/Async/Hql/EntityJoinHqlTest.cs b/src/NHibernate.Test/Async/Hql/EntityJoinHqlTest.cs index dcca7fc2924..1222dc245f0 100644 --- a/src/NHibernate.Test/Async/Hql/EntityJoinHqlTest.cs +++ b/src/NHibernate.Test/Async/Hql/EntityJoinHqlTest.cs @@ -153,6 +153,124 @@ public async Task EntityJoinFoSubqueryAsync() } } + [Test] + public async Task EntityJoinWithNullableOneToOneEntityComparisonInWithClausShouldAddJoinAsync() + { + using (var sqlLog = new SqlLogSpy()) + using (var session = OpenSession()) + { + var entity = + await (session + .CreateQuery( + "select ex " + + "from NullableOwner ex " + + "left join OneToOneEntity st with st = ex.OneToOne " + ).SetMaxResults(1) + .UniqueResultAsync()); + + Assert.That(Regex.Matches(sqlLog.GetWholeLog(), "OneToOneEntity").Count, Is.EqualTo(2)); + Assert.That(sqlLog.Appender.GetEvents().Length, Is.EqualTo(1), "Only one SQL select is expected"); + } + } + + [Test] + public async Task NullableOneToOneWhereEntityIsNotNullAsync() + { + using (var sqlLog = new SqlLogSpy()) + using (var session = OpenSession()) + { + var entity = + await (session + .CreateQuery( + "select ex " + + "from NullableOwner ex " + + "where ex.OneToOne is not null " + ).SetMaxResults(1) + .UniqueResultAsync()); + + Assert.That(Regex.Matches(sqlLog.GetWholeLog(), "OneToOneEntity").Count, Is.EqualTo(1)); + Assert.That(sqlLog.Appender.GetEvents().Length, Is.EqualTo(1), "Only one SQL select is expected"); + } + } + + [Test] + public async Task NullableOneToOneWhereIdIsNotNullAsync() + { + using (var sqlLog = new SqlLogSpy()) + using (var session = OpenSession()) + { + var entity = + await (session + .CreateQuery( + "select ex " + + "from NullableOwner ex " + + "where ex.OneToOne.Id is not null " + ).SetMaxResults(1) + .UniqueResultAsync()); + + Assert.That(Regex.Matches(sqlLog.GetWholeLog(), "OneToOneEntity").Count, Is.EqualTo(1)); + Assert.That(sqlLog.Appender.GetEvents().Length, Is.EqualTo(1), "Only one SQL select is expected"); + } + } + + [Test] + public async Task NullablePropRefWhereIdEntityNotNullShouldAddJoinAsync() + { + using (var sqlLog = new SqlLogSpy()) + using (var session = OpenSession()) + { + var entity = + await (session + .CreateQuery( + "select ex " + + "from NullableOwner ex " + + "where ex.PropRef is not null " + ).SetMaxResults(1) + .UniqueResultAsync()); + + Assert.That(Regex.Matches(sqlLog.GetWholeLog(), "PropRefEntity").Count, Is.EqualTo(1)); + Assert.That(sqlLog.Appender.GetEvents().Length, Is.EqualTo(1), "Only one SQL select is expected"); + } + } + + [Test] + public async Task NullableOneToOneFetchQueryIsNotAffectedAsync() + { + using (var sqlLog = new SqlLogSpy()) + using (var session = OpenSession()) + { + var entity = + await (session + .CreateQuery( + "select ex " + + "from NullableOwner ex left join fetch ex.OneToOne o " + + "where o is null " + ).SetMaxResults(1) + .UniqueResultAsync()); + + Assert.That(Regex.Matches(sqlLog.GetWholeLog(), "OneToOneEntity").Count, Is.EqualTo(1)); + } + } + + [Test] + public async Task NullableOneToOneFetchQueryIsNotAffected2Async() + { + using (var sqlLog = new SqlLogSpy()) + using (var session = OpenSession()) + { + var entity = + await (session + .CreateQuery( + "select ex " + + "from NullableOwner ex left join fetch ex.OneToOne o " + + "where o.Id is null " + ).SetMaxResults(1) + .UniqueResultAsync()); + + Assert.That(Regex.Matches(sqlLog.GetWholeLog(), "OneToOneEntity").Count, Is.EqualTo(1)); + } + } + [Test] public async Task EntityJoinWithEntityComparisonInWithClausShouldNotAddJoinAsync() { @@ -329,7 +447,7 @@ protected override HbmMapping GetMappings() rc.Property(e => e.Name); }); - + mapper.Class( rc => { @@ -367,6 +485,38 @@ protected override HbmMapping GetMappings() rc.Property(e => e.Name); }); + mapper.Class( + rc => + { + rc.Id(e => e.Id, m => m.Generator(Generators.GuidComb)); + rc.Property(e => e.Name); + }); + + mapper.Class( + rc => + { + rc.Id(e => e.Id, m => m.Generator(Generators.GuidComb)); + rc.Property(e => e.Name); + rc.Property(e => e.PropertyRef); + }); + + mapper.Class( + rc => + { + rc.Id(e => e.Id, m => m.Generator(Generators.GuidComb)); + rc.Property(e => e.Name); + rc.OneToOne(e => e.OneToOne, m => m.Constrained(false)); + rc.ManyToOne( + e => e.PropRef, + m => + { + m.PropertyRef(nameof(PropRefEntity.PropertyRef)); + m.ForeignKey("none"); + m.NotFound(NotFoundMode.Ignore); + }); + }); + + return mapper.CompileMappingForAllExplicitlyAddedEntities(); } @@ -431,7 +581,6 @@ protected override void OnSetUp() Composite1Key2 = _entityWithCompositeId.Key.Id2, CustomEntityNameId = _entityWithCustomEntityName.Id }; - session.Save(_noAssociation); session.Flush(); diff --git a/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs b/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs index fcd64564947..f1b82cb1da6 100644 --- a/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs +++ b/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs @@ -145,6 +145,124 @@ public void EntityJoinFoSubquery() } } + [Test] + public void EntityJoinWithNullableOneToOneEntityComparisonInWithClausShouldAddJoin() + { + using (var sqlLog = new SqlLogSpy()) + using (var session = OpenSession()) + { + var entity = + session + .CreateQuery( + "select ex " + + "from NullableOwner ex " + + "left join OneToOneEntity st with st = ex.OneToOne " + ).SetMaxResults(1) + .UniqueResult(); + + Assert.That(Regex.Matches(sqlLog.GetWholeLog(), "OneToOneEntity").Count, Is.EqualTo(2)); + Assert.That(sqlLog.Appender.GetEvents().Length, Is.EqualTo(1), "Only one SQL select is expected"); + } + } + + [Test] + public void NullableOneToOneWhereEntityIsNotNull() + { + using (var sqlLog = new SqlLogSpy()) + using (var session = OpenSession()) + { + var entity = + session + .CreateQuery( + "select ex " + + "from NullableOwner ex " + + "where ex.OneToOne is not null " + ).SetMaxResults(1) + .UniqueResult(); + + Assert.That(Regex.Matches(sqlLog.GetWholeLog(), "OneToOneEntity").Count, Is.EqualTo(1)); + Assert.That(sqlLog.Appender.GetEvents().Length, Is.EqualTo(1), "Only one SQL select is expected"); + } + } + + [Test] + public void NullableOneToOneWhereIdIsNotNull() + { + using (var sqlLog = new SqlLogSpy()) + using (var session = OpenSession()) + { + var entity = + session + .CreateQuery( + "select ex " + + "from NullableOwner ex " + + "where ex.OneToOne.Id is not null " + ).SetMaxResults(1) + .UniqueResult(); + + Assert.That(Regex.Matches(sqlLog.GetWholeLog(), "OneToOneEntity").Count, Is.EqualTo(1)); + Assert.That(sqlLog.Appender.GetEvents().Length, Is.EqualTo(1), "Only one SQL select is expected"); + } + } + + [Test] + public void NullablePropRefWhereIdEntityNotNullShouldAddJoin() + { + using (var sqlLog = new SqlLogSpy()) + using (var session = OpenSession()) + { + var entity = + session + .CreateQuery( + "select ex " + + "from NullableOwner ex " + + "where ex.PropRef is not null " + ).SetMaxResults(1) + .UniqueResult(); + + Assert.That(Regex.Matches(sqlLog.GetWholeLog(), "PropRefEntity").Count, Is.EqualTo(1)); + Assert.That(sqlLog.Appender.GetEvents().Length, Is.EqualTo(1), "Only one SQL select is expected"); + } + } + + [Test] + public void NullableOneToOneFetchQueryIsNotAffected() + { + using (var sqlLog = new SqlLogSpy()) + using (var session = OpenSession()) + { + var entity = + session + .CreateQuery( + "select ex " + + "from NullableOwner ex left join fetch ex.OneToOne o " + + "where o is null " + ).SetMaxResults(1) + .UniqueResult(); + + Assert.That(Regex.Matches(sqlLog.GetWholeLog(), "OneToOneEntity").Count, Is.EqualTo(1)); + } + } + + [Test] + public void NullableOneToOneFetchQueryIsNotAffected2() + { + using (var sqlLog = new SqlLogSpy()) + using (var session = OpenSession()) + { + var entity = + session + .CreateQuery( + "select ex " + + "from NullableOwner ex left join fetch ex.OneToOne o " + + "where o.Id is null " + ).SetMaxResults(1) + .UniqueResult(); + + Assert.That(Regex.Matches(sqlLog.GetWholeLog(), "OneToOneEntity").Count, Is.EqualTo(1)); + } + } + [Test] public void EntityJoinWithEntityComparisonInWithClausShouldNotAddJoin() { @@ -374,7 +492,7 @@ protected override HbmMapping GetMappings() rc.Property(e => e.Name); }); - + mapper.Class( rc => { @@ -412,6 +530,39 @@ protected override HbmMapping GetMappings() rc.Property(e => e.Name); }); + mapper.Class( + rc => + { + rc.Id(e => e.Id, m => m.Generator(Generators.GuidComb)); + rc.Property(e => e.Name); + }); + + mapper.Class( + rc => + { + rc.Id(e => e.Id, m => m.Generator(Generators.GuidComb)); + rc.Property(e => e.Name); + rc.Property(e => e.PropertyRef, m => m.Column("EntityPropertyRef")); + }); + + mapper.Class( + rc => + { + rc.Id(e => e.Id, m => m.Generator(Generators.GuidComb)); + rc.Property(e => e.Name); + rc.OneToOne(e => e.OneToOne, m => m.Constrained(false)); + rc.ManyToOne( + e => e.PropRef, + m => + { + m.Column("OwnerPropertyRef"); + m.PropertyRef(nameof(PropRefEntity.PropertyRef)); + m.ForeignKey("none"); + m.NotFound(NotFoundMode.Ignore); + }); + }); + + Node.AddMapping(mapper); UserEntityVisit.AddMapping(mapper); @@ -480,7 +631,6 @@ protected override void OnSetUp() Composite1Key2 = _entityWithCompositeId.Key.Id2, CustomEntityNameId = _entityWithCustomEntityName.Id }; - session.Save(_noAssociation); session.Flush(); diff --git a/src/NHibernate.Test/Hql/EntityJoinHqlTestEntities.cs b/src/NHibernate.Test/Hql/EntityJoinHqlTestEntities.cs index 2b897d81b6a..277b0ea706b 100644 --- a/src/NHibernate.Test/Hql/EntityJoinHqlTestEntities.cs +++ b/src/NHibernate.Test/Hql/EntityJoinHqlTestEntities.cs @@ -5,22 +5,40 @@ namespace NHibernate.Test.Hql.EntityJoinHqlTestEntities public class EntityComplex { public virtual Guid Id { get; set; } - public virtual int Version { get; set; } - public virtual string Name { get; set; } - public virtual string LazyProp { get; set; } - public virtual EntityComplex SameTypeChild { get; set; } public virtual EntityComplex SameTypeChild2 { get; set; } } + public class OneToOneEntity + { + public virtual Guid Id { get; set; } + public virtual string Name { get; set; } + } + + public class PropRefEntity + { + public virtual Guid Id { get; set; } + public virtual string Name { get; set; } + public virtual string PropertyRef { get; set; } + } + + public class NullableOwner + { + public virtual Guid Id { get; set; } + public virtual string Name { get; set; } + public virtual OneToOneEntity OneToOne { get; set; } + public virtual PropRefEntity PropRef { get; set; } + } + public class EntityWithCompositeId { public virtual CompositeKey Key { get; set; } public virtual string Name { get; set; } } + public class CompositeKey { public int Id1 { get; set; } diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/DotNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/DotNode.cs index e1ed1bce5e9..bbe0a91b740 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/DotNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/DotNode.cs @@ -386,6 +386,9 @@ private void DereferenceEntity(EntityType entityType, bool implicitJoin, string string property = _propertyName; bool joinIsNeeded; + //For nullable entity comparisons we always need to add join (like not constrained one-to-one or not-found ignore associations) + bool comparisonWithNullableEntity = false; + if ( IsDotNode( parent ) ) { // our parent is another dot node, meaning we are being further dereferenced. @@ -393,7 +396,7 @@ private void DereferenceEntity(EntityType entityType, bool implicitJoin, string // entity's PK (because 'our' table would know the FK). parentAsDotNode = ( DotNode ) parent; property = parentAsDotNode._propertyName; - joinIsNeeded = generateJoin && !IsReferenceToPrimaryKey( parentAsDotNode._propertyName, entityType ); + joinIsNeeded = generateJoin && (entityType.IsNullable || !IsReferenceToPrimaryKey( parentAsDotNode._propertyName, entityType )); } else if ( ! Walker.IsSelectStatement ) { @@ -406,12 +409,18 @@ private void DereferenceEntity(EntityType entityType, bool implicitJoin, string } else { - joinIsNeeded = generateJoin || ((Walker.IsInSelect && !Walker.IsInCase) || (Walker.IsInFrom && !Walker.IsComparativeExpressionClause)); + comparisonWithNullableEntity = (Walker.IsComparativeExpressionClause && entityType.IsNullable); + joinIsNeeded = generateJoin || (Walker.IsInSelect && !Walker.IsInCase) || (Walker.IsInFrom && !Walker.IsComparativeExpressionClause) + || comparisonWithNullableEntity; } if ( joinIsNeeded ) { DereferenceEntityJoin( classAlias, entityType, implicitJoin, parent ); + if (comparisonWithNullableEntity) + { + _columns = FromElement.GetIdentityColumns(); + } } else { diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/FromElement.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/FromElement.cs index b84a43b23e6..fa4796baebc 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/FromElement.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/FromElement.cs @@ -482,6 +482,19 @@ public IType GetPropertyType(string propertyName, string propertyPath) } public virtual string GetIdentityColumn() + { + var cols = GetIdentityColumns(); + string result = string.Join(", ", cols); + + if (cols.Length > 1 && Walker.IsComparativeExpressionClause) + { + return "(" + result + ")"; + } + + return result; + } + + internal string[] GetIdentityColumns() { CheckInitialized(); string table = TableAlias; @@ -513,14 +526,8 @@ public virtual string GetIdentityColumn() { cols = GetPropertyMapping(propertyName).ToColumns(propertyName); } - string result = string.Join(", ", cols); - if (cols.Length > 1 && Walker.IsComparativeExpressionClause) - { - return "(" + result + ")"; - } - - return result; + return cols; } public void HandlePropertyBeingDereferenced(IType propertySource, string propertyName) From 2b73276888ea02878830b673c18739d52009e827 Mon Sep 17 00:00:00 2001 From: maca88 Date: Wed, 1 Apr 2020 21:25:30 +0200 Subject: [PATCH 23/43] Reduce cast usage for COUNT aggregate and add support for Mssql count_big (#2061) Co-authored-by: Alexander Zaytsev --- .../Async/Linq/ByMethod/CountTests.cs | 46 +++++++++++++++++++ .../Linq/ByMethod/CountTests.cs | 46 +++++++++++++++++++ src/NHibernate/Dialect/Dialect.cs | 1 + .../Function/ClassicAggregateFunction.cs | 15 ++++-- .../Dialect/Function/ISQLFunction.cs | 43 +++++++++++++++++ .../Dialect/Function/ISQLFunctionExtended.cs | 27 +++++++++++ src/NHibernate/Dialect/MsSql2000Dialect.cs | 13 ++++-- src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs | 6 +++ src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.g | 4 +- .../Hql/Ast/ANTLR/Tree/AggregateNode.cs | 15 ++++++ .../Hql/Ast/ANTLR/Tree/CountNode.cs | 7 +-- src/NHibernate/Hql/Ast/HqlTreeBuilder.cs | 5 ++ src/NHibernate/Hql/Ast/HqlTreeNode.cs | 13 ++++++ .../Visitors/HqlGeneratorExpressionVisitor.cs | 34 ++++++++------ src/NHibernate/Util/ExpressionsHelper.cs | 23 ++++++++++ 15 files changed, 270 insertions(+), 28 deletions(-) create mode 100644 src/NHibernate/Dialect/Function/ISQLFunctionExtended.cs diff --git a/src/NHibernate.Test/Async/Linq/ByMethod/CountTests.cs b/src/NHibernate.Test/Async/Linq/ByMethod/CountTests.cs index 8dfdc304aae..5581badf6c8 100644 --- a/src/NHibernate.Test/Async/Linq/ByMethod/CountTests.cs +++ b/src/NHibernate.Test/Async/Linq/ByMethod/CountTests.cs @@ -11,6 +11,7 @@ using System; using System.Linq; using NHibernate.Cfg; +using NHibernate.Dialect; using NUnit.Framework; using NHibernate.Linq; @@ -110,5 +111,50 @@ into temp Assert.That(result.Count, Is.EqualTo(77)); } + + [Test] + public async Task CheckSqlFunctionNameLongCountAsync() + { + var name = Dialect is MsSql2000Dialect ? "count_big" : "count"; + using (var sqlLog = new SqlLogSpy()) + { + var result = await (db.Orders.LongCountAsync()); + Assert.That(result, Is.EqualTo(830)); + + var log = sqlLog.GetWholeLog(); + Assert.That(log, Does.Contain($"{name}(")); + } + } + + [Test] + public async Task CheckSqlFunctionNameForCountAsync() + { + using (var sqlLog = new SqlLogSpy()) + { + var result = await (db.Orders.CountAsync()); + Assert.That(result, Is.EqualTo(830)); + + var log = sqlLog.GetWholeLog(); + Assert.That(log, Does.Contain("count(")); + } + } + + [Test] + public async Task CheckMssqlCountCastAsync() + { + if (!(Dialect is MsSql2000Dialect)) + { + Assert.Ignore(); + } + + using (var sqlLog = new SqlLogSpy()) + { + var result = await (db.Orders.CountAsync()); + Assert.That(result, Is.EqualTo(830)); + + var log = sqlLog.GetWholeLog(); + Assert.That(log, Does.Not.Contain("cast(")); + } + } } } diff --git a/src/NHibernate.Test/Linq/ByMethod/CountTests.cs b/src/NHibernate.Test/Linq/ByMethod/CountTests.cs index 3bb93c7a083..7ef2d7dbbe6 100644 --- a/src/NHibernate.Test/Linq/ByMethod/CountTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/CountTests.cs @@ -1,6 +1,7 @@ using System; using System.Linq; using NHibernate.Cfg; +using NHibernate.Dialect; using NUnit.Framework; namespace NHibernate.Test.Linq.ByMethod @@ -98,5 +99,50 @@ into temp Assert.That(result.Count, Is.EqualTo(77)); } + + [Test] + public void CheckSqlFunctionNameLongCount() + { + var name = Dialect is MsSql2000Dialect ? "count_big" : "count"; + using (var sqlLog = new SqlLogSpy()) + { + var result = db.Orders.LongCount(); + Assert.That(result, Is.EqualTo(830)); + + var log = sqlLog.GetWholeLog(); + Assert.That(log, Does.Contain($"{name}(")); + } + } + + [Test] + public void CheckSqlFunctionNameForCount() + { + using (var sqlLog = new SqlLogSpy()) + { + var result = db.Orders.Count(); + Assert.That(result, Is.EqualTo(830)); + + var log = sqlLog.GetWholeLog(); + Assert.That(log, Does.Contain("count(")); + } + } + + [Test] + public void CheckMssqlCountCast() + { + if (!(Dialect is MsSql2000Dialect)) + { + Assert.Ignore(); + } + + using (var sqlLog = new SqlLogSpy()) + { + var result = db.Orders.Count(); + Assert.That(result, Is.EqualTo(830)); + + var log = sqlLog.GetWholeLog(); + Assert.That(log, Does.Not.Contain("cast(")); + } + } } } diff --git a/src/NHibernate/Dialect/Dialect.cs b/src/NHibernate/Dialect/Dialect.cs index af1995dca06..53c30b55175 100644 --- a/src/NHibernate/Dialect/Dialect.cs +++ b/src/NHibernate/Dialect/Dialect.cs @@ -55,6 +55,7 @@ public abstract partial class Dialect static Dialect() { StandardAggregateFunctions["count"] = new CountQueryFunctionInfo(); + StandardAggregateFunctions["count_big"] = new CountQueryFunctionInfo(); StandardAggregateFunctions["avg"] = new AvgQueryFunctionInfo(); StandardAggregateFunctions["max"] = new ClassicAggregateFunction("max", false); StandardAggregateFunctions["min"] = new ClassicAggregateFunction("min", false); diff --git a/src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs b/src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs index 299605adb03..e0b78f1e1e2 100644 --- a/src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs +++ b/src/NHibernate/Dialect/Function/ClassicAggregateFunction.cs @@ -1,15 +1,15 @@ using System; using System.Collections; -using System.Text; +using System.Collections.Generic; +using System.Linq; using NHibernate.Engine; using NHibernate.SqlCommand; using NHibernate.Type; -using NHibernate.Util; namespace NHibernate.Dialect.Function { [Serializable] - public class ClassicAggregateFunction : ISQLFunction, IFunctionGrammar + public class ClassicAggregateFunction : ISQLFunction, IFunctionGrammar, ISQLFunctionExtended { private IType returnType = null; private readonly string name; @@ -45,6 +45,15 @@ public virtual IType ReturnType(IType columnType, IMapping mapping) return returnType ?? columnType; } + /// + public virtual IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) + { + return ReturnType(argumentTypes.FirstOrDefault(), mapping); + } + + /// + public string FunctionName => name; + public bool HasArguments { get { return true; } diff --git a/src/NHibernate/Dialect/Function/ISQLFunction.cs b/src/NHibernate/Dialect/Function/ISQLFunction.cs index 4433e5dc63c..5302625a01e 100644 --- a/src/NHibernate/Dialect/Function/ISQLFunction.cs +++ b/src/NHibernate/Dialect/Function/ISQLFunction.cs @@ -1,4 +1,6 @@ using System.Collections; +using System.Collections.Generic; +using System.Linq; using NHibernate.Engine; using NHibernate.SqlCommand; using NHibernate.Type; @@ -41,4 +43,45 @@ public interface ISQLFunction /// SQL fragment for the function. SqlString Render(IList args, ISessionFactoryImplementor factory); } + + // 6.0 TODO: Remove + internal static class SQLFunctionExtensions + { + /// + /// Get the type that will be effectively returned by the underlying database. + /// + /// The sql function. + /// The types of arguments. + /// The mapping for retrieving the argument sql types. + /// Whether to throw when the number of arguments is invalid or they are not supported. + /// The type returned by the underlying database or when the number of arguments + /// is invalid or they are not supported. + /// When is set to and the + /// number of arguments is invalid or they are not supported. + public static IType GetEffectiveReturnType( + this ISQLFunction sqlFunction, + IEnumerable argumentTypes, + IMapping mapping, + bool throwOnError) + { + if (!(sqlFunction is ISQLFunctionExtended extendedSqlFunction)) + { + try + { + return sqlFunction.ReturnType(argumentTypes.FirstOrDefault(), mapping); + } + catch (QueryException) + { + if (throwOnError) + { + throw; + } + + return null; + } + } + + return extendedSqlFunction.GetEffectiveReturnType(argumentTypes, mapping, throwOnError); + } + } } diff --git a/src/NHibernate/Dialect/Function/ISQLFunctionExtended.cs b/src/NHibernate/Dialect/Function/ISQLFunctionExtended.cs new file mode 100644 index 00000000000..e2db4747198 --- /dev/null +++ b/src/NHibernate/Dialect/Function/ISQLFunctionExtended.cs @@ -0,0 +1,27 @@ +using System.Collections.Generic; +using NHibernate.Engine; +using NHibernate.Type; + +namespace NHibernate.Dialect.Function +{ + // 6.0 TODO: Merge into ISQLFunction + internal interface ISQLFunctionExtended : ISQLFunction + { + /// + /// The function name or when multiple functions/operators/statements are used. + /// + string FunctionName { get; } + + /// + /// Get the type that will be effectively returned by the underlying database. + /// + /// The types of arguments. + /// The mapping for retrieving the argument sql types. + /// Whether to throw when the number of arguments is invalid or they are not supported. + /// The type returned by the underlying database or when the number of arguments + /// is invalid or they are not supported. + /// When is set to and the + /// number of arguments is invalid or they are not supported. + IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError); + } +} diff --git a/src/NHibernate/Dialect/MsSql2000Dialect.cs b/src/NHibernate/Dialect/MsSql2000Dialect.cs index 7acb3cb9b4f..35c760e85b3 100644 --- a/src/NHibernate/Dialect/MsSql2000Dialect.cs +++ b/src/NHibernate/Dialect/MsSql2000Dialect.cs @@ -286,7 +286,8 @@ protected virtual void RegisterKeywords() protected virtual void RegisterFunctions() { - RegisterFunction("count", new CountBigQueryFunction()); + RegisterFunction("count", new CountQueryFunction()); + RegisterFunction("count_big", new CountBigQueryFunction()); RegisterFunction("abs", new StandardSQLFunction("abs")); RegisterFunction("absval", new StandardSQLFunction("absval")); @@ -704,11 +705,15 @@ protected virtual string GetSelectExistingObject(string catalog, string schema, [Serializable] protected class CountBigQueryFunction : ClassicAggregateFunction { - public CountBigQueryFunction() : base("count_big", true) { } + public CountBigQueryFunction() : base("count_big", true, NHibernateUtil.Int64) { } + } - public override IType ReturnType(IType columnType, IMapping mapping) + [Serializable] + private class CountQueryFunction : CountQueryFunctionInfo + { + public override IType GetEffectiveReturnType(IEnumerable argumentTypes, IMapping mapping, bool throwOnError) { - return NHibernateUtil.Int64; + return NHibernateUtil.Int32; } } diff --git a/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs b/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs index 4f6f902891d..0f1c7a64603 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.cs @@ -310,6 +310,12 @@ private void EndFunctionTemplate(IASTNode m) } } + private void OutAggregateFunctionName(IASTNode m) + { + var aggregateNode = (AggregateNode) m; + Out(aggregateNode.FunctionName); + } + private void CommaBetweenParameters(String comma) { writer.CommaBetweenParameters(comma); diff --git a/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.g b/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.g index 73d04c792ec..21cb3c7521e 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.g +++ b/src/NHibernate/Hql/Ast/ANTLR/SqlGenerator.g @@ -150,7 +150,7 @@ selectExpr ; count - : ^(COUNT { Out("count("); } ( distinctOrAll ) ? countExpr { Out(")"); } ) + : ^(c=COUNT { OutAggregateFunctionName(c); Out("("); } ( distinctOrAll ) ? countExpr { Out(")"); } ) ; distinctOrAll @@ -344,7 +344,7 @@ caseExpr ; aggregate - : ^(a=AGGREGATE { Out(a); Out("("); } expr { Out(")"); } ) + : ^(a=AGGREGATE { OutAggregateFunctionName(a); Out("("); } expr { Out(")"); } ) ; diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs index d7dc2f9a0f0..a3887a56c01 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/AggregateNode.cs @@ -1,5 +1,6 @@ using System; using Antlr.Runtime; +using NHibernate.Dialect.Function; using NHibernate.Type; using NHibernate.Hql.Ast.ANTLR.Util; @@ -19,6 +20,19 @@ public AggregateNode(IToken token) { } + public string FunctionName + { + get + { + if (SessionFactoryHelper.FindSQLFunction(Text) is ISQLFunctionExtended sqlFunction) + { + return sqlFunction.FunctionName; + } + + return Text; + } + } + public override IType DataType { get @@ -31,6 +45,7 @@ public override IType DataType base.DataType = value; } } + public override void SetScalarColumnText(int i) { ColumnHelper.GenerateSingleScalarColumn(ASTFactory, this, i); diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs index 88ec9cc22fd..ef6d6272043 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/CountNode.cs @@ -1,4 +1,5 @@ using Antlr.Runtime; +using NHibernate.Dialect.Function; using NHibernate.Hql.Ast.ANTLR.Util; using NHibernate.Type; @@ -9,7 +10,7 @@ namespace NHibernate.Hql.Ast.ANTLR.Tree /// Author: josh /// Ported by: Steve Strong /// - class CountNode : AbstractSelectExpression, ISelectExpression + class CountNode : AggregateNode, ISelectExpression { public CountNode(IToken token) : base(token) { @@ -26,9 +27,5 @@ public override IType DataType base.DataType = value; } } - public override void SetScalarColumnText(int i) - { - ColumnHelper.GenerateSingleScalarColumn(ASTFactory, this, i); - } } } diff --git a/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs b/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs index bb208295afc..dd59f2ff0f0 100755 --- a/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs +++ b/src/NHibernate/Hql/Ast/HqlTreeBuilder.cs @@ -307,6 +307,11 @@ public HqlCount Count(HqlExpression child) return new HqlCount(_factory, child); } + public HqlCountBig CountBig(HqlExpression child) + { + return new HqlCountBig(_factory, child); + } + public HqlRowStar RowStar() { return new HqlRowStar(_factory); diff --git a/src/NHibernate/Hql/Ast/HqlTreeNode.cs b/src/NHibernate/Hql/Ast/HqlTreeNode.cs index 160739d7f16..8967174bde3 100755 --- a/src/NHibernate/Hql/Ast/HqlTreeNode.cs +++ b/src/NHibernate/Hql/Ast/HqlTreeNode.cs @@ -697,6 +697,19 @@ public HqlCount(IASTFactory factory, HqlExpression child) } } + public class HqlCountBig : HqlExpression + { + public HqlCountBig(IASTFactory factory) + : base(HqlSqlWalker.COUNT, "count_big", factory) + { + } + + public HqlCountBig(IASTFactory factory, HqlExpression child) + : base(HqlSqlWalker.COUNT, "count_big", factory, child) + { + } + } + public class HqlAs : HqlExpression { public HqlAs(IASTFactory factory, HqlExpression expression, System.Type type) diff --git a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs index 224af655a24..cd9cd49eadb 100644 --- a/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs +++ b/src/NHibernate/Linq/Visitors/HqlGeneratorExpressionVisitor.cs @@ -4,6 +4,7 @@ using System.Linq; using System.Linq.Expressions; using System.Runtime.CompilerServices; +using NHibernate.Dialect.Function; using NHibernate.Engine.Query; using NHibernate.Hql.Ast; using NHibernate.Hql.Ast.ANTLR; @@ -257,7 +258,22 @@ protected HqlTreeNode VisitNhAverage(NhAverageExpression expression) protected HqlTreeNode VisitNhCount(NhCountExpression expression) { - return _hqlTreeBuilder.Cast(_hqlTreeBuilder.Count(VisitExpression(expression.Expression).AsExpression()), expression.Type); + string functionName; + HqlExpression countHqlExpression; + if (expression is NhLongCountExpression) + { + functionName = "count_big"; + countHqlExpression = _hqlTreeBuilder.CountBig(VisitExpression(expression.Expression).AsExpression()); + } + else + { + functionName = "count"; + countHqlExpression = _hqlTreeBuilder.Count(VisitExpression(expression.Expression).AsExpression()); + } + + return IsCastRequired(functionName, expression.Expression, expression.Type) + ? (HqlTreeNode) _hqlTreeBuilder.Cast(countHqlExpression, expression.Type) + : _hqlTreeBuilder.TransparentCast(countHqlExpression, expression.Type); } protected HqlTreeNode VisitNhMin(NhMinExpression expression) @@ -595,7 +611,7 @@ private bool IsCastRequired(Expression expression, System.Type toType, out bool { existType = false; return toType != typeof(object) && - IsCastRequired(GetType(expression), TypeFactory.GetDefaultTypeFor(toType), out existType); + IsCastRequired(ExpressionsHelper.GetType(_parameters, expression), TypeFactory.GetDefaultTypeFor(toType), out existType); } private bool IsCastRequired(IType type, IType toType, out bool existType) @@ -639,7 +655,7 @@ private bool IsCastRequired(IType type, IType toType, out bool existType) private bool IsCastRequired(string sqlFunctionName, Expression argumentExpression, System.Type returnType) { - var argumentType = GetType(argumentExpression); + var argumentType = ExpressionsHelper.GetType(_parameters, argumentExpression); if (argumentType == null || returnType == typeof(object)) { return false; @@ -657,18 +673,8 @@ private bool IsCastRequired(string sqlFunctionName, Expression argumentExpressio return true; // Fallback to the old behavior } - var fnReturnType = sqlFunction.ReturnType(argumentType, _parameters.SessionFactory); + var fnReturnType = sqlFunction.GetEffectiveReturnType(new[] {argumentType}, _parameters.SessionFactory, false); return fnReturnType == null || IsCastRequired(fnReturnType, returnNhType, out _); } - - private IType GetType(Expression expression) - { - // Try to get the mapped type for the member as it may be a non default one - return expression.Type == typeof(object) - ? null - : (ExpressionsHelper.TryGetMappedType(_parameters.SessionFactory, expression, out var type, out _, out _, out _) - ? type - : TypeFactory.GetDefaultTypeFor(expression.Type)); - } } } diff --git a/src/NHibernate/Util/ExpressionsHelper.cs b/src/NHibernate/Util/ExpressionsHelper.cs index 86fb8d04bd3..376a42ccda0 100644 --- a/src/NHibernate/Util/ExpressionsHelper.cs +++ b/src/NHibernate/Util/ExpressionsHelper.cs @@ -28,6 +28,29 @@ public static MemberInfo DecodeMemberAccessExpression(Expressi return ((MemberExpression)expression.Body).Member; } + /// + /// Get the mapped type for the given expression. + /// + /// The query parameters. + /// The expression. + /// The mapped type of the expression or when the mapped type was not + /// found and the type is . + internal static IType GetType(VisitorParameters parameters, Expression expression) + { + if (expression is ConstantExpression constantExpression && + parameters.ConstantToParameterMap.TryGetValue(constantExpression, out var param)) + { + return param.Type; + } + + if (TryGetMappedType(parameters.SessionFactory, expression, out var type, out _, out _, out _)) + { + return type; + } + + return expression.Type == typeof(object) ? null : TypeFactory.HeuristicType(expression.Type); + } + /// /// Try to get the mapped nullability from the given expression. /// From 7a95a0fc4d2baf653df85269c482f4172e19f641 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Delaporte?= <12201973+fredericdelaporte@users.noreply.github.com> Date: Sat, 4 Apr 2020 20:14:31 +0200 Subject: [PATCH 24/43] Fix a missing async regeneration in #2081 --- src/NHibernate.Test/Async/Hql/EntityJoinHqlTest.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/NHibernate.Test/Async/Hql/EntityJoinHqlTest.cs b/src/NHibernate.Test/Async/Hql/EntityJoinHqlTest.cs index 1222dc245f0..0632a01771a 100644 --- a/src/NHibernate.Test/Async/Hql/EntityJoinHqlTest.cs +++ b/src/NHibernate.Test/Async/Hql/EntityJoinHqlTest.cs @@ -497,7 +497,7 @@ protected override HbmMapping GetMappings() { rc.Id(e => e.Id, m => m.Generator(Generators.GuidComb)); rc.Property(e => e.Name); - rc.Property(e => e.PropertyRef); + rc.Property(e => e.PropertyRef, m => m.Column("EntityPropertyRef")); }); mapper.Class( @@ -510,6 +510,7 @@ protected override HbmMapping GetMappings() e => e.PropRef, m => { + m.Column("OwnerPropertyRef"); m.PropertyRef(nameof(PropRefEntity.PropertyRef)); m.ForeignKey("none"); m.NotFound(NotFoundMode.Ignore); From 97c238c591af53b33f727843d12d6c1cff815469 Mon Sep 17 00:00:00 2001 From: Roman Artiukhin Date: Sun, 5 Apr 2020 20:17:46 +0300 Subject: [PATCH 25/43] Use entities prepared by Loader in hql select projections (#2082) Co-authored-by: Alexander Zaytsev --- .../GH2064/OneToOneSelectProjectionFixture.cs | 118 ++++++++++++++++++ .../NHSpecificTest/GH2064/OneToOneEntity.cs | 10 ++ .../GH2064/OneToOneSelectProjectionFixture.cs | 106 ++++++++++++++++ .../NHSpecificTest/GH2064/ParentEntity.cs | 11 ++ .../Async/Loader/Hql/QueryLoader.cs | 4 +- .../Hql/Ast/ANTLR/Tree/ConstructorNode.cs | 7 -- .../Hql/Ast/ANTLR/Tree/SelectClause.cs | 31 +++-- src/NHibernate/Loader/Hql/QueryLoader.cs | 6 +- 8 files changed, 275 insertions(+), 18 deletions(-) create mode 100644 src/NHibernate.Test/Async/NHSpecificTest/GH2064/OneToOneSelectProjectionFixture.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/GH2064/OneToOneEntity.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/GH2064/OneToOneSelectProjectionFixture.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/GH2064/ParentEntity.cs diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH2064/OneToOneSelectProjectionFixture.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH2064/OneToOneSelectProjectionFixture.cs new file mode 100644 index 00000000000..07b0adabd83 --- /dev/null +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH2064/OneToOneSelectProjectionFixture.cs @@ -0,0 +1,118 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System.Linq; +using NHibernate.Cfg.MappingSchema; +using NHibernate.Mapping.ByCode; +using NUnit.Framework; +using NHibernate.Linq; + +namespace NHibernate.Test.NHSpecificTest.GH2064 +{ + using System.Threading.Tasks; + [TestFixture] + public class OneToOneSelectProjectionFixtureAsync : TestCaseMappingByCode + { + protected override HbmMapping GetMappings() + { + var mapper = new ModelMapper(); + mapper.Class( + rc => + { + rc.Id(e => e.Id, m => m.Generator(Generators.Assigned)); + rc.Property(e => e.Name); + }); + + mapper.Class( + rc => + { + rc.Id(e => e.Id, m => m.Generator(Generators.GuidComb)); + rc.Property(e => e.Name); + rc.OneToOne(e => e.OneToOne, m => { }); + }); + + return mapper.CompileMappingForAllExplicitlyAddedEntities(); + } + + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var nullableOwner = new ParentEntity() {Name = "Owner",}; + var oneToOne = new OneToOneEntity() {Name = "OneToOne"}; + nullableOwner.OneToOne = oneToOne; + session.Save(nullableOwner); + oneToOne.Id = nullableOwner.Id; + session.Save(oneToOne); + session.Flush(); + + transaction.Commit(); + } + } + + protected override void OnTearDown() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + // The HQL delete does all the job inside the database without loading the entities, but it does + // not handle delete order for avoiding violating constraints if any. Use + // session.Delete("from System.Object"); + // instead if in need of having NHibernate ordering the deletes, but this will cause + // loading the entities in the session. + session.CreateQuery("delete from System.Object").ExecuteUpdate(); + + transaction.Commit(); + } + } + + [Test] + public async Task QueryOneToOneAsync() + { + using (var session = OpenSession()) + { + var entity = + await (session + .Query() + .FirstOrDefaultAsync()); + Assert.That(entity.OneToOne, Is.Not.Null); + } + } + + [Test] + public async Task QueryOneToOneProjectionAsync() + { + using (var session = OpenSession()) + { + var entity = + await (session + .Query() + .Select( + x => new + { + x.Id, + SubType = new {x.OneToOne, x.Name}, + SubType2 = new + { + x.Id, + x.OneToOne, + SubType3 = new {x.Id, x.OneToOne} + }, + x.OneToOne + }).FirstOrDefaultAsync()); + Assert.That(entity.OneToOne, Is.Not.Null); + Assert.That(entity.SubType.OneToOne, Is.Not.Null); + Assert.That(entity.SubType2.OneToOne, Is.Not.Null); + Assert.That(entity.SubType2.SubType3.OneToOne, Is.Not.Null); + } + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH2064/OneToOneEntity.cs b/src/NHibernate.Test/NHSpecificTest/GH2064/OneToOneEntity.cs new file mode 100644 index 00000000000..40673d98926 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH2064/OneToOneEntity.cs @@ -0,0 +1,10 @@ +using System; + +namespace NHibernate.Test.NHSpecificTest.GH2064 +{ + public class OneToOneEntity + { + public virtual Guid Id { get; set; } + public virtual string Name { get; set; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH2064/OneToOneSelectProjectionFixture.cs b/src/NHibernate.Test/NHSpecificTest/GH2064/OneToOneSelectProjectionFixture.cs new file mode 100644 index 00000000000..7d6acdc74e5 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH2064/OneToOneSelectProjectionFixture.cs @@ -0,0 +1,106 @@ +using System.Linq; +using NHibernate.Cfg.MappingSchema; +using NHibernate.Mapping.ByCode; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH2064 +{ + [TestFixture] + public class OneToOneSelectProjectionFixture : TestCaseMappingByCode + { + protected override HbmMapping GetMappings() + { + var mapper = new ModelMapper(); + mapper.Class( + rc => + { + rc.Id(e => e.Id, m => m.Generator(Generators.Assigned)); + rc.Property(e => e.Name); + }); + + mapper.Class( + rc => + { + rc.Id(e => e.Id, m => m.Generator(Generators.GuidComb)); + rc.Property(e => e.Name); + rc.OneToOne(e => e.OneToOne, m => { }); + }); + + return mapper.CompileMappingForAllExplicitlyAddedEntities(); + } + + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var nullableOwner = new ParentEntity() {Name = "Owner",}; + var oneToOne = new OneToOneEntity() {Name = "OneToOne"}; + nullableOwner.OneToOne = oneToOne; + session.Save(nullableOwner); + oneToOne.Id = nullableOwner.Id; + session.Save(oneToOne); + session.Flush(); + + transaction.Commit(); + } + } + + protected override void OnTearDown() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + // The HQL delete does all the job inside the database without loading the entities, but it does + // not handle delete order for avoiding violating constraints if any. Use + // session.Delete("from System.Object"); + // instead if in need of having NHibernate ordering the deletes, but this will cause + // loading the entities in the session. + session.CreateQuery("delete from System.Object").ExecuteUpdate(); + + transaction.Commit(); + } + } + + [Test] + public void QueryOneToOne() + { + using (var session = OpenSession()) + { + var entity = + session + .Query() + .FirstOrDefault(); + Assert.That(entity.OneToOne, Is.Not.Null); + } + } + + [Test] + public void QueryOneToOneProjection() + { + using (var session = OpenSession()) + { + var entity = + session + .Query() + .Select( + x => new + { + x.Id, + SubType = new {x.OneToOne, x.Name}, + SubType2 = new + { + x.Id, + x.OneToOne, + SubType3 = new {x.Id, x.OneToOne} + }, + x.OneToOne + }).FirstOrDefault(); + Assert.That(entity.OneToOne, Is.Not.Null); + Assert.That(entity.SubType.OneToOne, Is.Not.Null); + Assert.That(entity.SubType2.OneToOne, Is.Not.Null); + Assert.That(entity.SubType2.SubType3.OneToOne, Is.Not.Null); + } + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH2064/ParentEntity.cs b/src/NHibernate.Test/NHSpecificTest/GH2064/ParentEntity.cs new file mode 100644 index 00000000000..03325f531b0 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH2064/ParentEntity.cs @@ -0,0 +1,11 @@ +using System; + +namespace NHibernate.Test.NHSpecificTest.GH2064 +{ + public class ParentEntity + { + public virtual Guid Id { get; set; } + public virtual string Name { get; set; } + public virtual OneToOneEntity OneToOne { get; set; } + } +} diff --git a/src/NHibernate/Async/Loader/Hql/QueryLoader.cs b/src/NHibernate/Async/Loader/Hql/QueryLoader.cs index 9135a5b8481..83832ddf38b 100644 --- a/src/NHibernate/Async/Loader/Hql/QueryLoader.cs +++ b/src/NHibernate/Async/Loader/Hql/QueryLoader.cs @@ -76,7 +76,9 @@ protected override async Task GetResultRowAsync(object[] row, DbDataRe resultRow = new object[queryCols]; for (int i = 0; i < queryCols; i++) { - resultRow[i] = await (ResultTypes[i].NullSafeGetAsync(rs, scalarColumns[i], session, null, cancellationToken)).ConfigureAwait(false); + resultRow[i] = _entityByResultTypeDic.TryGetValue(i, out var rowIndex) + ? row[rowIndex] + : await (ResultTypes[i].NullSafeGetAsync(rs, scalarColumns[i], session, null, cancellationToken)).ConfigureAwait(false); } } else diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/ConstructorNode.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/ConstructorNode.cs index 925c0452f88..6155b54e670 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/ConstructorNode.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/ConstructorNode.cs @@ -133,13 +133,6 @@ public void Prepare() private IType[] ResolveConstructorArgumentTypes() { ISelectExpression[] argumentExpressions = CollectSelectExpressions(); - - if ( argumentExpressions == null ) - { - // return an empty Type array - return Array.Empty(); - } - IType[] types = new IType[argumentExpressions.Length]; for ( int x = 0; x < argumentExpressions.Length; x++ ) { diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/SelectClause.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/SelectClause.cs index dfa5cf67f18..188fc67f0db 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/SelectClause.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/SelectClause.cs @@ -23,6 +23,8 @@ public class SelectClause : SelectExpressionList private IType[] _queryReturnTypes; private string[][] _columnNames; private readonly List _fromElementsForLoad = new List(); + private readonly Dictionary _entityByResultTypeDic = new Dictionary(); + private ConstructorNode _constructorNode; private string[] _aliases; private int[] _columnNamesStartPositions; @@ -134,19 +136,20 @@ public void InitializeExplicitSelectClause(FromClause fromClause) if (expr.IsConstructor) { _constructorNode = (ConstructorNode)expr; - IList constructorArgumentTypeList = _constructorNode.ConstructorArgumentTypeList; //sqlResultTypeList.addAll( constructorArgumentTypeList ); - queryReturnTypeList.AddRange(constructorArgumentTypeList); _scalarSelect = true; - for (int j = 1; j < _constructorNode.ChildCount; j++) + var ctorSelectExpressions = _constructorNode.CollectSelectExpressions(); + for (int j = 0; j < ctorSelectExpressions.Length; j++) { - ISelectExpression se = _constructorNode.GetChild(j) as ISelectExpression; + ISelectExpression se = ctorSelectExpressions[j]; - if (se != null && IsReturnableEntity(se)) + if (IsReturnableEntity(se)) { - _fromElementsForLoad.Add(se.FromElement); + AddEntityToProjection(queryReturnTypeList.Count, se); } + + queryReturnTypeList.Add(se.DataType); } } else @@ -163,10 +166,9 @@ public void InitializeExplicitSelectClause(FromClause fromClause) { _scalarSelect = true; } - - if (IsReturnableEntity(expr)) + else if (IsReturnableEntity(expr)) { - _fromElementsForLoad.Add(expr.FromElement); + AddEntityToProjection(queryReturnTypeList.Count, expr); } // Always add the type to the return type list. @@ -247,6 +249,12 @@ public void InitializeExplicitSelectClause(FromClause fromClause) FinishInitialization( /*sqlResultTypeList,*/ queryReturnTypeList); } + private void AddEntityToProjection(int resultIndex, ISelectExpression se) + { + _entityByResultTypeDic[resultIndex] = _fromElementsForLoad.Count; + _fromElementsForLoad.Add(se.FromElement); + } + private static FromElement GetOrigin(FromElement fromElement) { var realOrigin = fromElement.RealOrigin; @@ -271,6 +279,11 @@ public IList FromElementsForLoad get { return _fromElementsForLoad; } } + /// + /// Maps QueryReturnTypes[key] to entities from FromElementsForLoad[value] + /// + internal IReadOnlyDictionary EntityByResultTypeDic => _entityByResultTypeDic; + public bool IsScalarSelect { get { return _scalarSelect; } diff --git a/src/NHibernate/Loader/Hql/QueryLoader.cs b/src/NHibernate/Loader/Hql/QueryLoader.cs index ad95fd5362e..95ce94e6693 100644 --- a/src/NHibernate/Loader/Hql/QueryLoader.cs +++ b/src/NHibernate/Loader/Hql/QueryLoader.cs @@ -47,6 +47,7 @@ public partial class QueryLoader : BasicLoader private IType[] _cacheTypes; private ISet _uncacheableCollectionPersisters; private Dictionary[] _collectionUserProvidedAliases; + private IReadOnlyDictionary _entityByResultTypeDic; public QueryLoader(QueryTranslatorImpl queryTranslator, ISessionFactoryImplementor factory, SelectClause selectClause) : base(factory) @@ -214,6 +215,7 @@ protected override IDictionary GetCollectionUserProvidedAlias( private void Initialize(SelectClause selectClause) { IList fromElementList = selectClause.FromElementsForLoad; + _entityByResultTypeDic = selectClause.EntityByResultTypeDic; _hasScalars = selectClause.IsScalarSelect; _scalarColumnNames = selectClause.ColumnNames; @@ -401,7 +403,9 @@ protected override object[] GetResultRow(object[] row, DbDataReader rs, ISession resultRow = new object[queryCols]; for (int i = 0; i < queryCols; i++) { - resultRow[i] = ResultTypes[i].NullSafeGet(rs, scalarColumns[i], session, null); + resultRow[i] = _entityByResultTypeDic.TryGetValue(i, out var rowIndex) + ? row[rowIndex] + : ResultTypes[i].NullSafeGet(rs, scalarColumns[i], session, null); } } else From e206809e71a6121b008f4228f108e99038fe350d Mon Sep 17 00:00:00 2001 From: Roman Artiukhin Date: Mon, 6 Apr 2020 22:01:31 +0300 Subject: [PATCH 26/43] Add support for caching fetched relations with Criteria (#2090) --- .../Northwind/Entities/Northwind.cs | 4 +- .../Northwind/Entities/NorthwindQueryOver.cs | 102 ++++++ .../ReadonlyTests/QueryOverCacheableTests.cs | 290 ++++++++++++++++++ .../Async/Linq/QueryCacheableTests.cs | 90 ++++++ .../CriteriaNorthwindReadonlyTestCase.cs | 67 ++++ .../ReadonlyTests/QueryOverCacheableTests.cs | 279 +++++++++++++++++ .../ReadonlyTests/ReadonlyFixtureSetUp.cs | 12 + .../Linq/QueryCacheableTests.cs | 90 ++++++ src/NHibernate/Async/Loader/Loader.cs | 1 - .../Cache/QueryCacheResultBuilder.cs | 58 +--- .../Hql/Ast/ANTLR/Tree/SelectClause.cs | 2 - .../Loader/Criteria/CriteriaLoader.cs | 6 +- src/NHibernate/Loader/Custom/CustomLoader.cs | 2 +- src/NHibernate/Loader/Hql/QueryLoader.cs | 22 +- src/NHibernate/Loader/Loader.cs | 60 +++- src/NHibernate/Loader/OuterJoinLoader.cs | 2 +- .../Transform/CacheableResultTransformer.cs | 46 ++- src/NHibernate/Util/ArrayHelper.cs | 9 + src/NHibernate/Util/EnumerableExtensions.cs | 5 + 19 files changed, 1062 insertions(+), 85 deletions(-) create mode 100644 src/NHibernate.DomainModel/Northwind/Entities/NorthwindQueryOver.cs create mode 100644 src/NHibernate.Test/Async/Criteria/ReadonlyTests/QueryOverCacheableTests.cs create mode 100644 src/NHibernate.Test/Criteria/ReadonlyTests/CriteriaNorthwindReadonlyTestCase.cs create mode 100644 src/NHibernate.Test/Criteria/ReadonlyTests/QueryOverCacheableTests.cs create mode 100644 src/NHibernate.Test/Criteria/ReadonlyTests/ReadonlyFixtureSetUp.cs diff --git a/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs b/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs index 5da2e35cadb..c4cbda23f26 100755 --- a/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs +++ b/src/NHibernate.DomainModel/Northwind/Entities/Northwind.cs @@ -94,9 +94,9 @@ public IQueryable Role get { return _session.Query(); } } - public IEnumerable IUsers + public IQueryable IUsers { get { return _session.Query(); } } } -} \ No newline at end of file +} diff --git a/src/NHibernate.DomainModel/Northwind/Entities/NorthwindQueryOver.cs b/src/NHibernate.DomainModel/Northwind/Entities/NorthwindQueryOver.cs new file mode 100644 index 00000000000..0910e58ce99 --- /dev/null +++ b/src/NHibernate.DomainModel/Northwind/Entities/NorthwindQueryOver.cs @@ -0,0 +1,102 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using NHibernate.Linq; + +namespace NHibernate.DomainModel.Northwind.Entities +{ + public class NorthwindQueryOver + { + private readonly ISession _session; + + public NorthwindQueryOver(ISession session) + { + _session = session; + } + + public IQueryOver Customers + { + get { return _session.QueryOver(); } + } + + public IQueryOver Products + { + get { return _session.QueryOver(); } + } + + public IQueryOver Shippers + { + get { return _session.QueryOver(); } + } + + public IQueryOver Orders + { + get { return _session.QueryOver(); } + } + + public IQueryOver OrderLines + { + get { return _session.QueryOver(); } + } + + public IQueryOver Employees + { + get { return _session.QueryOver(); } + } + + public IQueryOver Categories + { + get { return _session.QueryOver(); } + } + + public IQueryOver Timesheets + { + get { return _session.QueryOver(); } + } + + public IQueryOver Animals + { + get { return _session.QueryOver(); } + } + + public IQueryOver Mammals + { + get { return _session.QueryOver(); } + } + + public IQueryOver Users + { + get { return _session.QueryOver(); } + } + + public IQueryOver PatientRecords + { + get { return _session.QueryOver(); } + } + + public IQueryOver States + { + get { return _session.QueryOver(); } + } + + public IQueryOver Patients + { + get { return _session.QueryOver(); } + } + + public IQueryOver Physicians + { + get { return _session.QueryOver(); } + } + + public IQueryOver Role + { + get { return _session.QueryOver(); } + } + + public IQueryOver IUsers + { + get { return _session.QueryOver(); } + } + } +} diff --git a/src/NHibernate.Test/Async/Criteria/ReadonlyTests/QueryOverCacheableTests.cs b/src/NHibernate.Test/Async/Criteria/ReadonlyTests/QueryOverCacheableTests.cs new file mode 100644 index 00000000000..b3f6e51b233 --- /dev/null +++ b/src/NHibernate.Test/Async/Criteria/ReadonlyTests/QueryOverCacheableTests.cs @@ -0,0 +1,290 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System.Linq; +using NHibernate.Cfg; +using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.SqlCommand; +using NUnit.Framework; + +namespace NHibernate.Test.Criteria.ReadonlyTests +{ + using System.Threading.Tasks; + [TestFixture] + public class QueryOverCacheableTestsAsync : CriteriaNorthwindReadonlyTestCase + { + //Just for discoverability + private class CriteriaCacheableTest{} + + protected override void Configure(Configuration config) + { + config.SetProperty(Environment.UseQueryCache, "true"); + config.SetProperty(Environment.GenerateStatistics, "true"); + base.Configure(config); + } + + [Test] + public async Task QueryIsCacheableAsync() + { + Sfi.Statistics.Clear(); + await (Sfi.EvictQueriesAsync()); + + await (db.Customers.Cacheable().Take(1).ListAsync()); + await (db.Customers.Cacheable().Take(1).ListAsync()); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(1), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(1), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheHitCount, Is.EqualTo(1), "Unexpected cache hit count"); + } + + [Test] + public async Task QueryIsCacheable2Async() + { + Sfi.Statistics.Clear(); + await (Sfi.EvictQueriesAsync()); + + await (db.Customers.Cacheable().Take(1).ListAsync()); + await (db.Customers.Take(1).ListAsync()); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(2), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(1), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheHitCount, Is.EqualTo(0), "Unexpected cache hit count"); + } + + [Test] + public async Task QueryIsCacheableWithRegionAsync() + { + Sfi.Statistics.Clear(); + await (Sfi.EvictQueriesAsync()); + await (Sfi.EvictQueriesAsync("test")); + await (Sfi.EvictQueriesAsync("other")); + + await (db.Customers.Cacheable().Take(1).CacheRegion("test").ListAsync()); + await (db.Customers.Cacheable().Take(1).CacheRegion("test").ListAsync()); + await (db.Customers.Cacheable().Take(1).CacheRegion("other").ListAsync()); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(2), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(2), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheHitCount, Is.EqualTo(1), "Unexpected cache hit count"); + } + + [Test] + public async Task CanBeCombinedWithFetchAsync() + { + Sfi.Statistics.Clear(); + await (Sfi.EvictQueriesAsync()); + + await (db.Customers + .Cacheable() + .ListAsync()); + + await (db.Orders + .Cacheable() + .Take(1) + .ListAsync()); + + await (db.Customers + .Fetch(SelectMode.Fetch, x => x.Orders) + .Cacheable() + .Take(1) + .ListAsync()); + + await (db.Orders + .Fetch(SelectMode.Fetch, x => x.OrderLines) + .Cacheable() + .Take(1) + .ListAsync()); + + await (db.Customers + .Fetch(SelectMode.Fetch, x => x.Address) + .Where(x => x.CustomerId == "VINET") + .Cacheable() + .SingleOrDefaultAsync()); + + var customer = await (db.Customers + .Fetch(SelectMode.Fetch, x => x.Address) + .Where(x => x.CustomerId == "VINET") + .Cacheable() + .SingleOrDefaultAsync()); + + Assert.That(NHibernateUtil.IsInitialized(customer.Address), Is.True, "Expected the fetched Address to be initialized"); + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(5), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(5), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheHitCount, Is.EqualTo(1), "Unexpected cache hit count"); + } + + [Test] + public async Task FetchIsCacheableAsync() + { + Sfi.Statistics.Clear(); + await (Sfi.EvictQueriesAsync()); + + var order = (await (db.Orders + .Fetch( + SelectMode.Fetch, + x => x.Customer, + x => x.OrderLines, + x => x.OrderLines.First().Product, + x => x.OrderLines.First().Product.OrderLines) + .Where(x => x.OrderId == 10248) + .Cacheable() + .ListAsync())) + .First(); + + AssertFetchedOrder(order); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(1), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(1), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheMissCount, Is.EqualTo(1), "Unexpected cache miss count"); + + Sfi.Statistics.Clear(); + Session.Clear(); + + order = (await (db.Orders + .Fetch( + SelectMode.Fetch, + x => x.Customer, + x => x.OrderLines, + x => x.OrderLines.First().Product, + x => x.OrderLines.First().Product.OrderLines) + .Where(x => x.OrderId == 10248) + .Cacheable() + .ListAsync())) + .First(); + + AssertFetchedOrder(order); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(0), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(0), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheMissCount, Is.EqualTo(0), "Unexpected cache miss count"); + Assert.That(Sfi.Statistics.QueryCacheHitCount, Is.EqualTo(1), "Unexpected cache hit count"); + } + + [Test] + public async Task FetchIsCacheableForJoinAliasAsync() + { + Sfi.Statistics.Clear(); + await (Sfi.EvictQueriesAsync()); + + Customer customer = null; + OrderLine orderLines = null; + Product product = null; + OrderLine prOrderLines = null; + + var order = (await (db.Orders + .JoinAlias(x => x.Customer, () => customer) + .JoinAlias(x => x.OrderLines, () => orderLines, JoinType.LeftOuterJoin) + .JoinAlias(() => orderLines.Product, () => product) + .JoinAlias(() => product.OrderLines, () => prOrderLines, JoinType.LeftOuterJoin) + .Where(x => x.OrderId == 10248) + .Cacheable() + .ListAsync())) + .First(); + + AssertFetchedOrder(order); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(1), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(1), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheMissCount, Is.EqualTo(1), "Unexpected cache miss count"); + + Sfi.Statistics.Clear(); + Session.Clear(); + + order = (await (db.Orders + .JoinAlias(x => x.Customer, () => customer) + .JoinAlias(x => x.OrderLines, () => orderLines, JoinType.LeftOuterJoin) + .JoinAlias(() => orderLines.Product, () => product) + .JoinAlias(() => product.OrderLines, () => prOrderLines, JoinType.LeftOuterJoin) + .Where(x => x.OrderId == 10248) + .Cacheable() + .ListAsync())) + .First(); + + AssertFetchedOrder(order); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(0), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(0), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheMissCount, Is.EqualTo(0), "Unexpected cache miss count"); + Assert.That(Sfi.Statistics.QueryCacheHitCount, Is.EqualTo(1), "Unexpected cache hit count"); + } + + [Test] + public async Task FutureFetchIsCacheableAsync() + { + Sfi.Statistics.Clear(); + await (Sfi.EvictQueriesAsync()); + var multiQueries = Sfi.ConnectionProvider.Driver.SupportsMultipleQueries; + + db.Orders + .Fetch(SelectMode.Fetch, x => x.Customer) + .Where(x => x.OrderId == 10248) + .Cacheable() + .Future(); + + var order = db.Orders + .Fetch( + SelectMode.Fetch, + x => x.OrderLines, + x => x.OrderLines.First().Product, + x => x.OrderLines.First().Product.OrderLines) + .Where(x => x.OrderId == 10248) + .Cacheable() + .Future() + .ToList() + .First(); + + AssertFetchedOrder(order); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(multiQueries ? 1 : 2), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(2), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheMissCount, Is.EqualTo(2), "Unexpected cache miss count"); + + Sfi.Statistics.Clear(); + Session.Clear(); + + db.Orders + .Fetch(SelectMode.Fetch, x => x.Customer) + .Where(x => x.OrderId == 10248) + .Cacheable() + .Future(); + + order = db.Orders + .Fetch( + SelectMode.Fetch, + x => x.OrderLines, + x => x.OrderLines.First().Product, + x => x.OrderLines.First().Product.OrderLines) + .Where(x => x.OrderId == 10248) + .Cacheable() + .Future() + .ToList() + .First(); + + AssertFetchedOrder(order); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(0), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(0), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheMissCount, Is.EqualTo(0), "Unexpected cache miss count"); + Assert.That(Sfi.Statistics.QueryCacheHitCount, Is.EqualTo(2), "Unexpected cache hit count"); + } + + private static void AssertFetchedOrder(Order order) + { + Assert.That(order.Customer, Is.Not.Null, "Expected the fetched Customer to be not null"); + Assert.That(NHibernateUtil.IsInitialized(order.Customer), Is.True, "Expected the fetched Customer to be initialized"); + Assert.That(NHibernateUtil.IsInitialized(order.OrderLines), Is.True, "Expected the fetched OrderLines to be initialized"); + Assert.That(order.OrderLines, Has.Count.EqualTo(3), "Expected the fetched OrderLines to have 3 items"); + var orderLine = order.OrderLines.First(); + Assert.That(orderLine.Product, Is.Not.Null, "Expected the fetched Product to be not null"); + Assert.That(NHibernateUtil.IsInitialized(orderLine.Product), Is.True, "Expected the fetched Product to be initialized"); + Assert.That(NHibernateUtil.IsInitialized(orderLine.Product.OrderLines), Is.True, "Expected the fetched OrderLines to be initialized"); + } + } +} diff --git a/src/NHibernate.Test/Async/Linq/QueryCacheableTests.cs b/src/NHibernate.Test/Async/Linq/QueryCacheableTests.cs index 4abcd24c0af..54a8a6cfc44 100644 --- a/src/NHibernate.Test/Async/Linq/QueryCacheableTests.cs +++ b/src/NHibernate.Test/Async/Linq/QueryCacheableTests.cs @@ -12,6 +12,7 @@ using NHibernate.Cfg; using NHibernate.DomainModel.Northwind.Entities; using NHibernate.Linq; +using NHibernate.Transform; using NUnit.Framework; namespace NHibernate.Test.Linq @@ -406,13 +407,102 @@ public async Task FutureFetchIsCachableAsync() Assert.That(Sfi.Statistics.QueryCacheMissCount, Is.EqualTo(0), "Unexpected cache miss count"); Assert.That(Sfi.Statistics.QueryCacheHitCount, Is.EqualTo(2), "Unexpected cache hit count"); } + + [Explicit("Not working. dto.Customer retrieved from cache as uninitialized proxy")] + [Test] + public async Task ProjectedEntitiesAreCachableAsync() + { + Sfi.Statistics.Clear(); + await (Sfi.EvictQueriesAsync()); + var dto = await (session.Query() + .WithOptions(o => o.SetCacheable(true)) + .Where(x => x.OrderId == 10248) + .Select(x => new { x.Customer, Order = x }) + .FirstOrDefaultAsync()); + + Assert.That(dto, Is.Not.Null, "dto should not be null"); + Assert.That(dto.Order, Is.Not.Null, "dto.Order should not be null"); + Assert.That(NHibernateUtil.IsInitialized(dto.Order), Is.True, "dto.Order should be initialized"); + Assert.That(dto.Customer, Is.Not.Null, "dto.Customer should not be null"); + Assert.That(NHibernateUtil.IsInitialized(dto.Customer), Is.True, "dto.Customer from cache should be initialized"); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(1), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(1), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheMissCount, Is.EqualTo(1), "Unexpected cache miss count"); + + Sfi.Statistics.Clear(); + session.Clear(); + + dto = await (session.Query() + .WithOptions(o => o.SetCacheable(true)) + .Where(x => x.OrderId == 10248) + .Select(x => new { x.Customer, Order = x }) + .FirstOrDefaultAsync()); + + Assert.That(dto, Is.Not.Null, "dto from cache should not be null"); + Assert.That(dto.Order, Is.Not.Null, "dto.Order from cache should not be null"); + Assert.That(NHibernateUtil.IsInitialized(dto.Order), Is.True, "dto.Order from cache should be initialized"); + Assert.That(dto.Customer, Is.Not.Null, "dto.Customer from cache should not be null"); + Assert.That(NHibernateUtil.IsInitialized(dto.Customer), Is.True, "dto.Customer from cache should be initialized"); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(0), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(0), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheMissCount, Is.EqualTo(0), "Unexpected cache miss count"); + Assert.That(Sfi.Statistics.QueryCacheHitCount, Is.EqualTo(1), "Unexpected cache hit count"); + } + + [Test] + public async Task CacheHqlQueryWithFetchAndTransformerThatChangeTupleAsync() + { + if (!TestDialect.SupportsDuplicatedColumnAliases) + Assert.Ignore("Ignored due to GH-2092"); + + Sfi.Statistics.Clear(); + await (Sfi.EvictQueriesAsync()); + + // the combination of query and transformer doesn't make sense. + // It's simply used as example of returned data being transformed before caching leading to mismatch between + // Loader.ResultTypes collection and provided tuple + var order = await (session.CreateQuery("select o.Employee.FirstName, o from Order o join fetch o.Customer where o.OrderId = :id") + .SetInt32("id", 10248) + .SetCacheable(true) + .SetResultTransformer(Transformers.RootEntity) + .UniqueResultAsync()); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(1), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(1), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheMissCount, Is.EqualTo(1), "Unexpected cache miss count"); + Assert.That(order, Is.Not.Null); + Assert.That(order.Customer, Is.Not.Null); + Assert.That(NHibernateUtil.IsInitialized(order.Customer), Is.True); + + session.Clear(); + Sfi.Statistics.Clear(); + + order = await (session.CreateQuery("select o.Employee.FirstName, o from Order o join fetch o.Customer where o.OrderId = :id") + .SetInt32("id", 10248) + .SetCacheable(true) + .SetResultTransformer(Transformers.RootEntity) + .UniqueResultAsync()); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(0), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(0), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheMissCount, Is.EqualTo(0), "Unexpected cache miss count"); + Assert.That(Sfi.Statistics.QueryCacheHitCount, Is.EqualTo(1), "Unexpected cache hit count"); + Assert.That(order, Is.Not.Null); + Assert.That(order.Customer, Is.Not.Null); + Assert.That(NHibernateUtil.IsInitialized(order.Customer), Is.True); + } private static void AssertFetchedOrder(Order order) { + Assert.That(NHibernateUtil.IsInitialized(order), "Expected the order to be initialized"); + Assert.That(order.Customer, Is.Not.Null, "Expected the fetched Customer to be not null"); Assert.That(NHibernateUtil.IsInitialized(order.Customer), Is.True, "Expected the fetched Customer to be initialized"); Assert.That(NHibernateUtil.IsInitialized(order.OrderLines), Is.True, "Expected the fetched OrderLines to be initialized"); Assert.That(order.OrderLines, Has.Count.EqualTo(3), "Expected the fetched OrderLines to have 3 items"); var orderLine = order.OrderLines.First(); + Assert.That(orderLine.Product, Is.Not.Null, "Expected the fetched Product to be not null"); Assert.That(NHibernateUtil.IsInitialized(orderLine.Product), Is.True, "Expected the fetched Product to be initialized"); Assert.That(NHibernateUtil.IsInitialized(orderLine.Product.OrderLines), Is.True, "Expected the fetched OrderLines to be initialized"); } diff --git a/src/NHibernate.Test/Criteria/ReadonlyTests/CriteriaNorthwindReadonlyTestCase.cs b/src/NHibernate.Test/Criteria/ReadonlyTests/CriteriaNorthwindReadonlyTestCase.cs new file mode 100644 index 00000000000..cd8cfd50fba --- /dev/null +++ b/src/NHibernate.Test/Criteria/ReadonlyTests/CriteriaNorthwindReadonlyTestCase.cs @@ -0,0 +1,67 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using NHibernate.DomainModel.Northwind.Entities; +using NUnit.Framework; + +namespace NHibernate.Test.Criteria.ReadonlyTests +{ + public abstract class CriteriaNorthwindReadonlyTestCase : NHibernate.Test.Linq.ReadonlyTestCase + { + private ISession _session = null; + protected NorthwindQueryOver db; + + protected override string[] Mappings + { + get + { + return new[] + { + "Northwind.Mappings.Customer.hbm.xml", + "Northwind.Mappings.Employee.hbm.xml", + "Northwind.Mappings.Order.hbm.xml", + "Northwind.Mappings.OrderLine.hbm.xml", + "Northwind.Mappings.Product.hbm.xml", + "Northwind.Mappings.ProductCategory.hbm.xml", + "Northwind.Mappings.Region.hbm.xml", + "Northwind.Mappings.Shipper.hbm.xml", + "Northwind.Mappings.Supplier.hbm.xml", + "Northwind.Mappings.Territory.hbm.xml", + "Northwind.Mappings.AnotherEntity.hbm.xml", + "Northwind.Mappings.Role.hbm.xml", + "Northwind.Mappings.User.hbm.xml", + "Northwind.Mappings.TimeSheet.hbm.xml", + "Northwind.Mappings.Animal.hbm.xml", + "Northwind.Mappings.Patient.hbm.xml" + }; + } + } + + public ISession Session + { + get { return _session; } + } + + protected override void OnSetUp() + { + _session = OpenSession(); + db = new NorthwindQueryOver(_session); + base.OnSetUp(); + } + + protected override void OnTearDown() + { + if (_session.IsOpen) + { + _session.Close(); + } + } + + public static void AssertByIds(IEnumerable entities, TId[] expectedIds, Converter entityIdGetter) + { + Assert.That(entities.Select(x => entityIdGetter(x)), Is.EquivalentTo(expectedIds)); + } + + protected IQueryOver Customers => Session.QueryOver(); + } +} diff --git a/src/NHibernate.Test/Criteria/ReadonlyTests/QueryOverCacheableTests.cs b/src/NHibernate.Test/Criteria/ReadonlyTests/QueryOverCacheableTests.cs new file mode 100644 index 00000000000..749a5184354 --- /dev/null +++ b/src/NHibernate.Test/Criteria/ReadonlyTests/QueryOverCacheableTests.cs @@ -0,0 +1,279 @@ +using System.Linq; +using NHibernate.Cfg; +using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.SqlCommand; +using NUnit.Framework; + +namespace NHibernate.Test.Criteria.ReadonlyTests +{ + [TestFixture] + public class QueryOverCacheableTests : CriteriaNorthwindReadonlyTestCase + { + //Just for discoverability + private class CriteriaCacheableTest{} + + protected override void Configure(Configuration config) + { + config.SetProperty(Environment.UseQueryCache, "true"); + config.SetProperty(Environment.GenerateStatistics, "true"); + base.Configure(config); + } + + [Test] + public void QueryIsCacheable() + { + Sfi.Statistics.Clear(); + Sfi.EvictQueries(); + + db.Customers.Cacheable().Take(1).List(); + db.Customers.Cacheable().Take(1).List(); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(1), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(1), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheHitCount, Is.EqualTo(1), "Unexpected cache hit count"); + } + + [Test] + public void QueryIsCacheable2() + { + Sfi.Statistics.Clear(); + Sfi.EvictQueries(); + + db.Customers.Cacheable().Take(1).List(); + db.Customers.Take(1).List(); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(2), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(1), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheHitCount, Is.EqualTo(0), "Unexpected cache hit count"); + } + + [Test] + public void QueryIsCacheableWithRegion() + { + Sfi.Statistics.Clear(); + Sfi.EvictQueries(); + Sfi.EvictQueries("test"); + Sfi.EvictQueries("other"); + + db.Customers.Cacheable().Take(1).CacheRegion("test").List(); + db.Customers.Cacheable().Take(1).CacheRegion("test").List(); + db.Customers.Cacheable().Take(1).CacheRegion("other").List(); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(2), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(2), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheHitCount, Is.EqualTo(1), "Unexpected cache hit count"); + } + + [Test] + public void CanBeCombinedWithFetch() + { + Sfi.Statistics.Clear(); + Sfi.EvictQueries(); + + db.Customers + .Cacheable() + .List(); + + db.Orders + .Cacheable() + .Take(1) + .List(); + + db.Customers + .Fetch(SelectMode.Fetch, x => x.Orders) + .Cacheable() + .Take(1) + .List(); + + db.Orders + .Fetch(SelectMode.Fetch, x => x.OrderLines) + .Cacheable() + .Take(1) + .List(); + + db.Customers + .Fetch(SelectMode.Fetch, x => x.Address) + .Where(x => x.CustomerId == "VINET") + .Cacheable() + .SingleOrDefault(); + + var customer = db.Customers + .Fetch(SelectMode.Fetch, x => x.Address) + .Where(x => x.CustomerId == "VINET") + .Cacheable() + .SingleOrDefault(); + + Assert.That(NHibernateUtil.IsInitialized(customer.Address), Is.True, "Expected the fetched Address to be initialized"); + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(5), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(5), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheHitCount, Is.EqualTo(1), "Unexpected cache hit count"); + } + + [Test] + public void FetchIsCacheable() + { + Sfi.Statistics.Clear(); + Sfi.EvictQueries(); + + var order = db.Orders + .Fetch( + SelectMode.Fetch, + x => x.Customer, + x => x.OrderLines, + x => x.OrderLines.First().Product, + x => x.OrderLines.First().Product.OrderLines) + .Where(x => x.OrderId == 10248) + .Cacheable() + .List() + .First(); + + AssertFetchedOrder(order); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(1), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(1), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheMissCount, Is.EqualTo(1), "Unexpected cache miss count"); + + Sfi.Statistics.Clear(); + Session.Clear(); + + order = db.Orders + .Fetch( + SelectMode.Fetch, + x => x.Customer, + x => x.OrderLines, + x => x.OrderLines.First().Product, + x => x.OrderLines.First().Product.OrderLines) + .Where(x => x.OrderId == 10248) + .Cacheable() + .List() + .First(); + + AssertFetchedOrder(order); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(0), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(0), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheMissCount, Is.EqualTo(0), "Unexpected cache miss count"); + Assert.That(Sfi.Statistics.QueryCacheHitCount, Is.EqualTo(1), "Unexpected cache hit count"); + } + + [Test] + public void FetchIsCacheableForJoinAlias() + { + Sfi.Statistics.Clear(); + Sfi.EvictQueries(); + + Customer customer = null; + OrderLine orderLines = null; + Product product = null; + OrderLine prOrderLines = null; + + var order = db.Orders + .JoinAlias(x => x.Customer, () => customer) + .JoinAlias(x => x.OrderLines, () => orderLines, JoinType.LeftOuterJoin) + .JoinAlias(() => orderLines.Product, () => product) + .JoinAlias(() => product.OrderLines, () => prOrderLines, JoinType.LeftOuterJoin) + .Where(x => x.OrderId == 10248) + .Cacheable() + .List() + .First(); + + AssertFetchedOrder(order); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(1), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(1), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheMissCount, Is.EqualTo(1), "Unexpected cache miss count"); + + Sfi.Statistics.Clear(); + Session.Clear(); + + order = db.Orders + .JoinAlias(x => x.Customer, () => customer) + .JoinAlias(x => x.OrderLines, () => orderLines, JoinType.LeftOuterJoin) + .JoinAlias(() => orderLines.Product, () => product) + .JoinAlias(() => product.OrderLines, () => prOrderLines, JoinType.LeftOuterJoin) + .Where(x => x.OrderId == 10248) + .Cacheable() + .List() + .First(); + + AssertFetchedOrder(order); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(0), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(0), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheMissCount, Is.EqualTo(0), "Unexpected cache miss count"); + Assert.That(Sfi.Statistics.QueryCacheHitCount, Is.EqualTo(1), "Unexpected cache hit count"); + } + + [Test] + public void FutureFetchIsCacheable() + { + Sfi.Statistics.Clear(); + Sfi.EvictQueries(); + var multiQueries = Sfi.ConnectionProvider.Driver.SupportsMultipleQueries; + + db.Orders + .Fetch(SelectMode.Fetch, x => x.Customer) + .Where(x => x.OrderId == 10248) + .Cacheable() + .Future(); + + var order = db.Orders + .Fetch( + SelectMode.Fetch, + x => x.OrderLines, + x => x.OrderLines.First().Product, + x => x.OrderLines.First().Product.OrderLines) + .Where(x => x.OrderId == 10248) + .Cacheable() + .Future() + .ToList() + .First(); + + AssertFetchedOrder(order); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(multiQueries ? 1 : 2), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(2), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheMissCount, Is.EqualTo(2), "Unexpected cache miss count"); + + Sfi.Statistics.Clear(); + Session.Clear(); + + db.Orders + .Fetch(SelectMode.Fetch, x => x.Customer) + .Where(x => x.OrderId == 10248) + .Cacheable() + .Future(); + + order = db.Orders + .Fetch( + SelectMode.Fetch, + x => x.OrderLines, + x => x.OrderLines.First().Product, + x => x.OrderLines.First().Product.OrderLines) + .Where(x => x.OrderId == 10248) + .Cacheable() + .Future() + .ToList() + .First(); + + AssertFetchedOrder(order); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(0), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(0), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheMissCount, Is.EqualTo(0), "Unexpected cache miss count"); + Assert.That(Sfi.Statistics.QueryCacheHitCount, Is.EqualTo(2), "Unexpected cache hit count"); + } + + private static void AssertFetchedOrder(Order order) + { + Assert.That(order.Customer, Is.Not.Null, "Expected the fetched Customer to be not null"); + Assert.That(NHibernateUtil.IsInitialized(order.Customer), Is.True, "Expected the fetched Customer to be initialized"); + Assert.That(NHibernateUtil.IsInitialized(order.OrderLines), Is.True, "Expected the fetched OrderLines to be initialized"); + Assert.That(order.OrderLines, Has.Count.EqualTo(3), "Expected the fetched OrderLines to have 3 items"); + var orderLine = order.OrderLines.First(); + Assert.That(orderLine.Product, Is.Not.Null, "Expected the fetched Product to be not null"); + Assert.That(NHibernateUtil.IsInitialized(orderLine.Product), Is.True, "Expected the fetched Product to be initialized"); + Assert.That(NHibernateUtil.IsInitialized(orderLine.Product.OrderLines), Is.True, "Expected the fetched OrderLines to be initialized"); + } + } +} diff --git a/src/NHibernate.Test/Criteria/ReadonlyTests/ReadonlyFixtureSetUp.cs b/src/NHibernate.Test/Criteria/ReadonlyTests/ReadonlyFixtureSetUp.cs new file mode 100644 index 00000000000..b2ba33908d6 --- /dev/null +++ b/src/NHibernate.Test/Criteria/ReadonlyTests/ReadonlyFixtureSetUp.cs @@ -0,0 +1,12 @@ +using NUnit.Framework; + +namespace NHibernate.Test.Criteria.ReadonlyTests +{ + /// + /// Single one-time fixture set up for all test fixtures in NHibernate.Test.Criteria.ReadonlyTests namespace + /// + [SetUpFixture] + public class ReadonlyFixtureSetUp : NHibernate.Test.Linq.LinqReadonlyTestsContext + { + } +} diff --git a/src/NHibernate.Test/Linq/QueryCacheableTests.cs b/src/NHibernate.Test/Linq/QueryCacheableTests.cs index 79aedf77c5f..4475971f09e 100644 --- a/src/NHibernate.Test/Linq/QueryCacheableTests.cs +++ b/src/NHibernate.Test/Linq/QueryCacheableTests.cs @@ -2,6 +2,7 @@ using NHibernate.Cfg; using NHibernate.DomainModel.Northwind.Entities; using NHibernate.Linq; +using NHibernate.Transform; using NUnit.Framework; namespace NHibernate.Test.Linq @@ -395,13 +396,102 @@ public void FutureFetchIsCachable() Assert.That(Sfi.Statistics.QueryCacheMissCount, Is.EqualTo(0), "Unexpected cache miss count"); Assert.That(Sfi.Statistics.QueryCacheHitCount, Is.EqualTo(2), "Unexpected cache hit count"); } + + [Explicit("Not working. dto.Customer retrieved from cache as uninitialized proxy")] + [Test] + public void ProjectedEntitiesAreCachable() + { + Sfi.Statistics.Clear(); + Sfi.EvictQueries(); + var dto = session.Query() + .WithOptions(o => o.SetCacheable(true)) + .Where(x => x.OrderId == 10248) + .Select(x => new { x.Customer, Order = x }) + .FirstOrDefault(); + + Assert.That(dto, Is.Not.Null, "dto should not be null"); + Assert.That(dto.Order, Is.Not.Null, "dto.Order should not be null"); + Assert.That(NHibernateUtil.IsInitialized(dto.Order), Is.True, "dto.Order should be initialized"); + Assert.That(dto.Customer, Is.Not.Null, "dto.Customer should not be null"); + Assert.That(NHibernateUtil.IsInitialized(dto.Customer), Is.True, "dto.Customer from cache should be initialized"); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(1), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(1), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheMissCount, Is.EqualTo(1), "Unexpected cache miss count"); + + Sfi.Statistics.Clear(); + session.Clear(); + + dto = session.Query() + .WithOptions(o => o.SetCacheable(true)) + .Where(x => x.OrderId == 10248) + .Select(x => new { x.Customer, Order = x }) + .FirstOrDefault(); + + Assert.That(dto, Is.Not.Null, "dto from cache should not be null"); + Assert.That(dto.Order, Is.Not.Null, "dto.Order from cache should not be null"); + Assert.That(NHibernateUtil.IsInitialized(dto.Order), Is.True, "dto.Order from cache should be initialized"); + Assert.That(dto.Customer, Is.Not.Null, "dto.Customer from cache should not be null"); + Assert.That(NHibernateUtil.IsInitialized(dto.Customer), Is.True, "dto.Customer from cache should be initialized"); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(0), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(0), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheMissCount, Is.EqualTo(0), "Unexpected cache miss count"); + Assert.That(Sfi.Statistics.QueryCacheHitCount, Is.EqualTo(1), "Unexpected cache hit count"); + } + + [Test] + public void CacheHqlQueryWithFetchAndTransformerThatChangeTuple() + { + if (!TestDialect.SupportsDuplicatedColumnAliases) + Assert.Ignore("Ignored due to GH-2092"); + + Sfi.Statistics.Clear(); + Sfi.EvictQueries(); + + // the combination of query and transformer doesn't make sense. + // It's simply used as example of returned data being transformed before caching leading to mismatch between + // Loader.ResultTypes collection and provided tuple + var order = session.CreateQuery("select o.Employee.FirstName, o from Order o join fetch o.Customer where o.OrderId = :id") + .SetInt32("id", 10248) + .SetCacheable(true) + .SetResultTransformer(Transformers.RootEntity) + .UniqueResult(); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(1), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(1), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheMissCount, Is.EqualTo(1), "Unexpected cache miss count"); + Assert.That(order, Is.Not.Null); + Assert.That(order.Customer, Is.Not.Null); + Assert.That(NHibernateUtil.IsInitialized(order.Customer), Is.True); + + session.Clear(); + Sfi.Statistics.Clear(); + + order = session.CreateQuery("select o.Employee.FirstName, o from Order o join fetch o.Customer where o.OrderId = :id") + .SetInt32("id", 10248) + .SetCacheable(true) + .SetResultTransformer(Transformers.RootEntity) + .UniqueResult(); + + Assert.That(Sfi.Statistics.QueryExecutionCount, Is.EqualTo(0), "Unexpected execution count"); + Assert.That(Sfi.Statistics.QueryCachePutCount, Is.EqualTo(0), "Unexpected cache put count"); + Assert.That(Sfi.Statistics.QueryCacheMissCount, Is.EqualTo(0), "Unexpected cache miss count"); + Assert.That(Sfi.Statistics.QueryCacheHitCount, Is.EqualTo(1), "Unexpected cache hit count"); + Assert.That(order, Is.Not.Null); + Assert.That(order.Customer, Is.Not.Null); + Assert.That(NHibernateUtil.IsInitialized(order.Customer), Is.True); + } private static void AssertFetchedOrder(Order order) { + Assert.That(NHibernateUtil.IsInitialized(order), "Expected the order to be initialized"); + Assert.That(order.Customer, Is.Not.Null, "Expected the fetched Customer to be not null"); Assert.That(NHibernateUtil.IsInitialized(order.Customer), Is.True, "Expected the fetched Customer to be initialized"); Assert.That(NHibernateUtil.IsInitialized(order.OrderLines), Is.True, "Expected the fetched OrderLines to be initialized"); Assert.That(order.OrderLines, Has.Count.EqualTo(3), "Expected the fetched OrderLines to have 3 items"); var orderLine = order.OrderLines.First(); + Assert.That(orderLine.Product, Is.Not.Null, "Expected the fetched Product to be not null"); Assert.That(NHibernateUtil.IsInitialized(orderLine.Product), Is.True, "Expected the fetched Product to be initialized"); Assert.That(NHibernateUtil.IsInitialized(orderLine.Product.OrderLines), Is.True, "Expected the fetched OrderLines to be initialized"); } diff --git a/src/NHibernate/Async/Loader/Loader.cs b/src/NHibernate/Async/Loader/Loader.cs index 33dff86b077..6367dbd9711 100644 --- a/src/NHibernate/Async/Loader/Loader.cs +++ b/src/NHibernate/Async/Loader/Loader.cs @@ -27,7 +27,6 @@ using NHibernate.Exceptions; using NHibernate.Hql.Util; using NHibernate.Impl; -using NHibernate.Intercept; using NHibernate.Param; using NHibernate.Persister.Collection; using NHibernate.Persister.Entity; diff --git a/src/NHibernate/Cache/QueryCacheResultBuilder.cs b/src/NHibernate/Cache/QueryCacheResultBuilder.cs index f3f145ddad3..fff6fe08d75 100644 --- a/src/NHibernate/Cache/QueryCacheResultBuilder.cs +++ b/src/NHibernate/Cache/QueryCacheResultBuilder.cs @@ -1,12 +1,7 @@ using System; using System.Collections; using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; using NHibernate.Collection; -using NHibernate.Engine; -using NHibernate.Persister.Collection; using NHibernate.Type; namespace NHibernate.Cache @@ -17,56 +12,34 @@ namespace NHibernate.Cache public sealed class QueryCacheResultBuilder { private readonly IType[] _resultTypes; - private readonly IType[] _cacheTypes; - private readonly List _entityFetchIndexes = new List(); - private readonly List _collectionFetchIndexes = new List(); - private readonly bool _hasFetches; + private readonly Loader.Loader.QueryCacheInfo _cacheInfo; + public static bool IsCacheWithFetches(Loader.Loader loader) + { + return loader.CacheTypes.Length > loader.ResultTypes.Length; + } + internal QueryCacheResultBuilder(Loader.Loader loader) { _resultTypes = loader.ResultTypes; - _cacheTypes = loader.CacheTypes; - if (loader.EntityFetches != null) + if (IsCacheWithFetches(loader)) { - for (var i = 0; i < loader.EntityFetches.Length; i++) - { - if (loader.EntityFetches[i]) - { - _entityFetchIndexes.Add(i); - } - } - - _hasFetches = _entityFetchIndexes.Count > 0; - } - - if (loader.CollectionFetches == null) - { - return; + _cacheInfo = loader.CacheInfo; } - - for (var i = 0; i < loader.CollectionFetches.Length; i++) - { - if (loader.CollectionFetches[i]) - { - _collectionFetchIndexes.Add(i); - } - } - - _hasFetches = _hasFetches || _collectionFetchIndexes.Count > 0; } internal IList Result { get; } = new List(); internal void AddRow(object result, object[] entities, IPersistentCollection[] collections) { - if (!_hasFetches) + if (_cacheInfo == null) { Result.Add(result); return; } - var row = new object[_cacheTypes.Length]; + var row = new object[_cacheInfo.CacheTypes.Length]; if (_resultTypes.Length == 1) { row[0] = result; @@ -77,14 +50,17 @@ internal void AddRow(object result, object[] entities, IPersistentCollection[] c } var i = _resultTypes.Length; - foreach (var index in _entityFetchIndexes) + foreach (var index in _cacheInfo.AdditionalEntities) { row[i++] = entities[index]; } - foreach (var index in _collectionFetchIndexes) + if (collections != null) { - row[i++] = collections[index]; + foreach (var collection in collections) + { + row[i++] = collection; + } } Result.Add(row); @@ -92,7 +68,7 @@ internal void AddRow(object result, object[] entities, IPersistentCollection[] c internal IList GetResultList(IList cacheList) { - if (!_hasFetches) + if (_cacheInfo == null) { return cacheList; } diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/SelectClause.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/SelectClause.cs index 188fc67f0db..a3e4c9af4e1 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/SelectClause.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/SelectClause.cs @@ -56,7 +56,6 @@ public void InitializeDerivedSelectClause(FromClause fromClause) ASTAppender appender = new ASTAppender(ASTFactory, this); // Get ready to start adding nodes. int size = fromElements.Count; - List sqlResultTypeList = new List(size); List queryReturnTypeList = new List(size); int k = 0; @@ -78,7 +77,6 @@ public void InitializeDerivedSelectClause(FromClause fromClause) } _fromElementsForLoad.Add(fromElement); - sqlResultTypeList.Add(type); // Generate the select expression. string text = fromElement.RenderIdentifierSelect(size, k); diff --git a/src/NHibernate/Loader/Criteria/CriteriaLoader.cs b/src/NHibernate/Loader/Criteria/CriteriaLoader.cs index e3b816d2661..e3a5ff5dfda 100644 --- a/src/NHibernate/Loader/Criteria/CriteriaLoader.cs +++ b/src/NHibernate/Loader/Criteria/CriteriaLoader.cs @@ -75,7 +75,7 @@ public CriteriaLoader(IOuterJoinLoadable persister, ISessionFactoryImplementor f userAliases = walker.UserAliases; ResultTypes = walker.ResultTypes; includeInResultRow = walker.IncludeInResultRow; - resultRowLength = ArrayHelper.CountTrue(IncludeInResultRow); + resultRowLength = ArrayHelper.CountTrue(includeInResultRow); childFetchEntities = walker.ChildFetchEntities; EntityFetchLazyProperties = walker.EntityFetchLazyProperties; // fill caching objects only if there is a projection @@ -85,6 +85,10 @@ public CriteriaLoader(IOuterJoinLoadable persister, ISessionFactoryImplementor f } PostInstantiate(); + if (!translator.HasProjection) + { + CachePersistersWithCollections(ArrayHelper.IndexesOf(includeInResultRow, true)); + } } // Not ported: scroll (not supported) diff --git a/src/NHibernate/Loader/Custom/CustomLoader.cs b/src/NHibernate/Loader/Custom/CustomLoader.cs index 504ce824a19..dea693cf540 100644 --- a/src/NHibernate/Loader/Custom/CustomLoader.cs +++ b/src/NHibernate/Loader/Custom/CustomLoader.cs @@ -275,7 +275,7 @@ public override ILoadable[] EntityPersisters get { return entityPersisters; } } - protected override ICollectionPersister[] CollectionPersisters + protected internal override ICollectionPersister[] CollectionPersisters { get { return collectionPersisters; } } diff --git a/src/NHibernate/Loader/Hql/QueryLoader.cs b/src/NHibernate/Loader/Hql/QueryLoader.cs index 95ce94e6693..9b26eab0937 100644 --- a/src/NHibernate/Loader/Hql/QueryLoader.cs +++ b/src/NHibernate/Loader/Hql/QueryLoader.cs @@ -44,7 +44,6 @@ public partial class QueryLoader : BasicLoader private readonly NullableDictionary _sqlAliasByEntityAlias = new NullableDictionary(); private int _selectLength; private LockMode[] _defaultLockModes; - private IType[] _cacheTypes; private ISet _uncacheableCollectionPersisters; private Dictionary[] _collectionUserProvidedAliases; private IReadOnlyDictionary _entityByResultTypeDic; @@ -200,13 +199,11 @@ protected override string[] CollectionSuffixes get { return _collectionSuffixes; } } - protected override ICollectionPersister[] CollectionPersisters + protected internal override ICollectionPersister[] CollectionPersisters { get { return _collectionPersisters; } } - public override IType[] CacheTypes => _cacheTypes; - protected override IDictionary GetCollectionUserProvidedAlias(int index) { return _collectionUserProvidedAliases?[index]; @@ -232,7 +229,6 @@ private void Initialize(SelectClause selectClause) _collectionPersisters = new IQueryableCollection[length]; _collectionOwners = new int[length]; _collectionSuffixes = new string[length]; - CollectionFetches = new bool[length]; if (collectionFromElements.Any(qc => qc.QueryableCollection.IsManyToMany)) _collectionUserProvidedAliases = new Dictionary[length]; @@ -244,7 +240,6 @@ private void Initialize(SelectClause selectClause) // collectionSuffixes[i] = collectionFromElement.getColumnAliasSuffix(); // collectionSuffixes[i] = Integer.toString( i ) + "_"; _collectionSuffixes[i] = collectionFromElement.CollectionSuffix; - CollectionFetches[i] = collectionFromElement.IsFetch; } } @@ -258,8 +253,6 @@ private void Initialize(SelectClause selectClause) _includeInSelect = new bool[size]; _owners = new int[size]; _ownerAssociationTypes = new EntityType[size]; - EntityFetches = new bool[size]; - var cacheTypes = new List(ResultTypes); for (int i = 0; i < size; i++) { @@ -282,11 +275,6 @@ private void Initialize(SelectClause selectClause) _sqlAliasSuffixes[i] = (size == 1) ? "" : i + "_"; // sqlAliasSuffixes[i] = element.getColumnAliasSuffix(); _includeInSelect[i] = !element.IsFetch; - EntityFetches[i] = element.IsFetch; - if (element.IsFetch) - { - cacheTypes.Add(_entityPersisters[i].Type); - } if (_includeInSelect[i]) { _selectLength++; @@ -325,16 +313,10 @@ private void Initialize(SelectClause selectClause) } } - if (_collectionPersisters != null) - { - cacheTypes.AddRange(_collectionPersisters.Where((t, i) => CollectionFetches[i]).Select(t => t.CollectionType)); - } - - _cacheTypes = cacheTypes.ToArray(); - //NONE, because its the requested lock mode, not the actual! _defaultLockModes = ArrayHelper.Fill(LockMode.None, size); _uncacheableCollectionPersisters = _queryTranslator.UncacheableCollectionPersisters; + CachePersistersWithCollections(ArrayHelper.IndexesOf(_includeInSelect, true)); } public IList List(ISessionImplementor session, QueryParameters queryParameters) diff --git a/src/NHibernate/Loader/Loader.cs b/src/NHibernate/Loader/Loader.cs index 988d08cef98..49fd0974e38 100644 --- a/src/NHibernate/Loader/Loader.cs +++ b/src/NHibernate/Loader/Loader.cs @@ -17,7 +17,6 @@ using NHibernate.Exceptions; using NHibernate.Hql.Util; using NHibernate.Impl; -using NHibernate.Intercept; using NHibernate.Param; using NHibernate.Persister.Collection; using NHibernate.Persister.Entity; @@ -52,8 +51,21 @@ namespace NHibernate.Loader /// public abstract partial class Loader { - private static readonly INHibernateLogger Log = NHibernateLogger.For(typeof(Loader)); + /// + /// DTO for providing all query cache related details + /// + public sealed class QueryCacheInfo + { + public IType[] CacheTypes { get; set; } + + /// + /// Loader.EntityPersister indexes to be cached. + /// + public IReadOnlyList AdditionalEntities { get; set; } + } + private static readonly INHibernateLogger Log = NHibernateLogger.For(typeof(Loader)); + private Lazy _cacheInfo; private readonly ISessionFactoryImplementor _factory; private readonly SessionFactoryHelper _helper; private ColumnNameCache _columnNameCache; @@ -156,11 +168,18 @@ public virtual bool IsSubselectLoadingEnabled /// public IType[] ResultTypes { get; protected set; } - public bool[] EntityFetches { get; protected set; } + public IType[] CacheTypes => CacheInfo?.CacheTypes ?? ResultTypes; - public bool[] CollectionFetches { get; protected set; } + public virtual QueryCacheInfo CacheInfo => _cacheInfo?.Value; - public virtual IType[] CacheTypes => ResultTypes; + /// + /// Cache all additional persisters and collection persisters that were loaded by query (fetched entities and collections) + /// + /// Persister indexes that are cached as part of query result (so present in ResultTypes) + protected void CachePersistersWithCollections(IEnumerable resultTypePersisters) + { + _cacheInfo = new Lazy(() => GetQueryCacheInfo(resultTypePersisters)); + } public ISessionFactoryImplementor Factory { @@ -186,7 +205,7 @@ public ISessionFactoryImplementor Factory /// An (optional) persister for a collection to be initialized; only collection loaders /// return a non-null value /// - protected virtual ICollectionPersister[] CollectionPersisters + protected internal virtual ICollectionPersister[] CollectionPersisters { get { return null; } } @@ -1867,9 +1886,11 @@ internal QueryKey GenerateQueryKey(ISessionImplementor session, QueryParameters private CacheableResultTransformer CreateCacheableResultTransformer(QueryParameters queryParameters) { + bool skipTransformer = QueryCacheResultBuilder.IsCacheWithFetches(this); + return CacheableResultTransformer.Create( queryParameters.ResultTransformer, ResultRowAliases, IncludeInResultRow, - queryParameters.HasAutoDiscoverScalarTypes, SqlString); + queryParameters.HasAutoDiscoverScalarTypes, SqlString, skipTransformer); } private IList GetResultFromQueryCache( @@ -2080,6 +2101,31 @@ protected bool TryGetLimitString(Dialect.Dialect dialect, SqlString queryString, return false; } + private QueryCacheInfo GetQueryCacheInfo(IEnumerable resultTypePersisters) + { + var resultTypes = ResultTypes.EmptyIfNull(); + + var cacheTypes = new List(resultTypes.Count + EntityPersisters.Length + CollectionPersisters?.Length ?? 0); + cacheTypes.AddRange(resultTypes); + + int[] additionalEntities = null; + if (EntityPersisters.Length > 0) + { + additionalEntities = Enumerable.Range(0, EntityPersisters.Length).Except(resultTypePersisters).ToArray(); + cacheTypes.AddRange(additionalEntities.Select(i => EntityPersisters[i].EntityMetamodel.EntityType)); + } + + cacheTypes.AddRange(CollectionPersisters.EmptyIfNull().Select(p => p.CollectionType)); + + return cacheTypes.Count == resultTypes.Count + ? null + : new QueryCacheInfo + { + CacheTypes = cacheTypes.ToArray(), + AdditionalEntities = additionalEntities.EmptyIfNull(), + }; + } + #endregion } } diff --git a/src/NHibernate/Loader/OuterJoinLoader.cs b/src/NHibernate/Loader/OuterJoinLoader.cs index eb5d515b114..04874e5ef5e 100644 --- a/src/NHibernate/Loader/OuterJoinLoader.cs +++ b/src/NHibernate/Loader/OuterJoinLoader.cs @@ -94,7 +94,7 @@ protected override string[] Aliases get { return aliases; } } - protected override ICollectionPersister[] CollectionPersisters + protected internal override ICollectionPersister[] CollectionPersisters { get { return collectionPersisters; } } diff --git a/src/NHibernate/Transform/CacheableResultTransformer.cs b/src/NHibernate/Transform/CacheableResultTransformer.cs index 97d64f459ff..caed8a31221 100644 --- a/src/NHibernate/Transform/CacheableResultTransformer.cs +++ b/src/NHibernate/Transform/CacheableResultTransformer.cs @@ -24,6 +24,7 @@ public class CacheableResultTransformer : IResultTransformer public bool AutoDiscoverTypes { get; } private readonly SqlString _autoDiscoveredQuery; + private readonly bool _skipTransformer; private int _tupleLength; private int _tupleSubsetLength; @@ -60,7 +61,7 @@ public class CacheableResultTransformer : IResultTransformer /// a CacheableResultTransformer that is used to transform /// tuples to a value(s) that can be cached. // Since v5.1 - [Obsolete("Please use overload with autoDiscoverTypes parameter.")] + [Obsolete("Please use overload with skipTransformer parameter.")] public static CacheableResultTransformer Create(IResultTransformer transformer, string[] aliases, bool[] includeInTuple) @@ -68,6 +69,17 @@ public static CacheableResultTransformer Create(IResultTransformer transformer, return Create(transformer, aliases, includeInTuple, false, null); } + // Since 5.2 + [Obsolete("Please use overload with skipTransformer parameter.")] + public static CacheableResultTransformer Create( + IResultTransformer transformer, string[] aliases, bool[] includeInTuple, bool autoDiscoverTypes, + SqlString autoDiscoveredQuery) + { + return autoDiscoverTypes + ? Create(autoDiscoveredQuery) + : Create(includeInTuple, GetIncludeInTransform(transformer, aliases, includeInTuple), false); + } + /// /// Returns a CacheableResultTransformer that is used to transform /// tuples to a value(s) that can be cached. @@ -84,15 +96,25 @@ public static CacheableResultTransformer Create(IResultTransformer transformer, /// Indicates if types auto-discovery is enabled. /// If , the query for which they /// will be autodiscovered. + /// If true cache results untransformed. /// a CacheableResultTransformer that is used to transform /// tuples to a value(s) that can be cached. public static CacheableResultTransformer Create( - IResultTransformer transformer, string[] aliases, bool[] includeInTuple, bool autoDiscoverTypes, - SqlString autoDiscoveredQuery) + IResultTransformer transformer, + string[] aliases, + bool[] includeInTuple, + bool autoDiscoverTypes, + SqlString autoDiscoveredQuery, + bool skipTransformer) { return autoDiscoverTypes ? Create(autoDiscoveredQuery) - : Create(includeInTuple, GetIncludeInTransform(transformer, aliases, includeInTuple)); + : Create( + includeInTuple, + skipTransformer + ? null + : GetIncludeInTransform(transformer, aliases, includeInTuple), + skipTransformer); } /// @@ -106,11 +128,12 @@ public static CacheableResultTransformer Create( /// must be non-null /// Indexes that are included in the transformation. /// null if all elements in the tuple are included. + /// /// a CacheableResultTransformer that is used to transform /// tuples to a value(s) that can be cached. - private static CacheableResultTransformer Create(bool[] includeInTuple, bool[] includeInTransform) + private static CacheableResultTransformer Create(bool[] includeInTuple, bool[] includeInTransform, bool skipTransformer) { - return new CacheableResultTransformer(includeInTuple, includeInTransform); + return new CacheableResultTransformer(includeInTuple, includeInTransform, skipTransformer); } private static CacheableResultTransformer Create(SqlString autoDiscoveredQuery) @@ -136,8 +159,9 @@ private static bool[] GetIncludeInTransform(IResultTransformer transformer, stri return resultTransformer.IncludeInTransform(aliases, tupleLength); } - private CacheableResultTransformer(bool[] includeInTuple, bool[] includeInTransform) + private CacheableResultTransformer(bool[] includeInTuple, bool[] includeInTransform, bool skipTransformer) { + _skipTransformer = skipTransformer; InitializeTransformer(includeInTuple, includeInTransform); } @@ -212,7 +236,7 @@ public IList RetransformResults(IList transformedResults, if (_includeInTuple == null) throw new InvalidOperationException("This transformer is not initialized"); - if (!HasSameParameters(Create(transformer, aliases, includeInTuple, false, null))) + if (!HasSameParameters(Create(transformer, aliases, includeInTuple, false, null, _skipTransformer))) { throw new InvalidOperationException( "this CacheableResultTransformer is inconsistent with specified arguments; cannot re-transform" @@ -220,7 +244,11 @@ public IList RetransformResults(IList transformedResults, } bool requiresRetransform = true; string[] aliasesToUse = aliases == null ? null : Index(aliases); - if (transformer.Equals(_actualTransformer)) + + if (_skipTransformer) + { + } + else if (transformer.Equals(_actualTransformer)) { requiresRetransform = false; } diff --git a/src/NHibernate/Util/ArrayHelper.cs b/src/NHibernate/Util/ArrayHelper.cs index 52d24942bc9..f078cfa8607 100644 --- a/src/NHibernate/Util/ArrayHelper.cs +++ b/src/NHibernate/Util/ArrayHelper.cs @@ -180,6 +180,15 @@ public static int CountTrue(bool[] array) return array.Count(t => t); } + internal static IEnumerable IndexesOf(T[] array, T value) + { + for (int i = 0; i < array.Length; i++) + { + if (EqualityComparer.Default.Equals(array[i], value)) + yield return i; + } + } + public static bool ArrayEquals(T[] a, T[] b) { return ArrayComparer.Default.Equals(a, b); diff --git a/src/NHibernate/Util/EnumerableExtensions.cs b/src/NHibernate/Util/EnumerableExtensions.cs index b0e35892b06..0fb37b8ffe3 100644 --- a/src/NHibernate/Util/EnumerableExtensions.cs +++ b/src/NHibernate/Util/EnumerableExtensions.cs @@ -102,5 +102,10 @@ internal static IList ToIList(this IEnumerable list) { return list as IList ?? list.ToList(); } + + internal static IReadOnlyList EmptyIfNull(this IReadOnlyList list) + { + return list ?? Array.Empty(); + } } } From 7bc3fddd107bf0e7fcacb3894e3c1fc19fdf278c Mon Sep 17 00:00:00 2001 From: "g.yakimov" Date: Thu, 9 Apr 2020 18:45:35 +0300 Subject: [PATCH 27/43] special handling of with clauses --- src/NHibernate.Test/App.config | 4 +-- .../Ast/ANTLR/Tree/AssignmentSpecification.cs | 2 +- .../Hql/Ast/ANTLR/Tree/ComponentJoin.cs | 8 ++--- .../Hql/Ast/ANTLR/Tree/IntoClause.cs | 4 +-- .../Criteria/CriteriaQueryTranslator.cs | 11 ++++++- .../Collection/AbstractCollectionPersister.cs | 8 ++--- .../Collection/CollectionPropertyMapping.cs | 6 ++-- .../Collection/ElementPropertyMapping.cs | 6 ++-- .../Entity/AbstractEntityPersister.cs | 31 ++++++++++--------- .../Entity/AbstractPropertyMapping.cs | 4 +-- .../Entity/BasicEntityPropertyMapping.cs | 7 ++--- .../Persister/Entity/IPropertyMapping.cs | 7 +++-- src/NHibernate/Persister/Entity/IQueryable.cs | 3 +- .../Entity/JoinedSubclassEntityPersister.cs | 6 ++-- .../Entity/SingleTableEntityPersister.cs | 10 +++--- .../Entity/UnionSubclassEntityPersister.cs | 4 +-- 16 files changed, 67 insertions(+), 54 deletions(-) diff --git a/src/NHibernate.Test/App.config b/src/NHibernate.Test/App.config index d3965012af5..8d0aa714996 100644 --- a/src/NHibernate.Test/App.config +++ b/src/NHibernate.Test/App.config @@ -7,7 +7,7 @@ - + @@ -31,7 +31,7 @@ NHibernate.Dialect.MsSql2008Dialect NHibernate.Driver.Sql2008ClientDriver - Server=localhost\sqlexpress;Database=nhibernate;Integrated Security=SSPI + Server=localhost;Database=nhibernate;Integrated Security=SSPI NHibernate.Test.DebugConnectionProvider, NHibernate.Test ReadCommitted diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/AssignmentSpecification.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/AssignmentSpecification.cs index 351d29318fa..4f450eead3f 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/AssignmentSpecification.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/AssignmentSpecification.cs @@ -61,7 +61,7 @@ public AssignmentSpecification(IASTNode eq, IQueryable persister) } else { - temp.Add(persister.GetSubclassTableName(persister.GetSubclassPropertyTableNumber(propertyPath))); + temp.Add(persister.GetSubclassTableName(persister.GetSubclassPropertyTableNumber(propertyPath, false))); } _tableNames = new HashSet(temp); diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs index bfffb9be928..448bf8ceda2 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs @@ -150,14 +150,14 @@ public bool TryToType(string propertyName, out IType type) return fromElementType.GetBasePropertyMapping().TryToType(GetPropertyPath(propertyName), out type); } - public string[] ToColumns(string alias, string propertyName) + public string[] ToColumns(string alias, string propertyName, bool useLastIndex = false) { - return fromElementType.GetBasePropertyMapping().ToColumns(alias, GetPropertyPath(propertyName)); + return fromElementType.GetBasePropertyMapping().ToColumns(alias, GetPropertyPath(propertyName), useLastIndex); } - public string[] ToColumns(string propertyName) + public string[] ToColumns(string propertyName, bool useLastIndex = false) { - return fromElementType.GetBasePropertyMapping().ToColumns(GetPropertyPath(propertyName)); + return fromElementType.GetBasePropertyMapping().ToColumns(GetPropertyPath(propertyName), useLastIndex); } #endregion diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/IntoClause.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/IntoClause.cs index a59e6d5a8ec..05e4b9a4248 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/IntoClause.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/IntoClause.cs @@ -214,7 +214,7 @@ private bool IsSuperclassProperty(string propertyName) // // we may want to disallow it for discrim-subclass just for // consistency-sake (currently does not work anyway)... - return _persister.GetSubclassPropertyTableNumber(propertyName) != 0; + return _persister.GetSubclassPropertyTableNumber(propertyName, false) != 0; } /// @@ -263,4 +263,4 @@ private static bool AreSqlTypesCompatible(SqlType target, SqlType source) return target.Equals(source); } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs b/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs index 682dc73e084..6dafb3fda1d 100644 --- a/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs +++ b/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs @@ -13,6 +13,7 @@ using NHibernate.Type; using NHibernate.Util; using IQueryable = NHibernate.Persister.Entity.IQueryable; +using static NHibernate.Impl.CriteriaImpl; namespace NHibernate.Loader.Criteria { @@ -766,7 +767,15 @@ private bool TryGetColumns(ICriteria subcriteria, string path, bool verifyProper return false; } - columns = propertyMapping.ToColumns(GetSQLAlias(pathCriteria), propertyName); + // here we can check if the condition belongs to a with clause + bool useLastIndex = false; + var withClause = pathCriteria as Subcriteria != null ? ((Subcriteria) pathCriteria).WithClause as SimpleExpression : null; + if (withClause != null && withClause.PropertyName == propertyName) + { + useLastIndex = true; + } + + columns = propertyMapping.ToColumns(GetSQLAlias(pathCriteria), propertyName, useLastIndex); return true; } diff --git a/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs b/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs index 20701df9435..707bf423d38 100644 --- a/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs +++ b/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs @@ -1386,7 +1386,7 @@ public bool IsManyToManyFiltered(IDictionary enabledFilters) return IsManyToMany && (manyToManyWhereString != null || manyToManyFilterHelper.IsAffectedBy(enabledFilters)); } - public string[] ToColumns(string alias, string propertyName) + public string[] ToColumns(string alias, string propertyName, bool useLastIndex = false) { if ("index".Equals(propertyName)) { @@ -1397,10 +1397,10 @@ public string[] ToColumns(string alias, string propertyName) return StringHelper.Qualify(alias, indexColumnNames); } - return elementPropertyMapping.ToColumns(alias, propertyName); + return elementPropertyMapping.ToColumns(alias, propertyName, useLastIndex); } - public string[] ToColumns(string propertyName) + public string[] ToColumns(string propertyName, bool useLastIndex = false) { if ("index".Equals(propertyName)) { @@ -1412,7 +1412,7 @@ public string[] ToColumns(string propertyName) return indexColumnNames; } - return elementPropertyMapping.ToColumns(propertyName); + return elementPropertyMapping.ToColumns(propertyName, useLastIndex); } protected abstract SqlCommandInfo GenerateDeleteString(); diff --git a/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs b/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs index e9e2f89dc51..569c53eb63d 100644 --- a/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs +++ b/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs @@ -57,7 +57,7 @@ public bool TryToType(string propertyName, out IType type) } } - public string[] ToColumns(string alias, string propertyName) + public string[] ToColumns(string alias, string propertyName, bool useLastIndex = false) { string[] cols; switch (propertyName) @@ -107,7 +107,7 @@ public string[] ToColumns(string alias, string propertyName) } } - public string[] ToColumns(string propertyName) + public string[] ToColumns(string propertyName, bool useLastIndex = false) { throw new System.NotSupportedException("References to collections must be define a SQL alias"); } @@ -117,4 +117,4 @@ public IType Type get { return memberPersister.CollectionType; } } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs b/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs index 20e9899ddb6..6ab04642a5f 100644 --- a/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs +++ b/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs @@ -47,7 +47,7 @@ public bool TryToType(string propertyName, out IType outType) } } - public string[] ToColumns(string alias, string propertyName) + public string[] ToColumns(string alias, string propertyName, bool useLastIndex) { if (propertyName == null || "id".Equals(propertyName)) { @@ -59,7 +59,7 @@ public string[] ToColumns(string alias, string propertyName) } } - public string[] ToColumns(string propertyName) + public string[] ToColumns(string propertyName, bool useLastIndex) { throw new System.NotSupportedException("References to collections must be define a SQL alias"); } @@ -71,4 +71,4 @@ public IType Type #endregion } -} \ No newline at end of file +} diff --git a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs index 0bf2e57c673..32639db4d0e 100644 --- a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs @@ -1118,9 +1118,9 @@ protected virtual bool IsIdOfTable(int property, int table) return false; } - protected abstract int GetSubclassPropertyTableNumber(int i); + protected abstract int GetSubclassPropertyTableNumber(int i, bool useLastIndex); - internal int GetSubclassPropertyTableNumber(string propertyName, string entityName) + internal int GetSubclassPropertyTableNumber(string propertyName, string entityName, bool useLastIndex = false) { var type = propertyMapping.ToType(propertyName); if (type.IsAssociationType && ((IAssociationType) type).UseLHSPrimaryKey) @@ -1271,7 +1271,7 @@ protected internal virtual SqlString GenerateLazySelectString() // use the subclass closure int propertyNumber = GetSubclassPropertyIndex(lazyPropertyNames[i]); - int tableNumber = GetSubclassPropertyTableNumber(propertyNumber); + int tableNumber = GetSubclassPropertyTableNumber(propertyNumber, false); tableNumbers.Add(tableNumber); int[] colNumbers = subclassPropertyColumnNumberClosure[propertyNumber]; @@ -1326,7 +1326,7 @@ protected virtual IDictionary GenerateLazySelectStringsByFetc // use the subclass closure var propertyNumber = GetSubclassPropertyIndex(lazyPropertyDescriptor.Name); - var tableNumber = GetSubclassPropertyTableNumber(propertyNumber); + var tableNumber = GetSubclassPropertyTableNumber(propertyNumber, false); tableNumbers.Add(tableNumber); var colNumbers = subclassPropertyColumnNumberClosure[propertyNumber]; @@ -2050,12 +2050,12 @@ public virtual string GetRootTableAlias(string drivingAlias) return drivingAlias; } - public virtual string[] ToColumns(string alias, string propertyName) + public virtual string[] ToColumns(string alias, string propertyName, bool useLastIndex = false) { - return propertyMapping.ToColumns(alias, propertyName); + return propertyMapping.ToColumns(alias, propertyName, useLastIndex); } - public string[] ToColumns(string propertyName) + public string[] ToColumns(string propertyName, bool useLastIndex = false) { return propertyMapping.GetColumnNames(propertyName); } @@ -2083,7 +2083,7 @@ public string[] GetPropertyColumnNames(string propertyName) /// SingleTableEntityPersister defines an overloaded form /// which takes the entity name. /// - public virtual int GetSubclassPropertyTableNumber(string propertyPath) + public virtual int GetSubclassPropertyTableNumber(string propertyPath, bool useLastIndex) { string rootPropertyName = StringHelper.Root(propertyPath); IType type = propertyMapping.ToType(rootPropertyName); @@ -2110,13 +2110,16 @@ public virtual int GetSubclassPropertyTableNumber(string propertyPath) return getSubclassColumnTableNumberClosure()[idx]; } }*/ - int index = Array.LastIndexOf(SubclassPropertyNameClosure, rootPropertyName); //TODO: optimize this better! - return index == -1 ? 0 : GetSubclassPropertyTableNumber(index); + int index = useLastIndex + ? Array.LastIndexOf(SubclassPropertyNameClosure, rootPropertyName) + : Array.IndexOf(SubclassPropertyNameClosure, rootPropertyName); //TODO: optimize this better! + + return index == -1 ? 0 : GetSubclassPropertyTableNumber(index, false); } public virtual Declarer GetSubclassPropertyDeclarer(string propertyPath) { - int tableIndex = GetSubclassPropertyTableNumber(propertyPath); + int tableIndex = GetSubclassPropertyTableNumber(propertyPath, false); if (tableIndex == 0) { return Declarer.Class; @@ -2164,7 +2167,7 @@ private string GetSubclassAliasedColumn(string rootAlias, int tableNumber, strin public string[] ToColumns(string name, int i) { - string alias = GenerateTableAlias(name, GetSubclassPropertyTableNumber(i)); + string alias = GenerateTableAlias(name, GetSubclassPropertyTableNumber(i, false)); string[] cols = GetSubclassPropertyColumnNames(i); string[] templates = SubclassPropertyFormulaTemplateClosure[i]; string[] result = new string[cols.Length]; @@ -2398,7 +2401,7 @@ private EntityLoader GetAppropriateUniqueKeyLoader(string propertyName, IDiction return uniqueKeyLoaders[propertyName]; } - return CreateUniqueKeyLoader(propertyMapping.ToType(propertyName), propertyMapping.ToColumns(propertyName), enabledFilters); + return CreateUniqueKeyLoader(propertyMapping.ToType(propertyName), propertyMapping.ToColumns(propertyName, false), enabledFilters); } public int GetPropertyIndex(string propertyName) @@ -3682,7 +3685,7 @@ private IDictionary GetColumnsToTableAliasMap(string rootAlias) if (cols != null && cols.Length > 0) { - PropertyKey key = new PropertyKey(cols[0], GetSubclassPropertyTableNumber(i)); + PropertyKey key = new PropertyKey(cols[0], GetSubclassPropertyTableNumber(i, false)); propDictionary[key] = property; } } diff --git a/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs b/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs index c027568bf18..40f9550802e 100644 --- a/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs +++ b/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs @@ -44,7 +44,7 @@ public bool TryToType(string propertyName, out IType type) return typesByPropertyPath.TryGetValue(propertyName, out type); } - public virtual string[] ToColumns(string alias, string propertyName) + public virtual string[] ToColumns(string alias, string propertyName, bool useLastIndex) { //TODO: *two* hashmap lookups here is one too many... string[] columns = GetColumns(propertyName); @@ -71,7 +71,7 @@ private string[] GetColumns(string propertyName) return columns; } - public virtual string[] ToColumns(string propertyName) + public virtual string[] ToColumns(string propertyName, bool useLastIndex) { string[] columns = GetColumns(propertyName); diff --git a/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs b/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs index 02f625bd550..ff0e71aefc0 100644 --- a/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs +++ b/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs @@ -26,11 +26,10 @@ public override IType Type get { return persister.Type; } } - public override string[] ToColumns(string alias, string propertyName) + public override string[] ToColumns(string alias, string propertyName, bool useLastIndex) { - return - base.ToColumns(persister.GenerateTableAlias(alias, persister.GetSubclassPropertyTableNumber(propertyName)), - propertyName); + var tableAlias = persister.GenerateTableAlias(alias, persister.GetSubclassPropertyTableNumber(propertyName, useLastIndex)); + return base.ToColumns(tableAlias, propertyName, useLastIndex); } } } diff --git a/src/NHibernate/Persister/Entity/IPropertyMapping.cs b/src/NHibernate/Persister/Entity/IPropertyMapping.cs index dbe08dd9139..fc1dc5bf495 100644 --- a/src/NHibernate/Persister/Entity/IPropertyMapping.cs +++ b/src/NHibernate/Persister/Entity/IPropertyMapping.cs @@ -34,10 +34,11 @@ public interface IPropertyMapping /// /// /// + /// /// - string[] ToColumns(string alias, string propertyName); + string[] ToColumns(string alias, string propertyName, bool useLastIndex = false); /// Given a property path, return the corresponding column name(s). - string[] ToColumns(string propertyName); + string[] ToColumns(string propertyName, bool useLastIndex = false); } -} \ No newline at end of file +} diff --git a/src/NHibernate/Persister/Entity/IQueryable.cs b/src/NHibernate/Persister/Entity/IQueryable.cs index 2178b43a024..fdf69b1ddf8 100644 --- a/src/NHibernate/Persister/Entity/IQueryable.cs +++ b/src/NHibernate/Persister/Entity/IQueryable.cs @@ -112,13 +112,14 @@ public interface IQueryable : ILoadable, IPropertyMapping, IJoinable /// to which this property is mapped. /// /// The name of the property. + /// The name of the property. /// The number of the table to which the property is mapped. /// /// Note that this is not relative to the results from {@link #getConstraintOrderedTableNameClosure()}. /// It is relative to the subclass table name closure maintained internal to the persister (yick!). /// It is also relative to the indexing used to resolve {@link #getSubclassTableName}... /// - int GetSubclassPropertyTableNumber(string propertyPath); + int GetSubclassPropertyTableNumber(string propertyPath, bool useLastIndex); /// Determine whether the given property is declared by our /// mapped class, our super class, or one of our subclasses... diff --git a/src/NHibernate/Persister/Entity/JoinedSubclassEntityPersister.cs b/src/NHibernate/Persister/Entity/JoinedSubclassEntityPersister.cs index 47b0a7c19a7..d11fb3d37c9 100644 --- a/src/NHibernate/Persister/Entity/JoinedSubclassEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/JoinedSubclassEntityPersister.cs @@ -526,7 +526,7 @@ public override string GenerateFilterConditionAlias(string rootAlias) return GenerateTableAlias(rootAlias, tableSpan - 1); } - public override string[] ToColumns(string alias, string propertyName) + public override string[] ToColumns(string alias, string propertyName, bool useLastIndex) { if (EntityClass.Equals(propertyName)) { @@ -542,11 +542,11 @@ public override string[] ToColumns(string alias, string propertyName) } else { - return base.ToColumns(alias, propertyName); + return base.ToColumns(alias, propertyName, useLastIndex); } } - protected override int GetSubclassPropertyTableNumber(int i) + protected override int GetSubclassPropertyTableNumber(int i, bool useLastIndex) { return subclassPropertyTableNumberClosure[i]; } diff --git a/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs b/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs index 9aa8a71a3e4..4dbb004fc8e 100644 --- a/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs @@ -675,7 +675,7 @@ protected override void AddDiscriminatorToSelect(SelectFragment select, string n select.AddColumn(name, DiscriminatorColumnName, DiscriminatorAlias); } - protected override int GetSubclassPropertyTableNumber(int i) + protected override int GetSubclassPropertyTableNumber(int i, bool useLastIndex) { return subclassPropertyTableNumberClosure[i]; } @@ -696,12 +696,12 @@ protected override void AddDiscriminatorToInsert(SqlInsertBuilder insert) protected override bool IsSubclassPropertyDeferred(string propertyName, string entityName) { return - hasSequentialSelects && IsSubclassTableSequentialSelect(base.GetSubclassPropertyTableNumber(propertyName, entityName)); + hasSequentialSelects && IsSubclassTableSequentialSelect(base.GetSubclassPropertyTableNumber(propertyName, entityName, false)); } protected override bool IsPropertyDeferred(int propertyIndex) { - return _hasSequentialSelect && subclassTableSequentialSelect[GetSubclassPropertyTableNumber(propertyIndex)]; + return _hasSequentialSelect && subclassTableSequentialSelect[GetSubclassPropertyTableNumber(propertyIndex, false)]; } //Since v5.3 @@ -713,9 +713,9 @@ public override bool HasSequentialSelect //Since v5.3 [Obsolete("This method has no more usage in NHibernate and will be removed in a future version.")] - public new int GetSubclassPropertyTableNumber(string propertyName, string entityName) + public new int GetSubclassPropertyTableNumber(string propertyName, string entityName, bool useLastIndex = false) { - return base.GetSubclassPropertyTableNumber(propertyName, entityName); + return base.GetSubclassPropertyTableNumber(propertyName, entityName, useLastIndex); } //Since v5.3 diff --git a/src/NHibernate/Persister/Entity/UnionSubclassEntityPersister.cs b/src/NHibernate/Persister/Entity/UnionSubclassEntityPersister.cs index c8d2c834cac..0b1dd021b61 100644 --- a/src/NHibernate/Persister/Entity/UnionSubclassEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/UnionSubclassEntityPersister.cs @@ -289,12 +289,12 @@ protected override void AddDiscriminatorToSelect(SelectFragment select, string n select.AddColumn(name, DiscriminatorColumnName, DiscriminatorAlias); } - protected override int GetSubclassPropertyTableNumber(int i) + protected override int GetSubclassPropertyTableNumber(int i, bool useLastIndex) { return 0; } - public override int GetSubclassPropertyTableNumber(string propertyName) + public override int GetSubclassPropertyTableNumber(string propertyName, bool useLastIndex) { return 0; } From cc0a57297996b518bde2d42d1aebd4da8cbbe613 Mon Sep 17 00:00:00 2001 From: "g.yakimov" Date: Fri, 10 Apr 2020 14:02:44 +0300 Subject: [PATCH 28/43] revert config changes --- src/NHibernate.Test/App.config | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/NHibernate.Test/App.config b/src/NHibernate.Test/App.config index 8d0aa714996..d3965012af5 100644 --- a/src/NHibernate.Test/App.config +++ b/src/NHibernate.Test/App.config @@ -7,7 +7,7 @@ - + @@ -31,7 +31,7 @@ NHibernate.Dialect.MsSql2008Dialect NHibernate.Driver.Sql2008ClientDriver - Server=localhost;Database=nhibernate;Integrated Security=SSPI + Server=localhost\sqlexpress;Database=nhibernate;Integrated Security=SSPI NHibernate.Test.DebugConnectionProvider, NHibernate.Test ReadCommitted From 3ba3b619561f6a2e338843a71994e274b407b509 Mon Sep 17 00:00:00 2001 From: maca88 Date: Sun, 12 Apr 2020 16:56:12 +0200 Subject: [PATCH 29/43] Improve async locking (#2147) --- src/AsyncGenerator.yml | 3 - .../AsyncReaderWriterLockFixture.cs | 215 ++++++++ src/NHibernate.Test/NHibernate.Test.csproj | 5 + .../AsyncReaderWriterLockFixture.cs | 475 ++++++++++++++++++ src/NHibernate/Async/Cache/ReadWriteCache.cs | 21 +- .../Async/Cache/UpdateTimestampsCache.cs | 66 ++- .../Async/Id/Enhanced/OptimizerFactory.cs | 15 +- .../Async/Id/Enhanced/TableGenerator.cs | 4 +- src/NHibernate/Async/Id/IncrementGenerator.cs | 6 +- .../Async/Id/SequenceHiLoGenerator.cs | 4 +- src/NHibernate/Async/Id/TableGenerator.cs | 4 +- src/NHibernate/Async/Id/TableHiLoGenerator.cs | 4 +- src/NHibernate/Cache/ReadWriteCache.cs | 20 +- src/NHibernate/Cache/UpdateTimestampsCache.cs | 74 +-- .../Id/Enhanced/OptimizerFactory.cs | 98 ++-- src/NHibernate/Id/Enhanced/TableGenerator.cs | 7 +- src/NHibernate/Id/IncrementGenerator.cs | 13 +- src/NHibernate/Id/SequenceHiLoGenerator.cs | 35 +- src/NHibernate/Id/TableGenerator.cs | 13 +- src/NHibernate/Id/TableHiLoGenerator.cs | 35 +- src/NHibernate/Util/AsyncLock.cs | 18 +- src/NHibernate/Util/AsyncReaderWriterLock.cs | 252 ++++++++++ 22 files changed, 1185 insertions(+), 202 deletions(-) create mode 100644 src/NHibernate.Test/Async/UtilityTest/AsyncReaderWriterLockFixture.cs create mode 100644 src/NHibernate.Test/UtilityTest/AsyncReaderWriterLockFixture.cs create mode 100644 src/NHibernate/Util/AsyncReaderWriterLock.cs diff --git a/src/AsyncGenerator.yml b/src/AsyncGenerator.yml index 99c765a110c..5c7754819fb 100644 --- a/src/AsyncGenerator.yml +++ b/src/AsyncGenerator.yml @@ -160,9 +160,6 @@ transformation: configureAwaitArgument: false localFunctions: true - asyncLock: - type: NHibernate.Util.AsyncLock - methodName: LockAsync documentationComments: addOrReplaceMethodSummary: - name: Commit diff --git a/src/NHibernate.Test/Async/UtilityTest/AsyncReaderWriterLockFixture.cs b/src/NHibernate.Test/Async/UtilityTest/AsyncReaderWriterLockFixture.cs new file mode 100644 index 00000000000..b22f20a3cd0 --- /dev/null +++ b/src/NHibernate.Test/Async/UtilityTest/AsyncReaderWriterLockFixture.cs @@ -0,0 +1,215 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using NHibernate.Util; +using NUnit.Framework; + +namespace NHibernate.Test.UtilityTest +{ + public partial class AsyncReaderWriterLockFixture + { + + [Test, Explicit] + public async Task TestConcurrentReadWriteAsync() + { + var l = new AsyncReaderWriterLock(); + for (var i = 0; i < 2; i++) + { + var writeReleaser = await (l.WriteLockAsync()); + Assert.That(l.Writing, Is.True); + + var secondWriteSemaphore = new SemaphoreSlim(0); + var secondWriteReleaser = default(AsyncReaderWriterLock.Releaser); + var secondWriteThread = new Thread( + () => + { + secondWriteSemaphore.Wait(); + secondWriteReleaser = l.WriteLock(); + }); + secondWriteThread.Priority = ThreadPriority.Highest; + secondWriteThread.Start(); + await (AssertEqualValueAsync(() => secondWriteThread.ThreadState == ThreadState.WaitSleepJoin, true)); + + var secondReadThreads = new Thread[20]; + var secondReadReleasers = new AsyncReaderWriterLock.Releaser[secondReadThreads.Length]; + var secondReadSemaphore = new SemaphoreSlim(0); + for (var j = 0; j < secondReadReleasers.Length; j++) + { + var index = j; + var thread = new Thread( + () => + { + secondReadSemaphore.Wait(); + secondReadReleasers[index] = l.ReadLock(); + }); + thread.Priority = ThreadPriority.Highest; + secondReadThreads[j] = thread; + thread.Start(); + } + + await (AssertEqualValueAsync(() => secondReadThreads.All(o => o.ThreadState == ThreadState.WaitSleepJoin), true)); + + var firstReadReleaserTasks = new Task[30]; + var firstReadStopSemaphore = new SemaphoreSlim(0); + for (var j = 0; j < firstReadReleaserTasks.Length; j++) + { + firstReadReleaserTasks[j] = Task.Run(async () => + { + var releaser = await (l.ReadLockAsync()); + await (firstReadStopSemaphore.WaitAsync()); + releaser.Dispose(); + }); + } + + await (AssertEqualValueAsync(() => l.ReadersWaiting, firstReadReleaserTasks.Length, waitDelay: 60000)); + + writeReleaser.Dispose(); + secondWriteSemaphore.Release(); + secondReadSemaphore.Release(secondReadThreads.Length); + await (Task.Delay(1000)); + firstReadStopSemaphore.Release(firstReadReleaserTasks.Length); + + await (AssertEqualValueAsync(() => firstReadReleaserTasks.All(o => o.IsCompleted), true)); + Assert.That(l.ReadersWaiting, Is.EqualTo(secondReadThreads.Length)); + Assert.That(l.CurrentReaders, Is.EqualTo(0)); + await (AssertEqualValueAsync(() => secondWriteThread.IsAlive, false)); + await (AssertEqualValueAsync(() => secondReadThreads.All(o => o.IsAlive), true)); + + secondWriteReleaser.Dispose(); + await (AssertEqualValueAsync(() => secondReadThreads.All(o => !o.IsAlive), true)); + + Assert.That(l.ReadersWaiting, Is.EqualTo(0)); + Assert.That(l.CurrentReaders, Is.EqualTo(secondReadThreads.Length)); + + foreach (var secondReadReleaser in secondReadReleasers) + { + secondReadReleaser.Dispose(); + } + + Assert.That(l.ReadersWaiting, Is.EqualTo(0)); + Assert.That(l.CurrentReaders, Is.EqualTo(0)); + } + } + + [Test] + public async Task TestInvaildExitReadLockUsageAsync() + { + var l = new AsyncReaderWriterLock(); + var readReleaser = await (l.ReadLockAsync()); + var readReleaser2 = await (l.ReadLockAsync()); + + readReleaser.Dispose(); + readReleaser2.Dispose(); + Assert.Throws(() => readReleaser.Dispose()); + Assert.Throws(() => readReleaser2.Dispose()); + } + + [Test] + public void TestOperationAfterDisposeAsync() + { + var l = new AsyncReaderWriterLock(); + l.Dispose(); + + Assert.ThrowsAsync(() => l.ReadLockAsync()); + Assert.ThrowsAsync(() => l.WriteLockAsync()); + } + + [Test] + public async Task TestInvaildExitWriteLockUsageAsync() + { + var l = new AsyncReaderWriterLock(); + var writeReleaser = await (l.WriteLockAsync()); + + writeReleaser.Dispose(); + Assert.Throws(() => writeReleaser.Dispose()); + } + + private static async Task LockAsync( + AsyncReaderWriterLock readWriteLock, + Random random, + LockStatistics lockStatistics, + System.Action checkLockAction, + Func canContinue, CancellationToken cancellationToken = default(CancellationToken)) + { + while (canContinue()) + { + var isRead = random.Next(100) < 80; + var releaser = isRead ? await (readWriteLock.ReadLockAsync()) : await (readWriteLock.WriteLockAsync()); + lock (readWriteLock) + { + if (isRead) + { + lockStatistics.ReadLockCount++; + } + else + { + lockStatistics.WriteLockCount++; + } + + checkLockAction(); + } + + await (Task.Delay(10, cancellationToken)); + + lock (readWriteLock) + { + releaser.Dispose(); + if (isRead) + { + lockStatistics.ReadLockCount--; + } + else + { + lockStatistics.WriteLockCount--; + } + + checkLockAction(); + } + } + } + + private static async Task AssertEqualValueAsync(Func getValueFunc, T value, Task task = null, int waitDelay = 5000, CancellationToken cancellationToken = default(CancellationToken)) + { + var currentTime = 0; + var step = 5; + while (currentTime < waitDelay) + { + if (task != null) + { + task.Wait(step); + } + else + { + await (Task.Delay(step, cancellationToken)); + } + + currentTime += step; + if (getValueFunc().Equals(value)) + { + return; + } + + step *= 2; + } + + Assert.That(getValueFunc(), Is.EqualTo(value)); + } + + private static Task AssertTaskCompletedAsync(Task task, CancellationToken cancellationToken = default(CancellationToken)) + { + return AssertEqualValueAsync(() => task.IsCompleted, true, task, cancellationToken: cancellationToken); + } + } +} diff --git a/src/NHibernate.Test/NHibernate.Test.csproj b/src/NHibernate.Test/NHibernate.Test.csproj index 365e1dcbf99..16ef3aab72e 100644 --- a/src/NHibernate.Test/NHibernate.Test.csproj +++ b/src/NHibernate.Test/NHibernate.Test.csproj @@ -44,6 +44,11 @@ + + + UtilityTest\AsyncReaderWriterLock.cs + + diff --git a/src/NHibernate.Test/UtilityTest/AsyncReaderWriterLockFixture.cs b/src/NHibernate.Test/UtilityTest/AsyncReaderWriterLockFixture.cs new file mode 100644 index 00000000000..b737b044def --- /dev/null +++ b/src/NHibernate.Test/UtilityTest/AsyncReaderWriterLockFixture.cs @@ -0,0 +1,475 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using NHibernate.Util; +using NUnit.Framework; + +namespace NHibernate.Test.UtilityTest +{ + public partial class AsyncReaderWriterLockFixture + { + [Test] + public void TestBlocking() + { + var l = new AsyncReaderWriterLock(); + for (var i = 0; i < 2; i++) + { + var readReleaser = l.ReadLock(); + Assert.That(l.CurrentReaders, Is.EqualTo(1)); + + var readReleaserTask = Task.Run(() => l.ReadLock()); + AssertEqualValue(() => l.CurrentReaders, 2, readReleaserTask); + AssertTaskCompleted(readReleaserTask); + + var writeReleaserTask = Task.Run(() => l.WriteLock()); + AssertEqualValue(() => l.AcquiredWriteLock, true, writeReleaserTask); + AssertEqualValue(() => l.Writing, false, writeReleaserTask); + Assert.That(writeReleaserTask.IsCompleted, Is.False); + + readReleaser.Dispose(); + Assert.That(l.CurrentReaders, Is.EqualTo(1)); + Assert.That(writeReleaserTask.IsCompleted, Is.False); + + readReleaserTask.Result.Dispose(); + Assert.That(l.CurrentReaders, Is.EqualTo(0)); + AssertEqualValue(() => l.Writing, true, writeReleaserTask); + AssertTaskCompleted(writeReleaserTask); + + readReleaserTask = Task.Run(() => l.ReadLock()); + AssertEqualValue(() => l.ReadersWaiting, 1, readReleaserTask); + Assert.That(readReleaserTask.IsCompleted, Is.False); + + var writeReleaserTask2 = Task.Run(() => l.WriteLock()); + AssertEqualValue(() => l.WritersWaiting, 1, writeReleaserTask2); + Assert.That(writeReleaserTask2.IsCompleted, Is.False); + + writeReleaserTask.Result.Dispose(); + AssertEqualValue(() => l.WritersWaiting, 0, writeReleaserTask2); + AssertEqualValue(() => l.Writing, true, writeReleaserTask2); + Assert.That(readReleaserTask.IsCompleted, Is.False); + AssertTaskCompleted(writeReleaserTask2); + + writeReleaserTask2.Result.Dispose(); + AssertEqualValue(() => l.Writing, false, writeReleaserTask2); + AssertEqualValue(() => l.ReadersWaiting, 0, readReleaserTask); + AssertEqualValue(() => l.CurrentReaders, 1, readReleaserTask); + AssertTaskCompleted(readReleaserTask); + + readReleaserTask.Result.Dispose(); + Assert.That(l.ReadersWaiting, Is.EqualTo(0)); + Assert.That(l.WritersWaiting, Is.EqualTo(0)); + Assert.That(l.CurrentReaders, Is.EqualTo(0)); + Assert.That(l.Writing, Is.False); + } + } + + [Test] + public void TestBlockingAsync() + { + var l = new AsyncReaderWriterLock(); + for (var i = 0; i < 2; i++) + { + var readReleaserTask = l.ReadLockAsync(); + AssertEqualValue(() => l.CurrentReaders, 1, readReleaserTask); + AssertTaskCompleted(readReleaserTask); + + var readReleaserTask2 = l.ReadLockAsync(); + AssertEqualValue(() => l.CurrentReaders, 2, readReleaserTask2); + AssertTaskCompleted(readReleaserTask2); + + var writeReleaserTask = l.WriteLockAsync(); + AssertEqualValue(() => l.AcquiredWriteLock, true, writeReleaserTask); + AssertEqualValue(() => l.Writing, false, writeReleaserTask); + Assert.That(writeReleaserTask.IsCompleted, Is.False); + + readReleaserTask.Result.Dispose(); + Assert.That(l.CurrentReaders, Is.EqualTo(1)); + Assert.That(writeReleaserTask.IsCompleted, Is.False); + + readReleaserTask2.Result.Dispose(); + Assert.That(l.CurrentReaders, Is.EqualTo(0)); + AssertEqualValue(() => l.Writing, true, writeReleaserTask); + AssertTaskCompleted(writeReleaserTask); + + readReleaserTask = l.ReadLockAsync(); + AssertEqualValue(() => l.ReadersWaiting, 1, readReleaserTask); + Assert.That(readReleaserTask.IsCompleted, Is.False); + + var writeReleaserTask2 = l.WriteLockAsync(); + AssertEqualValue(() => l.WritersWaiting, 1, writeReleaserTask2); + Assert.That(writeReleaserTask2.IsCompleted, Is.False); + + writeReleaserTask.Result.Dispose(); + AssertEqualValue(() => l.WritersWaiting, 0, writeReleaserTask2); + AssertEqualValue(() => l.Writing, true, writeReleaserTask2); + Assert.That(readReleaserTask.IsCompleted, Is.False); + AssertTaskCompleted(writeReleaserTask2); + + writeReleaserTask2.Result.Dispose(); + AssertEqualValue(() => l.Writing, false, writeReleaserTask2); + AssertEqualValue(() => l.ReadersWaiting, 0, readReleaserTask); + AssertEqualValue(() => l.CurrentReaders, 1, readReleaserTask); + AssertTaskCompleted(readReleaserTask); + + readReleaserTask.Result.Dispose(); + Assert.That(l.ReadersWaiting, Is.EqualTo(0)); + Assert.That(l.WritersWaiting, Is.EqualTo(0)); + Assert.That(l.CurrentReaders, Is.EqualTo(0)); + Assert.That(l.Writing, Is.False); + } + } + + [Test, Explicit] + public void TestConcurrentReadWrite() + { + var l = new AsyncReaderWriterLock(); + for (var i = 0; i < 2; i++) + { + var writeReleaser = l.WriteLock(); + Assert.That(l.Writing, Is.True); + + var secondWriteSemaphore = new SemaphoreSlim(0); + var secondWriteReleaser = default(AsyncReaderWriterLock.Releaser); + var secondWriteThread = new Thread( + () => + { + secondWriteSemaphore.Wait(); + secondWriteReleaser = l.WriteLock(); + }); + secondWriteThread.Priority = ThreadPriority.Highest; + secondWriteThread.Start(); + AssertEqualValue(() => secondWriteThread.ThreadState == ThreadState.WaitSleepJoin, true); + + var secondReadThreads = new Thread[20]; + var secondReadReleasers = new AsyncReaderWriterLock.Releaser[secondReadThreads.Length]; + var secondReadSemaphore = new SemaphoreSlim(0); + for (var j = 0; j < secondReadReleasers.Length; j++) + { + var index = j; + var thread = new Thread( + () => + { + secondReadSemaphore.Wait(); + secondReadReleasers[index] = l.ReadLock(); + }); + thread.Priority = ThreadPriority.Highest; + secondReadThreads[j] = thread; + thread.Start(); + } + + AssertEqualValue(() => secondReadThreads.All(o => o.ThreadState == ThreadState.WaitSleepJoin), true); + + var firstReadReleaserTasks = new Task[30]; + var firstReadStopSemaphore = new SemaphoreSlim(0); + for (var j = 0; j < firstReadReleaserTasks.Length; j++) + { + firstReadReleaserTasks[j] = Task.Run(() => + { + var releaser = l.ReadLock(); + firstReadStopSemaphore.Wait(); + releaser.Dispose(); + }); + } + + AssertEqualValue(() => l.ReadersWaiting, firstReadReleaserTasks.Length, waitDelay: 60000); + + writeReleaser.Dispose(); + secondWriteSemaphore.Release(); + secondReadSemaphore.Release(secondReadThreads.Length); + Thread.Sleep(1000); + firstReadStopSemaphore.Release(firstReadReleaserTasks.Length); + + AssertEqualValue(() => firstReadReleaserTasks.All(o => o.IsCompleted), true); + Assert.That(l.ReadersWaiting, Is.EqualTo(secondReadThreads.Length)); + Assert.That(l.CurrentReaders, Is.EqualTo(0)); + AssertEqualValue(() => secondWriteThread.IsAlive, false); + AssertEqualValue(() => secondReadThreads.All(o => o.IsAlive), true); + + secondWriteReleaser.Dispose(); + AssertEqualValue(() => secondReadThreads.All(o => !o.IsAlive), true); + + Assert.That(l.ReadersWaiting, Is.EqualTo(0)); + Assert.That(l.CurrentReaders, Is.EqualTo(secondReadThreads.Length)); + + foreach (var secondReadReleaser in secondReadReleasers) + { + secondReadReleaser.Dispose(); + } + + Assert.That(l.ReadersWaiting, Is.EqualTo(0)); + Assert.That(l.CurrentReaders, Is.EqualTo(0)); + } + } + + [Test] + public void TestInvaildExitReadLockUsage() + { + var l = new AsyncReaderWriterLock(); + var readReleaser = l.ReadLock(); + var readReleaser2 = l.ReadLock(); + + readReleaser.Dispose(); + readReleaser2.Dispose(); + Assert.Throws(() => readReleaser.Dispose()); + Assert.Throws(() => readReleaser2.Dispose()); + } + + [Test] + public void TestOperationAfterDispose() + { + var l = new AsyncReaderWriterLock(); + l.Dispose(); + + Assert.Throws(() => l.ReadLock()); + Assert.Throws(() => l.WriteLock()); + } + + [Test] + public void TestInvaildExitWriteLockUsage() + { + var l = new AsyncReaderWriterLock(); + var writeReleaser = l.WriteLock(); + + writeReleaser.Dispose(); + Assert.Throws(() => writeReleaser.Dispose()); + } + + [Test] + public void TestMixingSyncAndAsync() + { + var l = new AsyncReaderWriterLock(); + var readReleaser = l.ReadLock(); + Assert.That(l.CurrentReaders, Is.EqualTo(1)); + + var readReleaserTask = l.ReadLockAsync(); + AssertEqualValue(() => l.CurrentReaders, 2, readReleaserTask); + AssertTaskCompleted(readReleaserTask); + + readReleaser.Dispose(); + Assert.That(l.CurrentReaders, Is.EqualTo(1)); + + readReleaserTask.Result.Dispose(); + Assert.That(l.CurrentReaders, Is.EqualTo(0)); + + var writeReleaser = l.WriteLock(); + Assert.That(l.AcquiredWriteLock, Is.True); + + var writeReleaserTask = l.WriteLockAsync(); + AssertEqualValue(() => l.WritersWaiting, 1, writeReleaserTask); + Assert.That(writeReleaserTask.IsCompleted, Is.False); + + readReleaserTask = Task.Run(() => l.ReadLock()); + AssertEqualValue(() => l.ReadersWaiting, 1, readReleaserTask); + Assert.That(readReleaserTask.IsCompleted, Is.False); + + var readReleaserTask2 = l.ReadLockAsync(); + AssertEqualValue(() => l.ReadersWaiting, 2, readReleaserTask2); + Assert.That(readReleaserTask2.IsCompleted, Is.False); + + writeReleaser.Dispose(); + AssertEqualValue(() => l.WritersWaiting, 0, writeReleaserTask); + AssertEqualValue(() => l.Writing, true, writeReleaserTask); + AssertTaskCompleted(writeReleaserTask); + Assert.That(readReleaserTask.IsCompleted, Is.False); + Assert.That(readReleaserTask2.IsCompleted, Is.False); + + writeReleaserTask.Result.Dispose(); + AssertEqualValue(() => l.CurrentReaders, 2, readReleaserTask); + AssertEqualValue(() => l.ReadersWaiting, 0, readReleaserTask2); + AssertTaskCompleted(readReleaserTask); + AssertTaskCompleted(readReleaserTask2); + } + + [Test] + public void TestWritePriorityOverReadAsync() + { + var l = new AsyncReaderWriterLock(); + for (var i = 0; i < 2; i++) + { + var writeReleaser = l.WriteLock(); + Assert.That(l.AcquiredWriteLock, Is.True); + + var readReleaserTask = l.ReadLockAsync(); + AssertEqualValue(() => l.ReadersWaiting, 1, readReleaserTask); + + var writeReleaserTask = l.WriteLockAsync(); + AssertEqualValue(() => l.WritersWaiting, 1, writeReleaserTask); + + writeReleaser.Dispose(); + AssertEqualValue(() => l.WritersWaiting, 0, writeReleaserTask); + AssertEqualValue(() => l.ReadersWaiting, 1, readReleaserTask); + AssertTaskCompleted(writeReleaserTask); + + writeReleaserTask.Result.Dispose(); + AssertEqualValue(() => l.ReadersWaiting, 0, readReleaserTask); + AssertTaskCompleted(readReleaserTask); + + readReleaserTask.Result.Dispose(); + } + } + + [Test] + public void TestPartialReleasingReadLockAsync() + { + var l = new AsyncReaderWriterLock(); + var readReleaserTask = l.ReadLockAsync(); + AssertEqualValue(() => l.CurrentReaders, 1, readReleaserTask); + AssertTaskCompleted(readReleaserTask); + + var readReleaserTask2 = l.ReadLockAsync(); + AssertEqualValue(() => l.CurrentReaders, 2, readReleaserTask); + AssertTaskCompleted(readReleaserTask2); + + var writeReleaserTask = l.WriteLockAsync(); + AssertEqualValue(() => l.AcquiredWriteLock, true, writeReleaserTask); + AssertEqualValue(() => l.Writing, false, writeReleaserTask); + Assert.That(writeReleaserTask.IsCompleted, Is.False); + + var readReleaserTask3 = l.ReadLockAsync(); + AssertEqualValue(() => l.ReadersWaiting, 1, readReleaserTask3); + Assert.That(readReleaserTask3.IsCompleted, Is.False); + + readReleaserTask.Result.Dispose(); + Assert.That(writeReleaserTask.IsCompleted, Is.False); + Assert.That(readReleaserTask3.IsCompleted, Is.False); + + readReleaserTask2.Result.Dispose(); + AssertEqualValue(() => l.Writing, true, writeReleaserTask); + AssertEqualValue(() => l.ReadersWaiting, 1, readReleaserTask3); + AssertTaskCompleted(writeReleaserTask); + Assert.That(readReleaserTask3.IsCompleted, Is.False); + + writeReleaserTask.Result.Dispose(); + AssertEqualValue(() => l.ReadersWaiting, 0, readReleaserTask3); + AssertTaskCompleted(readReleaserTask3); + } + + [Test, Explicit] + public async Task TestLoadSyncAndAsync() + { + var l = new AsyncReaderWriterLock(); + var lockStatistics = new LockStatistics(); + var incorrectLockCount = false; + var tasks = new Task[20]; + var masterRandom = new Random(); + var cancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + + for (var i = 0; i < tasks.Length; i++) + { + // Ensure that each random has its own unique seed + var random = new Random(masterRandom.Next()); + tasks[i] = i % 2 == 0 + ? Task.Run(() => Lock(l, random, lockStatistics, CheckLockCount, CanContinue)) + : LockAsync(l, random, lockStatistics, CheckLockCount, CanContinue); + } + + await Task.WhenAll(tasks); + Assert.That(incorrectLockCount, Is.False); + + void CheckLockCount() + { + if (!lockStatistics.Validate()) + { + Volatile.Write(ref incorrectLockCount, true); + } + } + + bool CanContinue() + { + return !cancellationTokenSource.Token.IsCancellationRequested; + } + } + + private class LockStatistics + { + public int ReadLockCount { get; set; } + + public int WriteLockCount { get; set; } + + public bool Validate() + { + return (ReadLockCount == 0 && WriteLockCount == 0) || + (ReadLockCount > 0 && WriteLockCount == 0) || + (ReadLockCount == 0 && WriteLockCount == 1); + } + } + + private static void Lock( + AsyncReaderWriterLock readWriteLock, + Random random, + LockStatistics lockStatistics, + System.Action checkLockAction, + Func canContinue) + { + while (canContinue()) + { + var isRead = random.Next(100) < 80; + var releaser = isRead ? readWriteLock.ReadLock() : readWriteLock.WriteLock(); + lock (readWriteLock) + { + if (isRead) + { + lockStatistics.ReadLockCount++; + } + else + { + lockStatistics.WriteLockCount++; + } + + checkLockAction(); + } + + Thread.Sleep(10); + + lock (readWriteLock) + { + releaser.Dispose(); + if (isRead) + { + lockStatistics.ReadLockCount--; + } + else + { + lockStatistics.WriteLockCount--; + } + + checkLockAction(); + } + } + } + + private static void AssertEqualValue(Func getValueFunc, T value, Task task = null, int waitDelay = 5000) + { + var currentTime = 0; + var step = 5; + while (currentTime < waitDelay) + { + if (task != null) + { + task.Wait(step); + } + else + { + Thread.Sleep(step); + } + + currentTime += step; + if (getValueFunc().Equals(value)) + { + return; + } + + step *= 2; + } + + Assert.That(getValueFunc(), Is.EqualTo(value)); + } + + private static void AssertTaskCompleted(Task task) + { + AssertEqualValue(() => task.IsCompleted, true, task); + } + } +} diff --git a/src/NHibernate/Async/Cache/ReadWriteCache.cs b/src/NHibernate/Async/Cache/ReadWriteCache.cs index 326e344bbc4..eac3b2bc339 100644 --- a/src/NHibernate/Async/Cache/ReadWriteCache.cs +++ b/src/NHibernate/Async/Cache/ReadWriteCache.cs @@ -12,6 +12,7 @@ using System.Collections.Generic; using System.Linq; using NHibernate.Cache.Access; +using NHibernate.Util; namespace NHibernate.Cache { @@ -19,7 +20,6 @@ namespace NHibernate.Cache using System.Threading; public partial class ReadWriteCache : IBatchableCacheConcurrencyStrategy { - private readonly NHibernate.Util.AsyncLock _lockObjectAsync = new NHibernate.Util.AsyncLock(); /// /// Do not return an item whose timestamp is later than the current @@ -41,7 +41,7 @@ public partial class ReadWriteCache : IBatchableCacheConcurrencyStrategy public async Task GetAsync(CacheKey key, long txTimestamp, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _lockObjectAsync.LockAsync()) + using (await (_asyncReaderWriterLock.ReadLockAsync()).ConfigureAwait(false)) { if (log.IsDebugEnabled()) { @@ -70,7 +70,8 @@ public async Task GetManyAsync(CacheKey[] keys, long timestamp, Cancel log.Debug("Cache lookup: {0}", string.Join(",", keys.AsEnumerable())); } var result = new object[keys.Length]; - using (await _lockObjectAsync.LockAsync()) + cancellationToken.ThrowIfCancellationRequested(); + using (await (_asyncReaderWriterLock.ReadLockAsync()).ConfigureAwait(false)) { var lockables = await (_cache.GetManyAsync(keys, cancellationToken)).ConfigureAwait(false); for (var i = 0; i < lockables.Length; i++) @@ -92,7 +93,7 @@ public async Task GetManyAsync(CacheKey[] keys, long timestamp, Cancel public async Task LockAsync(CacheKey key, object version, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _lockObjectAsync.LockAsync()) + using (await (_asyncReaderWriterLock.WriteLockAsync()).ConfigureAwait(false)) { if (log.IsDebugEnabled()) { @@ -135,8 +136,9 @@ public async Task PutManyAsync( // MinValue means cache is disabled return result; } + cancellationToken.ThrowIfCancellationRequested(); - using (await _lockObjectAsync.LockAsync()) + using (await (_asyncReaderWriterLock.WriteLockAsync()).ConfigureAwait(false)) { if (log.IsDebugEnabled()) { @@ -205,8 +207,9 @@ public async Task PutAsync(CacheKey key, object value, long txTimestamp, o // MinValue means cache is disabled return false; } + cancellationToken.ThrowIfCancellationRequested(); - using (await _lockObjectAsync.LockAsync()) + using (await (_asyncReaderWriterLock.WriteLockAsync()).ConfigureAwait(false)) { if (log.IsDebugEnabled()) { @@ -270,7 +273,7 @@ private Task DecrementLockAsync(object key, CacheLock @lock, CancellationToken c public async Task ReleaseAsync(CacheKey key, ISoftLock clientLock, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _lockObjectAsync.LockAsync()) + using (await (_asyncReaderWriterLock.WriteLockAsync()).ConfigureAwait(false)) { if (log.IsDebugEnabled()) { @@ -343,7 +346,7 @@ public Task RemoveAsync(CacheKey key, CancellationToken cancellationToken) public async Task AfterUpdateAsync(CacheKey key, object value, object version, ISoftLock clientLock, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _lockObjectAsync.LockAsync()) + using (await (_asyncReaderWriterLock.WriteLockAsync()).ConfigureAwait(false)) { if (log.IsDebugEnabled()) { @@ -390,7 +393,7 @@ public async Task AfterUpdateAsync(CacheKey key, object value, object vers public async Task AfterInsertAsync(CacheKey key, object value, object version, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _lockObjectAsync.LockAsync()) + using (await (_asyncReaderWriterLock.WriteLockAsync()).ConfigureAwait(false)) { if (log.IsDebugEnabled()) { diff --git a/src/NHibernate/Async/Cache/UpdateTimestampsCache.cs b/src/NHibernate/Async/Cache/UpdateTimestampsCache.cs index f97f25401be..774c25bec75 100644 --- a/src/NHibernate/Async/Cache/UpdateTimestampsCache.cs +++ b/src/NHibernate/Async/Cache/UpdateTimestampsCache.cs @@ -22,10 +22,6 @@ namespace NHibernate.Cache using System.Threading; public partial class UpdateTimestampsCache { - private readonly NHibernate.Util.AsyncLock _preInvalidate = new NHibernate.Util.AsyncLock(); - private readonly NHibernate.Util.AsyncLock _invalidate = new NHibernate.Util.AsyncLock(); - private readonly NHibernate.Util.AsyncLock _isUpToDate = new NHibernate.Util.AsyncLock(); - private readonly NHibernate.Util.AsyncLock _areUpToDate = new NHibernate.Util.AsyncLock(); public virtual Task ClearAsync(CancellationToken cancellationToken) { @@ -55,20 +51,20 @@ public Task PreInvalidateAsync(object[] spaces, CancellationToken cancellationTo } } - [MethodImpl()] public virtual async Task PreInvalidateAsync(IReadOnlyCollection spaces, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _preInvalidate.LockAsync()) + if (spaces.Count == 0) + return; + cancellationToken.ThrowIfCancellationRequested(); + + using (await (_asyncReaderWriterLock.WriteLockAsync()).ConfigureAwait(false)) { //TODO: to handle concurrent writes correctly, this should return a Lock to the client var ts = _updateTimestamps.NextTimestamp() + _updateTimestamps.Timeout; await (SetSpacesTimestampAsync(spaces, ts, cancellationToken)).ConfigureAwait(false); - //TODO: return new Lock(ts); } - - //TODO: return new Lock(ts); } //Since v5.1 @@ -90,11 +86,14 @@ public Task InvalidateAsync(object[] spaces, CancellationToken cancellationToken } } - [MethodImpl()] public virtual async Task InvalidateAsync(IReadOnlyCollection spaces, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _invalidate.LockAsync()) + if (spaces.Count == 0) + return; + cancellationToken.ThrowIfCancellationRequested(); + + using (await (_asyncReaderWriterLock.WriteLockAsync()).ConfigureAwait(false)) { //TODO: to handle concurrent writes correctly, the client should pass in a Lock long ts = _updateTimestamps.NextTimestamp(); @@ -113,9 +112,6 @@ private Task SetSpacesTimestampAsync(IReadOnlyCollection spaces, long ts } try { - if (spaces.Count == 0) - return Task.CompletedTask; - return _updateTimestamps.PutManyAsync( spaces.ToArray(), ArrayHelper.Fill(ts, spaces.Count), cancellationToken); @@ -126,45 +122,45 @@ private Task SetSpacesTimestampAsync(IReadOnlyCollection spaces, long ts } } - [MethodImpl()] public virtual async Task IsUpToDateAsync(ISet spaces, long timestamp /* H2.1 has Long here */, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _isUpToDate.LockAsync()) - { - if (spaces.Count == 0) - return true; + if (spaces.Count == 0) + return true; + cancellationToken.ThrowIfCancellationRequested(); + using (await (_asyncReaderWriterLock.ReadLockAsync()).ConfigureAwait(false)) + { var lastUpdates = await (_updateTimestamps.GetManyAsync(spaces.ToArray(), cancellationToken)).ConfigureAwait(false); return lastUpdates.All(lastUpdate => !IsOutdated(lastUpdate as long?, timestamp)); } } - [MethodImpl()] public virtual async Task AreUpToDateAsync(ISet[] spaces, long[] timestamps, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _areUpToDate.LockAsync()) - { - if (spaces.Length == 0) - return Array.Empty(); + if (spaces.Length == 0) + return Array.Empty(); - var allSpaces = new HashSet(); - foreach (var sp in spaces) - { - allSpaces.UnionWith(sp); - } + var allSpaces = new HashSet(); + foreach (var sp in spaces) + { + allSpaces.UnionWith(sp); + } - if (allSpaces.Count == 0) - return ArrayHelper.Fill(true, spaces.Length); + if (allSpaces.Count == 0) + return ArrayHelper.Fill(true, spaces.Length); - var keys = allSpaces.ToArray(); + var keys = allSpaces.ToArray(); + cancellationToken.ThrowIfCancellationRequested(); + using (await (_asyncReaderWriterLock.ReadLockAsync()).ConfigureAwait(false)) + { var index = 0; var lastUpdatesBySpace = - (await (_updateTimestamps - .GetManyAsync(keys, cancellationToken)).ConfigureAwait(false)) - .ToDictionary(u => keys[index++], u => u as long?); + (await (_updateTimestamps + .GetManyAsync(keys, cancellationToken)).ConfigureAwait(false)) + .ToDictionary(u => keys[index++], u => u as long?); var results = new bool[spaces.Length]; for (var i = 0; i < spaces.Length; i++) diff --git a/src/NHibernate/Async/Id/Enhanced/OptimizerFactory.cs b/src/NHibernate/Async/Id/Enhanced/OptimizerFactory.cs index 63f63c9ff6c..5ba08f4a31e 100644 --- a/src/NHibernate/Async/Id/Enhanced/OptimizerFactory.cs +++ b/src/NHibernate/Async/Id/Enhanced/OptimizerFactory.cs @@ -24,13 +24,11 @@ public partial class OptimizerFactory public partial class HiLoOptimizer : OptimizerSupport { - private readonly NHibernate.Util.AsyncLock _generate = new NHibernate.Util.AsyncLock(); - [MethodImpl()] public override async Task GenerateAsync(IAccessCallback callback, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _generate.LockAsync()) + using (await (_asyncLock.LockAsync()).ConfigureAwait(false)) { if (_lastSourceValue < 0) { @@ -51,6 +49,7 @@ public override async Task GenerateAsync(IAccessCallback callback, Cance _lastSourceValue = await (callback.GetNextValueAsync(cancellationToken)).ConfigureAwait(false); _upperLimit = (_lastSourceValue * IncrementSize) + 1; } + return Make(_value++); } } @@ -101,13 +100,11 @@ public abstract partial class OptimizerSupport : IOptimizer public partial class PooledOptimizer : OptimizerSupport, IInitialValueAwareOptimizer { - private readonly NHibernate.Util.AsyncLock _generate = new NHibernate.Util.AsyncLock(); - [MethodImpl()] public override async Task GenerateAsync(IAccessCallback callback, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _generate.LockAsync()) + using (await (_asyncLock.LockAsync()).ConfigureAwait(false)) { if (_hiValue < 0) { @@ -134,6 +131,7 @@ public override async Task GenerateAsync(IAccessCallback callback, Cance _hiValue = await (callback.GetNextValueAsync(cancellationToken)).ConfigureAwait(false); _value = _hiValue - IncrementSize; } + return Make(_value++); } } @@ -145,13 +143,11 @@ public override async Task GenerateAsync(IAccessCallback callback, Cance public partial class PooledLoOptimizer : OptimizerSupport { - private readonly NHibernate.Util.AsyncLock _generate = new NHibernate.Util.AsyncLock(); - [MethodImpl()] public override async Task GenerateAsync(IAccessCallback callback, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _generate.LockAsync()) + using (await (_asyncLock.LockAsync()).ConfigureAwait(false)) { if (_lastSourceValue < 0 || _value >= (_lastSourceValue + IncrementSize)) { @@ -161,6 +157,7 @@ public override async Task GenerateAsync(IAccessCallback callback, Cance while (_value < 1) _value++; } + return Make(_value++); } } diff --git a/src/NHibernate/Async/Id/Enhanced/TableGenerator.cs b/src/NHibernate/Async/Id/Enhanced/TableGenerator.cs index 3dd339de624..0ca2de00611 100644 --- a/src/NHibernate/Async/Id/Enhanced/TableGenerator.cs +++ b/src/NHibernate/Async/Id/Enhanced/TableGenerator.cs @@ -26,13 +26,11 @@ namespace NHibernate.Id.Enhanced using System.Threading; public partial class TableGenerator : TransactionHelper, IPersistentIdentifierGenerator, IConfigurable { - private readonly NHibernate.Util.AsyncLock _generate = new NHibernate.Util.AsyncLock(); - [MethodImpl()] public virtual async Task GenerateAsync(ISessionImplementor session, object obj, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _generate.LockAsync()) + using (await (_asyncLock.LockAsync()).ConfigureAwait(false)) { return await (Optimizer.GenerateAsync(new TableAccessCallback(session, this), cancellationToken)).ConfigureAwait(false); } diff --git a/src/NHibernate/Async/Id/IncrementGenerator.cs b/src/NHibernate/Async/Id/IncrementGenerator.cs index 0fd915d15cf..4df097d6624 100644 --- a/src/NHibernate/Async/Id/IncrementGenerator.cs +++ b/src/NHibernate/Async/Id/IncrementGenerator.cs @@ -19,6 +19,7 @@ using NHibernate.SqlCommand; using NHibernate.SqlTypes; using NHibernate.Type; +using NHibernate.Util; namespace NHibernate.Id { @@ -26,7 +27,6 @@ namespace NHibernate.Id using System.Threading; public partial class IncrementGenerator : IIdentifierGenerator, IConfigurable { - private readonly NHibernate.Util.AsyncLock _generate = new NHibernate.Util.AsyncLock(); /// /// @@ -35,16 +35,16 @@ public partial class IncrementGenerator : IIdentifierGenerator, IConfigurable /// /// A cancellation token that can be used to cancel the work /// - [MethodImpl()] public async Task GenerateAsync(ISessionImplementor session, object obj, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _generate.LockAsync()) + using (await (_asyncLock.LockAsync()).ConfigureAwait(false)) { if (_sql != null) { await (GetNextAsync(session, cancellationToken)).ConfigureAwait(false); } + return IdentifierGeneratorFactory.CreateNumber(_next++, _returnClass); } } diff --git a/src/NHibernate/Async/Id/SequenceHiLoGenerator.cs b/src/NHibernate/Async/Id/SequenceHiLoGenerator.cs index 94ee6d72da5..75992f456ed 100644 --- a/src/NHibernate/Async/Id/SequenceHiLoGenerator.cs +++ b/src/NHibernate/Async/Id/SequenceHiLoGenerator.cs @@ -23,7 +23,6 @@ namespace NHibernate.Id using System.Threading; public partial class SequenceHiLoGenerator : SequenceGenerator { - private readonly NHibernate.Util.AsyncLock _generate = new NHibernate.Util.AsyncLock(); #region IIdentifierGenerator Members @@ -35,11 +34,10 @@ public partial class SequenceHiLoGenerator : SequenceGenerator /// The entity for which the id is being generated. /// A cancellation token that can be used to cancel the work /// The new identifier as a , , or . - [MethodImpl()] public override async Task GenerateAsync(ISessionImplementor session, object obj, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _generate.LockAsync()) + using (await (_asyncLock.LockAsync()).ConfigureAwait(false)) { if (maxLo < 1) { diff --git a/src/NHibernate/Async/Id/TableGenerator.cs b/src/NHibernate/Async/Id/TableGenerator.cs index b2731653a92..3ad468a09be 100644 --- a/src/NHibernate/Async/Id/TableGenerator.cs +++ b/src/NHibernate/Async/Id/TableGenerator.cs @@ -29,7 +29,6 @@ namespace NHibernate.Id using System.Threading; public partial class TableGenerator : TransactionHelper, IPersistentIdentifierGenerator, IConfigurable { - private readonly NHibernate.Util.AsyncLock _generate = new NHibernate.Util.AsyncLock(); #region IIdentifierGenerator Members @@ -41,11 +40,10 @@ public partial class TableGenerator : TransactionHelper, IPersistentIdentifierGe /// The entity for which the id is being generated. /// A cancellation token that can be used to cancel the work /// The new identifier as a , , or . - [MethodImpl()] public virtual async Task GenerateAsync(ISessionImplementor session, object obj, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _generate.LockAsync()) + using (await (_asyncLock.LockAsync()).ConfigureAwait(false)) { // This has to be done using a different connection to the containing // transaction becase the new hi value must remain valid even if the diff --git a/src/NHibernate/Async/Id/TableHiLoGenerator.cs b/src/NHibernate/Async/Id/TableHiLoGenerator.cs index 8302dad5e98..663733f6b39 100644 --- a/src/NHibernate/Async/Id/TableHiLoGenerator.cs +++ b/src/NHibernate/Async/Id/TableHiLoGenerator.cs @@ -23,7 +23,6 @@ namespace NHibernate.Id using System.Threading; public partial class TableHiLoGenerator : TableGenerator { - private readonly NHibernate.Util.AsyncLock _generate = new NHibernate.Util.AsyncLock(); #region IIdentifierGenerator Members @@ -34,11 +33,10 @@ public partial class TableHiLoGenerator : TableGenerator /// The entity for which the id is being generated. /// A cancellation token that can be used to cancel the work /// The new identifier as a . - [MethodImpl()] public override async Task GenerateAsync(ISessionImplementor session, object obj, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _generate.LockAsync()) + using (await (_asyncLock.LockAsync()).ConfigureAwait(false)) { if (maxLo < 1) { diff --git a/src/NHibernate/Cache/ReadWriteCache.cs b/src/NHibernate/Cache/ReadWriteCache.cs index 2bf891f3068..9bb25e51048 100644 --- a/src/NHibernate/Cache/ReadWriteCache.cs +++ b/src/NHibernate/Cache/ReadWriteCache.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using NHibernate.Cache.Access; +using NHibernate.Util; namespace NHibernate.Cache { @@ -33,9 +34,9 @@ public interface ILockable private static readonly INHibernateLogger log = NHibernateLogger.For(typeof(ReadWriteCache)); - private readonly object _lockObject = new object(); private CacheBase _cache; private int _nextLockId; + private readonly AsyncReaderWriterLock _asyncReaderWriterLock = new AsyncReaderWriterLock(); /// /// Gets the cache region name. @@ -95,7 +96,7 @@ private int NextLockId() /// public object Get(CacheKey key, long txTimestamp) { - lock (_lockObject) + using (_asyncReaderWriterLock.ReadLock()) { if (log.IsDebugEnabled()) { @@ -123,7 +124,7 @@ public object[] GetMany(CacheKey[] keys, long timestamp) log.Debug("Cache lookup: {0}", string.Join(",", keys.AsEnumerable())); } var result = new object[keys.Length]; - lock (_lockObject) + using (_asyncReaderWriterLock.ReadLock()) { var lockables = _cache.GetMany(keys); for (var i = 0; i < lockables.Length; i++) @@ -166,7 +167,7 @@ private static object GetValue(long timestamp, CacheKey key, ILockable lockable) /// public ISoftLock Lock(CacheKey key, object version) { - lock (_lockObject) + using (_asyncReaderWriterLock.WriteLock()) { if (log.IsDebugEnabled()) { @@ -209,7 +210,7 @@ public bool[] PutMany( return result; } - lock (_lockObject) + using (_asyncReaderWriterLock.WriteLock()) { if (log.IsDebugEnabled()) { @@ -278,7 +279,7 @@ public bool Put(CacheKey key, object value, long txTimestamp, object version, IC return false; } - lock (_lockObject) + using (_asyncReaderWriterLock.WriteLock()) { if (log.IsDebugEnabled()) { @@ -330,7 +331,7 @@ private void DecrementLock(object key, CacheLock @lock) public void Release(CacheKey key, ISoftLock clientLock) { - lock (_lockObject) + using (_asyncReaderWriterLock.WriteLock()) { if (log.IsDebugEnabled()) { @@ -382,6 +383,7 @@ public void Destroy() // The cache is externally provided and may be shared. Destroying the cache is // not the responsibility of this class. Cache = null; + _asyncReaderWriterLock.Dispose(); } /// @@ -390,7 +392,7 @@ public void Destroy() /// public bool AfterUpdate(CacheKey key, object value, object version, ISoftLock clientLock) { - lock (_lockObject) + using (_asyncReaderWriterLock.WriteLock()) { if (log.IsDebugEnabled()) { @@ -436,7 +438,7 @@ public bool AfterUpdate(CacheKey key, object value, object version, ISoftLock cl public bool AfterInsert(CacheKey key, object value, object version) { - lock (_lockObject) + using (_asyncReaderWriterLock.WriteLock()) { if (log.IsDebugEnabled()) { diff --git a/src/NHibernate/Cache/UpdateTimestampsCache.cs b/src/NHibernate/Cache/UpdateTimestampsCache.cs index 40369e4ac97..f6851f5ed44 100644 --- a/src/NHibernate/Cache/UpdateTimestampsCache.cs +++ b/src/NHibernate/Cache/UpdateTimestampsCache.cs @@ -19,6 +19,7 @@ public partial class UpdateTimestampsCache { private static readonly INHibernateLogger log = NHibernateLogger.For(typeof(UpdateTimestampsCache)); private readonly CacheBase _updateTimestamps; + private readonly AsyncReaderWriterLock _asyncReaderWriterLock = new AsyncReaderWriterLock(); public virtual void Clear() { @@ -54,14 +55,18 @@ public void PreInvalidate(object[] spaces) PreInvalidate(spaces.OfType().ToList()); } - [MethodImpl(MethodImplOptions.Synchronized)] public virtual void PreInvalidate(IReadOnlyCollection spaces) { - //TODO: to handle concurrent writes correctly, this should return a Lock to the client - var ts = _updateTimestamps.NextTimestamp() + _updateTimestamps.Timeout; - SetSpacesTimestamp(spaces, ts); + if (spaces.Count == 0) + return; - //TODO: return new Lock(ts); + using (_asyncReaderWriterLock.WriteLock()) + { + //TODO: to handle concurrent writes correctly, this should return a Lock to the client + var ts = _updateTimestamps.NextTimestamp() + _updateTimestamps.Timeout; + SetSpacesTimestamp(spaces, ts); + //TODO: return new Lock(ts); + } } //Since v5.1 @@ -72,38 +77,41 @@ public void Invalidate(object[] spaces) Invalidate(spaces.OfType().ToList()); } - [MethodImpl(MethodImplOptions.Synchronized)] public virtual void Invalidate(IReadOnlyCollection spaces) { - //TODO: to handle concurrent writes correctly, the client should pass in a Lock - long ts = _updateTimestamps.NextTimestamp(); - //TODO: if lock.getTimestamp().equals(ts) - if (log.IsDebugEnabled()) - log.Debug("Invalidating spaces [{0}]", StringHelper.CollectionToString(spaces)); - SetSpacesTimestamp(spaces, ts); + if (spaces.Count == 0) + return; + + using (_asyncReaderWriterLock.WriteLock()) + { + //TODO: to handle concurrent writes correctly, the client should pass in a Lock + long ts = _updateTimestamps.NextTimestamp(); + //TODO: if lock.getTimestamp().equals(ts) + if (log.IsDebugEnabled()) + log.Debug("Invalidating spaces [{0}]", StringHelper.CollectionToString(spaces)); + SetSpacesTimestamp(spaces, ts); + } } private void SetSpacesTimestamp(IReadOnlyCollection spaces, long ts) { - if (spaces.Count == 0) - return; - _updateTimestamps.PutMany( spaces.ToArray(), ArrayHelper.Fill(ts, spaces.Count)); } - [MethodImpl(MethodImplOptions.Synchronized)] public virtual bool IsUpToDate(ISet spaces, long timestamp /* H2.1 has Long here */) { if (spaces.Count == 0) return true; - var lastUpdates = _updateTimestamps.GetMany(spaces.ToArray()); - return lastUpdates.All(lastUpdate => !IsOutdated(lastUpdate as long?, timestamp)); + using (_asyncReaderWriterLock.ReadLock()) + { + var lastUpdates = _updateTimestamps.GetMany(spaces.ToArray()); + return lastUpdates.All(lastUpdate => !IsOutdated(lastUpdate as long?, timestamp)); + } } - [MethodImpl(MethodImplOptions.Synchronized)] public virtual bool[] AreUpToDate(ISet[] spaces, long[] timestamps) { if (spaces.Length == 0) @@ -120,20 +128,23 @@ public virtual bool[] AreUpToDate(ISet[] spaces, long[] timestamps) var keys = allSpaces.ToArray(); - var index = 0; - var lastUpdatesBySpace = - _updateTimestamps - .GetMany(keys) - .ToDictionary(u => keys[index++], u => u as long?); - - var results = new bool[spaces.Length]; - for (var i = 0; i < spaces.Length; i++) + using (_asyncReaderWriterLock.ReadLock()) { - var timestamp = timestamps[i]; - results[i] = spaces[i].All(space => !IsOutdated(lastUpdatesBySpace[space], timestamp)); - } + var index = 0; + var lastUpdatesBySpace = + _updateTimestamps + .GetMany(keys) + .ToDictionary(u => keys[index++], u => u as long?); + + var results = new bool[spaces.Length]; + for (var i = 0; i < spaces.Length; i++) + { + var timestamp = timestamps[i]; + results[i] = spaces[i].All(space => !IsOutdated(lastUpdatesBySpace[space], timestamp)); + } - return results; + return results; + } } // Since v5.3 @@ -142,6 +153,7 @@ public virtual void Destroy() { // The cache is externally provided and may be shared. Destroying the cache is // not the responsibility of this class. + _asyncReaderWriterLock.Dispose(); } private static bool IsOutdated(long? lastUpdate, long timestamp) diff --git a/src/NHibernate/Id/Enhanced/OptimizerFactory.cs b/src/NHibernate/Id/Enhanced/OptimizerFactory.cs index 0a410f7fda2..0adf3695551 100644 --- a/src/NHibernate/Id/Enhanced/OptimizerFactory.cs +++ b/src/NHibernate/Id/Enhanced/OptimizerFactory.cs @@ -101,6 +101,7 @@ public partial class HiLoOptimizer : OptimizerSupport private long _upperLimit; private long _lastSourceValue = -1; private long _value; + private readonly AsyncLock _asyncLock = new AsyncLock(); public HiLoOptimizer(System.Type returnClass, int incrementSize) : base(returnClass, incrementSize) { @@ -140,29 +141,32 @@ public override bool ApplyIncrementSizeToSourceValues get { return false; } } - [MethodImpl(MethodImplOptions.Synchronized)] public override object Generate(IAccessCallback callback) { - if (_lastSourceValue < 0) + using (_asyncLock.Lock()) { - _lastSourceValue = callback.GetNextValue(); - while (_lastSourceValue <= 0) + if (_lastSourceValue < 0) { _lastSourceValue = callback.GetNextValue(); - } + while (_lastSourceValue <= 0) + { + _lastSourceValue = callback.GetNextValue(); + } - // upperLimit defines the upper end of the bucket values - _upperLimit = (_lastSourceValue * IncrementSize) + 1; + // upperLimit defines the upper end of the bucket values + _upperLimit = (_lastSourceValue * IncrementSize) + 1; - // initialize value to the low end of the bucket - _value = _upperLimit - IncrementSize; - } - else if (_upperLimit <= _value) - { - _lastSourceValue = callback.GetNextValue(); - _upperLimit = (_lastSourceValue * IncrementSize) + 1; + // initialize value to the low end of the bucket + _value = _upperLimit - IncrementSize; + } + else if (_upperLimit <= _value) + { + _lastSourceValue = callback.GetNextValue(); + _upperLimit = (_lastSourceValue * IncrementSize) + 1; + } + + return Make(_value++); } - return Make(_value++); } } @@ -267,6 +271,7 @@ public partial class PooledOptimizer : OptimizerSupport, IInitialValueAwareOptim private long _hiValue = -1; private long _value; private long _initialValue; + private readonly AsyncLock _asyncLock = new AsyncLock(); public PooledOptimizer(System.Type returnClass, int incrementSize) : base(returnClass, incrementSize) { @@ -303,35 +308,38 @@ public void InjectInitialValue(long initialValue) _initialValue = initialValue; } - [MethodImpl(MethodImplOptions.Synchronized)] public override object Generate(IAccessCallback callback) { - if (_hiValue < 0) + using (_asyncLock.Lock()) { - _value = callback.GetNextValue(); - if (_value < 1) + if (_hiValue < 0) { - // unfortunately not really safe to normalize this - // to 1 as an initial value like we do the others - // because we would not be able to control this if - // we are using a sequence... - Log.Info("pooled optimizer source reported [{0}] as the initial value; use of 1 or greater highly recommended", _value); + _value = callback.GetNextValue(); + if (_value < 1) + { + // unfortunately not really safe to normalize this + // to 1 as an initial value like we do the others + // because we would not be able to control this if + // we are using a sequence... + Log.Info("pooled optimizer source reported [{0}] as the initial value; use of 1 or greater highly recommended", _value); + } + + if ((_initialValue == -1 && _value < IncrementSize) || _value == _initialValue) + _hiValue = callback.GetNextValue(); + else + { + _hiValue = _value; + _value = _hiValue - IncrementSize; + } } - - if ((_initialValue == -1 && _value < IncrementSize) || _value == _initialValue) - _hiValue = callback.GetNextValue(); - else + else if (_value >= _hiValue) { - _hiValue = _value; + _hiValue = callback.GetNextValue(); _value = _hiValue - IncrementSize; } + + return Make(_value++); } - else if (_value >= _hiValue) - { - _hiValue = callback.GetNextValue(); - _value = _hiValue - IncrementSize; - } - return Make(_value++); } } @@ -343,6 +351,7 @@ public partial class PooledLoOptimizer : OptimizerSupport { private long _lastSourceValue = -1; // last value read from db source private long _value; // the current generator value + private readonly AsyncLock _asyncLock = new AsyncLock(); public PooledLoOptimizer(System.Type returnClass, int incrementSize) : base(returnClass, incrementSize) { @@ -356,18 +365,21 @@ public PooledLoOptimizer(System.Type returnClass, int incrementSize) : base(retu } } - [MethodImpl(MethodImplOptions.Synchronized)] public override object Generate(IAccessCallback callback) { - if (_lastSourceValue < 0 || _value >= (_lastSourceValue + IncrementSize)) + using (_asyncLock.Lock()) { - _lastSourceValue = callback.GetNextValue(); - _value = _lastSourceValue; - // handle cases where initial-value is less than one (hsqldb for instance). - while (_value < 1) - _value++; + if (_lastSourceValue < 0 || _value >= (_lastSourceValue + IncrementSize)) + { + _lastSourceValue = callback.GetNextValue(); + _value = _lastSourceValue; + // handle cases where initial-value is less than one (hsqldb for instance). + while (_value < 1) + _value++; + } + + return Make(_value++); } - return Make(_value++); } public override long LastSourceValue diff --git a/src/NHibernate/Id/Enhanced/TableGenerator.cs b/src/NHibernate/Id/Enhanced/TableGenerator.cs index 881280f5237..60a287db13f 100644 --- a/src/NHibernate/Id/Enhanced/TableGenerator.cs +++ b/src/NHibernate/Id/Enhanced/TableGenerator.cs @@ -181,6 +181,7 @@ public partial class TableGenerator : TransactionHelper, IPersistentIdentifierGe private SqlTypes.SqlType[] insertParameterTypes; private SqlString updateQuery; private SqlTypes.SqlType[] updateParameterTypes; + private readonly AsyncLock _asyncLock = new AsyncLock(); public virtual string GeneratorKey() { @@ -378,10 +379,12 @@ protected void BuildInsertQuery() }; } - [MethodImpl(MethodImplOptions.Synchronized)] public virtual object Generate(ISessionImplementor session, object obj) { - return Optimizer.Generate(new TableAccessCallback(session, this)); + using (_asyncLock.Lock()) + { + return Optimizer.Generate(new TableAccessCallback(session, this)); + } } private partial class TableAccessCallback : IAccessCallback diff --git a/src/NHibernate/Id/IncrementGenerator.cs b/src/NHibernate/Id/IncrementGenerator.cs index 1d4fef66ba7..ba4012f60ec 100644 --- a/src/NHibernate/Id/IncrementGenerator.cs +++ b/src/NHibernate/Id/IncrementGenerator.cs @@ -9,6 +9,7 @@ using NHibernate.SqlCommand; using NHibernate.SqlTypes; using NHibernate.Type; +using NHibernate.Util; namespace NHibernate.Id { @@ -32,6 +33,7 @@ public partial class IncrementGenerator : IIdentifierGenerator, IConfigurable private long _next; private SqlString _sql; private System.Type _returnClass; + private readonly AsyncLock _asyncLock = new AsyncLock(); /// /// @@ -85,14 +87,17 @@ public void Configure(IType type, IDictionary parms, Dialect.Dia /// /// /// - [MethodImpl(MethodImplOptions.Synchronized)] public object Generate(ISessionImplementor session, object obj) { - if (_sql != null) + using (_asyncLock.Lock()) { - GetNext(session); + if (_sql != null) + { + GetNext(session); + } + + return IdentifierGeneratorFactory.CreateNumber(_next++, _returnClass); } - return IdentifierGeneratorFactory.CreateNumber(_next++, _returnClass); } private void GetNext(ISessionImplementor session) diff --git a/src/NHibernate/Id/SequenceHiLoGenerator.cs b/src/NHibernate/Id/SequenceHiLoGenerator.cs index 8e07c079548..2d9fca0b85b 100644 --- a/src/NHibernate/Id/SequenceHiLoGenerator.cs +++ b/src/NHibernate/Id/SequenceHiLoGenerator.cs @@ -46,6 +46,7 @@ public partial class SequenceHiLoGenerator : SequenceGenerator private int lo; private long hi; private System.Type returnClass; + private readonly AsyncLock _asyncLock = new AsyncLock(); #region IConfigurable Members @@ -75,27 +76,29 @@ public override void Configure(IType type, IDictionary parms, Di /// The this id is being generated in. /// The entity for which the id is being generated. /// The new identifier as a , , or . - [MethodImpl(MethodImplOptions.Synchronized)] public override object Generate(ISessionImplementor session, object obj) { - if (maxLo < 1) + using (_asyncLock.Lock()) { - //keep the behavior consistent even for boundary usages - long val = Convert.ToInt64(base.Generate(session, obj)); - if (val == 0) - val = Convert.ToInt64(base.Generate(session, obj)); - return IdentifierGeneratorFactory.CreateNumber(val, returnClass); - } + if (maxLo < 1) + { + //keep the behavior consistent even for boundary usages + long val = Convert.ToInt64(base.Generate(session, obj)); + if (val == 0) + val = Convert.ToInt64(base.Generate(session, obj)); + return IdentifierGeneratorFactory.CreateNumber(val, returnClass); + } - if (lo > maxLo) - { - long hival = Convert.ToInt64(base.Generate(session, obj)); - lo = (hival == 0) ? 1 : 0; - hi = hival * (maxLo + 1); - if (log.IsDebugEnabled()) - log.Debug("new hi value: {0}", hival); + if (lo > maxLo) + { + long hival = Convert.ToInt64(base.Generate(session, obj)); + lo = (hival == 0) ? 1 : 0; + hi = hival * (maxLo + 1); + if (log.IsDebugEnabled()) + log.Debug("new hi value: {0}", hival); + } + return IdentifierGeneratorFactory.CreateNumber(hi + lo++, returnClass); } - return IdentifierGeneratorFactory.CreateNumber(hi + lo++, returnClass); } #endregion diff --git a/src/NHibernate/Id/TableGenerator.cs b/src/NHibernate/Id/TableGenerator.cs index ce79bf2a532..09180687efc 100644 --- a/src/NHibernate/Id/TableGenerator.cs +++ b/src/NHibernate/Id/TableGenerator.cs @@ -70,6 +70,7 @@ public partial class TableGenerator : TransactionHelper, IPersistentIdentifierGe private SqlString updateSql; private SqlType[] parameterTypes; + private readonly AsyncLock _asyncLock = new AsyncLock(); #region IConfigurable Members @@ -151,13 +152,15 @@ public virtual void Configure(IType type, IDictionary parms, Dia /// The this id is being generated in. /// The entity for which the id is being generated. /// The new identifier as a , , or . - [MethodImpl(MethodImplOptions.Synchronized)] public virtual object Generate(ISessionImplementor session, object obj) { - // This has to be done using a different connection to the containing - // transaction becase the new hi value must remain valid even if the - // containing transaction rolls back. - return DoWorkInNewTransaction(session); + using (_asyncLock.Lock()) + { + // This has to be done using a different connection to the containing + // transaction becase the new hi value must remain valid even if the + // containing transaction rolls back. + return DoWorkInNewTransaction(session); + } } #endregion diff --git a/src/NHibernate/Id/TableHiLoGenerator.cs b/src/NHibernate/Id/TableHiLoGenerator.cs index 402e5dbd467..c64de1d3323 100644 --- a/src/NHibernate/Id/TableHiLoGenerator.cs +++ b/src/NHibernate/Id/TableHiLoGenerator.cs @@ -52,6 +52,7 @@ public partial class TableHiLoGenerator : TableGenerator private long lo; private long maxLo; private System.Type returnClass; + private readonly AsyncLock _asyncLock = new AsyncLock(); #region IConfigurable Members @@ -80,26 +81,28 @@ public override void Configure(IType type, IDictionary parms, Di /// The this id is being generated in. /// The entity for which the id is being generated. /// The new identifier as a . - [MethodImpl(MethodImplOptions.Synchronized)] public override object Generate(ISessionImplementor session, object obj) { - if (maxLo < 1) + using (_asyncLock.Lock()) { - //keep the behavior consistent even for boundary usages - long val = Convert.ToInt64(base.Generate(session, obj)); - if (val == 0) - val = Convert.ToInt64(base.Generate(session, obj)); - return IdentifierGeneratorFactory.CreateNumber(val, returnClass); - } - if (lo > maxLo) - { - long hival = Convert.ToInt64(base.Generate(session, obj)); - lo = (hival == 0) ? 1 : 0; - hi = hival * (maxLo + 1); - log.Debug("New high value: {0}", hival); - } + if (maxLo < 1) + { + //keep the behavior consistent even for boundary usages + long val = Convert.ToInt64(base.Generate(session, obj)); + if (val == 0) + val = Convert.ToInt64(base.Generate(session, obj)); + return IdentifierGeneratorFactory.CreateNumber(val, returnClass); + } + if (lo > maxLo) + { + long hival = Convert.ToInt64(base.Generate(session, obj)); + lo = (hival == 0) ? 1 : 0; + hi = hival * (maxLo + 1); + log.Debug("New high value: {0}", hival); + } - return IdentifierGeneratorFactory.CreateNumber(hi + lo++, returnClass); + return IdentifierGeneratorFactory.CreateNumber(hi + lo++, returnClass); + } } #endregion diff --git a/src/NHibernate/Util/AsyncLock.cs b/src/NHibernate/Util/AsyncLock.cs index f322d48f175..8a6f00bc95f 100644 --- a/src/NHibernate/Util/AsyncLock.cs +++ b/src/NHibernate/Util/AsyncLock.cs @@ -8,24 +8,32 @@ namespace NHibernate.Util public sealed class AsyncLock { private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1); - private readonly Task _releaser; + private readonly IDisposable _releaser; + private readonly Task _releaserTask; public AsyncLock() { - _releaser = Task.FromResult((IDisposable)new Releaser(this)); + _releaser = new Releaser(this); + _releaserTask = Task.FromResult(_releaser); } public Task LockAsync() { var wait = _semaphore.WaitAsync(); return wait.IsCompleted ? - _releaser : + _releaserTask : wait.ContinueWith( (_, state) => (IDisposable)state, - _releaser.Result, CancellationToken.None, + _releaser, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); } + public IDisposable Lock() + { + _semaphore.Wait(); + return _releaser; + } + private sealed class Releaser : IDisposable { private readonly AsyncLock _toRelease; @@ -33,4 +41,4 @@ private sealed class Releaser : IDisposable public void Dispose() { _toRelease._semaphore.Release(); } } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Util/AsyncReaderWriterLock.cs b/src/NHibernate/Util/AsyncReaderWriterLock.cs new file mode 100644 index 00000000000..46d25b99c9f --- /dev/null +++ b/src/NHibernate/Util/AsyncReaderWriterLock.cs @@ -0,0 +1,252 @@ +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace NHibernate.Util +{ + // Idea from: + // https://github.com/kpreisser/AsyncReaderWriterLockSlim + // https://devblogs.microsoft.com/pfxteam/building-async-coordination-primitives-part-7-asyncreaderwriterlock/ + internal class AsyncReaderWriterLock : IDisposable + { + private readonly SemaphoreSlim _writeLockSemaphore = new SemaphoreSlim(1, 1); + private readonly SemaphoreSlim _readLockSemaphore = new SemaphoreSlim(0, 1); + private readonly Releaser _writerReleaser; + private readonly Releaser _readerReleaser; + private readonly Task _readerReleaserTask; + private SemaphoreSlim _waitingReadLockSemaphore; + private SemaphoreSlim _waitingDisposalSemaphore; + private int _readersWaiting; + private int _currentReaders; + private int _writersWaiting; + private bool _disposed; + + public AsyncReaderWriterLock() + { + _writerReleaser = new Releaser(this, true); + _readerReleaser = new Releaser(this, false); + _readerReleaserTask = Task.FromResult(_readerReleaser); + } + + internal int CurrentReaders => _currentReaders; + + internal int WritersWaiting => _writersWaiting; + + internal int ReadersWaiting => _readersWaiting; + + internal bool Writing => _currentReaders == 0 && _writeLockSemaphore.CurrentCount == 0; + + internal bool AcquiredWriteLock => _writeLockSemaphore.CurrentCount == 0; + + public Releaser WriteLock() + { + if (!CanEnterWriteLock(out var waitForReadLocks)) + { + _writeLockSemaphore.Wait(); + lock (_writeLockSemaphore) + { + _writersWaiting--; + } + } + + if (waitForReadLocks) + { + _readLockSemaphore.Wait(); + } + + DisposeWaitingSemaphore(); + + return _writerReleaser; + } + + public async Task WriteLockAsync() + { + if (!CanEnterWriteLock(out var waitForReadLocks)) + { + await _writeLockSemaphore.WaitAsync().ConfigureAwait(false); + lock (_writeLockSemaphore) + { + _writersWaiting--; + } + } + + if (waitForReadLocks) + { + await _readLockSemaphore.WaitAsync().ConfigureAwait(false); + } + + DisposeWaitingSemaphore(); + + return _writerReleaser; + } + + public Releaser ReadLock() + { + if (CanEnterReadLock()) + { + return _readerReleaser; + } + + _waitingReadLockSemaphore.Wait(); + + return _readerReleaser; + } + + public Task ReadLockAsync() + { + return CanEnterReadLock() ? _readerReleaserTask : ReadLockInternalAsync(); + + async Task ReadLockInternalAsync() + { + await _waitingReadLockSemaphore.WaitAsync().ConfigureAwait(false); + + return _readerReleaser; + } + } + + public void Dispose() + { + lock (_writeLockSemaphore) + { + _writeLockSemaphore.Dispose(); + _readLockSemaphore.Dispose(); + _waitingReadLockSemaphore?.Dispose(); + _waitingDisposalSemaphore?.Dispose(); + _disposed = true; + } + } + + private bool CanEnterWriteLock(out bool waitForReadLocks) + { + waitForReadLocks = false; + lock (_writeLockSemaphore) + { + AssertNotDisposed(); + if (_writeLockSemaphore.CurrentCount > 0 && _writeLockSemaphore.Wait(0)) + { + waitForReadLocks = _currentReaders > 0; + return true; + } + + _writersWaiting++; + } + + return false; + } + + private void ExitWriteLock() + { + lock (_writeLockSemaphore) + { + AssertNotDisposed(); + if (_writeLockSemaphore.CurrentCount == 1) + { + throw new InvalidOperationException(); + } + + // Writers have the highest priority even if they came last + if (_writersWaiting > 0 || _waitingReadLockSemaphore == null) + { + _writeLockSemaphore.Release(); + return; + } + + if (_readersWaiting > 0) + { + _currentReaders += _readersWaiting; + _waitingReadLockSemaphore.Release(_readersWaiting); + _readersWaiting = 0; + // We have to dispose the waiting read lock only after all readers finished using it + _waitingDisposalSemaphore = _waitingReadLockSemaphore; + _waitingReadLockSemaphore = null; + } + + _writeLockSemaphore.Release(); + } + } + + private bool CanEnterReadLock() + { + lock (_writeLockSemaphore) + { + AssertNotDisposed(); + if (_writersWaiting == 0 && _writeLockSemaphore.CurrentCount > 0) + { + _currentReaders++; + + return true; + } + + if (_waitingReadLockSemaphore == null) + { + _waitingReadLockSemaphore = new SemaphoreSlim(0); + } + + _readersWaiting++; + + return false; + } + } + + private void ExitReadLock() + { + lock (_writeLockSemaphore) + { + AssertNotDisposed(); + if (_currentReaders == 0) + { + throw new InvalidOperationException(); + } + + _currentReaders--; + if (_currentReaders == 0 && _writeLockSemaphore.CurrentCount == 0) + { + _readLockSemaphore.Release(); + } + } + } + + private void DisposeWaitingSemaphore() + { + _waitingDisposalSemaphore?.Dispose(); + _waitingDisposalSemaphore = null; + } + + private void AssertNotDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(AsyncReaderWriterLock)); + } + } + + public struct Releaser : IDisposable + { + private readonly AsyncReaderWriterLock _toRelease; + private readonly bool _writer; + + internal Releaser(AsyncReaderWriterLock toRelease, bool writer) + { + _toRelease = toRelease; + _writer = writer; + } + + public void Dispose() + { + if (_toRelease == null) + { + return; + } + + if (_writer) + { + _toRelease.ExitWriteLock(); + } + else + { + _toRelease.ExitReadLock(); + } + } + } + } +} From 3bb2aa0c80046fe6d1bc8169205b0d20dd88f0fa Mon Sep 17 00:00:00 2001 From: Roman Artiukhin Date: Mon, 13 Apr 2020 15:11:32 +0300 Subject: [PATCH 30/43] Configure log4net from embedded resource log4net.xml (#2235) --- src/NHibernate.Test/App.config | 46 +------------------ src/NHibernate.Test/Hql/Ast/ParsingFixture.cs | 4 +- src/NHibernate.Test/Log4netConfiguration.cs | 2 - .../NHSpecificTest/NH1587/Fixture.cs | 3 -- src/NHibernate.Test/NHibernate.Test.csproj | 1 + src/NHibernate.Test/TestCase.cs | 7 --- src/NHibernate.Test/TestsContext.cs | 38 +++++---------- .../TypesTest/TypeFactoryFixture.cs | 6 --- src/NHibernate.Test/log4net.xml | 45 ++++++++++++++++++ 9 files changed, 59 insertions(+), 93 deletions(-) delete mode 100644 src/NHibernate.Test/Log4netConfiguration.cs create mode 100644 src/NHibernate.Test/log4net.xml diff --git a/src/NHibernate.Test/App.config b/src/NHibernate.Test/App.config index d3965012af5..40cf748a495 100644 --- a/src/NHibernate.Test/App.config +++ b/src/NHibernate.Test/App.config @@ -3,7 +3,6 @@
-
@@ -52,50 +51,7 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + diff --git a/src/NHibernate.Test/Hql/Ast/ParsingFixture.cs b/src/NHibernate.Test/Hql/Ast/ParsingFixture.cs index 040ba037b9a..45be4789f1c 100644 --- a/src/NHibernate.Test/Hql/Ast/ParsingFixture.cs +++ b/src/NHibernate.Test/Hql/Ast/ParsingFixture.cs @@ -74,8 +74,6 @@ namespace NHibernate.Test.Hql.Ast // [Test] // public void BasicQuery() // { - // XmlConfigurator.Configure(); - // string input = "select o.id, li.id from NHibernate.Test.CompositeId.Order o join o.LineItems li";// join o.LineItems li"; // ISessionFactoryImplementor sfi = SetupSFI(); @@ -175,4 +173,4 @@ namespace NHibernate.Test.Hql.Ast // } // } //} -} \ No newline at end of file +} diff --git a/src/NHibernate.Test/Log4netConfiguration.cs b/src/NHibernate.Test/Log4netConfiguration.cs deleted file mode 100644 index 841ecc09911..00000000000 --- a/src/NHibernate.Test/Log4netConfiguration.cs +++ /dev/null @@ -1,2 +0,0 @@ -using log4net.Config; -[assembly: XmlConfigurator()] \ No newline at end of file diff --git a/src/NHibernate.Test/NHSpecificTest/NH1587/Fixture.cs b/src/NHibernate.Test/NHSpecificTest/NH1587/Fixture.cs index 6a716ce5a66..68eb38d6bab 100644 --- a/src/NHibernate.Test/NHSpecificTest/NH1587/Fixture.cs +++ b/src/NHibernate.Test/NHSpecificTest/NH1587/Fixture.cs @@ -1,5 +1,3 @@ -using log4net; -using log4net.Config; using log4net.Core; using NHibernate.Cfg; using NUnit.Framework; @@ -12,7 +10,6 @@ public class Fixture [Test] public void Bug() { - XmlConfigurator.Configure(LogManager.GetRepository(typeof(Fixture).Assembly)); var cfg = new Configuration(); if (TestConfigurationHelper.hibernateConfigFile != null) cfg.Configure(TestConfigurationHelper.hibernateConfigFile); diff --git a/src/NHibernate.Test/NHibernate.Test.csproj b/src/NHibernate.Test/NHibernate.Test.csproj index 16ef3aab72e..11e964f39a6 100644 --- a/src/NHibernate.Test/NHibernate.Test.csproj +++ b/src/NHibernate.Test/NHibernate.Test.csproj @@ -25,6 +25,7 @@ + Always diff --git a/src/NHibernate.Test/TestCase.cs b/src/NHibernate.Test/TestCase.cs index 600265ade3f..14fc4ad1460 100644 --- a/src/NHibernate.Test/TestCase.cs +++ b/src/NHibernate.Test/TestCase.cs @@ -4,7 +4,6 @@ using System.Data; using System.Reflection; using log4net; -using log4net.Config; using NHibernate.Cfg; using NHibernate.Connection; using NHibernate.Engine; @@ -56,12 +55,6 @@ protected virtual string MappingsAssembly protected SchemaExport SchemaExport => _schemaExport ?? (_schemaExport = new SchemaExport(cfg)); - static TestCase() - { - // Configure log4net here since configuration through an attribute doesn't always work. - XmlConfigurator.Configure(LogManager.GetRepository(typeof(TestCase).Assembly)); - } - /// /// Creates the tables used in this TestCase /// diff --git a/src/NHibernate.Test/TestsContext.cs b/src/NHibernate.Test/TestsContext.cs index 4eae9d561e5..e80cca726aa 100644 --- a/src/NHibernate.Test/TestsContext.cs +++ b/src/NHibernate.Test/TestsContext.cs @@ -1,9 +1,8 @@ -#if NETCOREAPP2_0 -using NUnit.Framework; - +using NUnit.Framework; using System.Configuration; -using System.IO; -using log4net.Repository.Hierarchy; +using System.Reflection; +using log4net; +using log4net.Config; using NHibernate.Cfg; namespace NHibernate.Test @@ -11,40 +10,25 @@ namespace NHibernate.Test [SetUpFixture] public class TestsContext { - private static bool ExecutingWithVsTest { get; } = - System.Reflection.Assembly.GetEntryAssembly()?.GetName().Name == "testhost"; + private static readonly Assembly TestAssembly = typeof(TestsContext).Assembly; [OneTimeSetUp] public void RunBeforeAnyTests() { + ConfigureLog4Net(); + //When .NET Core App 2.0 tests run from VS/VSTest the entry assembly is "testhost.dll" //so we need to explicitly load the configuration - if (ExecutingWithVsTest) + if (Assembly.GetEntryAssembly() != null) { - var assemblyPath = Path.Combine(TestContext.CurrentContext.TestDirectory, Path.GetFileName(typeof(TestsContext).Assembly.Location)); - ConfigurationProvider.Current = new SystemConfigurationProvider(ConfigurationManager.OpenExeConfiguration(assemblyPath)); + ConfigurationProvider.Current = new SystemConfigurationProvider(ConfigurationManager.OpenExeConfiguration(TestAssembly.Location)); } - - ConfigureLog4Net(); } private static void ConfigureLog4Net() { - var hierarchy = (Hierarchy)log4net.LogManager.GetRepository(typeof(TestsContext).Assembly); - - var consoleAppender = new log4net.Appender.ConsoleAppender - { - Layout = new log4net.Layout.PatternLayout("%d{ABSOLUTE} %-5p %c{1}:%L - %m%n"), - }; - - ((Logger)hierarchy.GetLogger("NHibernate.Hql.Ast.ANTLR")).Level = log4net.Core.Level.Off; - ((Logger)hierarchy.GetLogger("NHibernate.SQL")).Level = log4net.Core.Level.Off; - ((Logger)hierarchy.GetLogger("NHibernate.AdoNet.AbstractBatcher")).Level = log4net.Core.Level.Off; - ((Logger)hierarchy.GetLogger("NHibernate.Tool.hbm2ddl.SchemaExport")).Level = log4net.Core.Level.Error; - hierarchy.Root.Level = log4net.Core.Level.Warn; - hierarchy.Root.AddAppender(consoleAppender); - hierarchy.Configured = true; + using (var log4NetXml = TestAssembly.GetManifestResourceStream("NHibernate.Test.log4net.xml")) + XmlConfigurator.Configure(LogManager.GetRepository(TestAssembly), log4NetXml); } } } -#endif diff --git a/src/NHibernate.Test/TypesTest/TypeFactoryFixture.cs b/src/NHibernate.Test/TypesTest/TypeFactoryFixture.cs index 33bfe99a710..3efb4bbf5b5 100644 --- a/src/NHibernate.Test/TypesTest/TypeFactoryFixture.cs +++ b/src/NHibernate.Test/TypesTest/TypeFactoryFixture.cs @@ -1,6 +1,5 @@ using System; using log4net; -using log4net.Repository.Hierarchy; using NHibernate.Type; using NUnit.Framework; @@ -12,11 +11,6 @@ namespace NHibernate.Test.TypesTest [TestFixture] public class TypeFactoryFixture { - public TypeFactoryFixture() - { - log4net.Config.XmlConfigurator.Configure(LogManager.GetRepository(typeof(TypeFactoryFixture).Assembly)); - } - private static readonly ILog log = LogManager.GetLogger(typeof(TypeFactoryFixture)); /// diff --git a/src/NHibernate.Test/log4net.xml b/src/NHibernate.Test/log4net.xml new file mode 100644 index 00000000000..01844a5d696 --- /dev/null +++ b/src/NHibernate.Test/log4net.xml @@ -0,0 +1,45 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From b14eea05fe9713dda04e017c2cafd3f24bcb355e Mon Sep 17 00:00:00 2001 From: maca88 Date: Fri, 6 Mar 2020 18:59:33 +0100 Subject: [PATCH 31/43] Add support for MemberInit expression in GroupBySelectClauseRewriter (#2221) --- .../Async/Linq/ByMethod/GroupByTests.cs | 42 +++++++++++++++++++ .../Linq/ByMethod/GroupByTests.cs | 42 +++++++++++++++++++ .../GroupBy/GroupBySelectClauseRewriter.cs | 19 ++++++++- 3 files changed, 101 insertions(+), 2 deletions(-) diff --git a/src/NHibernate.Test/Async/Linq/ByMethod/GroupByTests.cs b/src/NHibernate.Test/Async/Linq/ByMethod/GroupByTests.cs index 7d55ec6b069..a672dd18381 100644 --- a/src/NHibernate.Test/Async/Linq/ByMethod/GroupByTests.cs +++ b/src/NHibernate.Test/Async/Linq/ByMethod/GroupByTests.cs @@ -910,6 +910,48 @@ public async Task GroupByOrderByKeySelectToClassAsync() .ToListAsync()); } + [Test] + public async Task SelectArrayIndexBeforeGroupByAsync() + { + var result = db.Orders + .SelectMany(o => o.OrderLines.Select(c => c.Id).DefaultIfEmpty().Select(c => new object[] {c, o})) + .GroupBy(g => g[0], g => (Order) g[1]) + .Select(g => new[] {g.Key, g.Count(), g.Max(x => x.OrderDate)}); + + Assert.True(await (result.AnyAsync())); + } + + [Test] + public async Task SelectMemberInitBeforeGroupByAsync() + { + var result = await (db.Orders + .Select(o => new OrderGroup {OrderId = o.OrderId, OrderDate = o.OrderDate}) + .GroupBy(o => o.OrderId) + .Select(g => new OrderGroup {OrderId = g.Key, OrderDate = g.Max(o => o.OrderDate)}) + .ToListAsync()); + + Assert.True(result.Any()); + } + + [Test] + public async Task SelectNewBeforeGroupByAsync() + { + var result = await (db.Orders + .Select(o => new {o.OrderId, o.OrderDate}) + .GroupBy(o => o.OrderId) + .Select(g => new {OrderId = g.Key, OrderDate = g.Max(o => o.OrderDate)}) + .ToListAsync()); + + Assert.True(result.Any()); + } + + private class OrderGroup + { + public int OrderId { get; set; } + + public DateTime? OrderDate { get; set; } + } + private class GroupInfo { public object Key { get; set; } diff --git a/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs b/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs index 02612fec2e8..626dc619692 100644 --- a/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/GroupByTests.cs @@ -912,6 +912,48 @@ public void FetchBeforeGroupBy() Assert.True(result.Any()); } + [Test] + public void SelectArrayIndexBeforeGroupBy() + { + var result = db.Orders + .SelectMany(o => o.OrderLines.Select(c => c.Id).DefaultIfEmpty().Select(c => new object[] {c, o})) + .GroupBy(g => g[0], g => (Order) g[1]) + .Select(g => new[] {g.Key, g.Count(), g.Max(x => x.OrderDate)}); + + Assert.True(result.Any()); + } + + [Test] + public void SelectMemberInitBeforeGroupBy() + { + var result = db.Orders + .Select(o => new OrderGroup {OrderId = o.OrderId, OrderDate = o.OrderDate}) + .GroupBy(o => o.OrderId) + .Select(g => new OrderGroup {OrderId = g.Key, OrderDate = g.Max(o => o.OrderDate)}) + .ToList(); + + Assert.True(result.Any()); + } + + [Test] + public void SelectNewBeforeGroupBy() + { + var result = db.Orders + .Select(o => new {o.OrderId, o.OrderDate}) + .GroupBy(o => o.OrderId) + .Select(g => new {OrderId = g.Key, OrderDate = g.Max(o => o.OrderDate)}) + .ToList(); + + Assert.True(result.Any()); + } + + private class OrderGroup + { + public int OrderId { get; set; } + + public DateTime? OrderDate { get; set; } + } + private class GroupInfo { public object Key { get; set; } diff --git a/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs b/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs index 0e569fa6895..fe15d0bcf76 100644 --- a/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs +++ b/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using NHibernate.Linq.Expressions; @@ -14,6 +15,13 @@ namespace NHibernate.Linq.GroupBy //This should be renamed. It handles entire querymodels, not just select clauses internal class GroupBySelectClauseRewriter : RelinqExpressionVisitor { + private readonly static HashSet _validElementSelectorTypes = new HashSet + { + ExpressionType.ArrayIndex, + ExpressionType.New, + ExpressionType.MemberInit + }; + public static Expression ReWrite(Expression expression, GroupResultOperator groupBy, QueryModel model) { var visitor = new GroupBySelectClauseRewriter(groupBy, model); @@ -67,8 +75,8 @@ protected override Expression VisitMember(MemberExpression expression) return base.VisitMember(expression); } - if ((elementSelector is NewExpression || elementSelector.NodeType == ExpressionType.Convert) - && elementSelector.Type == expression.Expression.Type) + if (_validElementSelectorTypes.Contains(UnwrapUnary(elementSelector).NodeType) && + elementSelector.Type == expression.Expression.Type) { //TODO: probably we should check this with a visitor return Expression.MakeMemberAccess(elementSelector, expression.Member); @@ -156,5 +164,12 @@ expression.QueryModel.SelectClause.Selector is NhCountExpression countExpression // valid assumption. Should probably be passed a list of aggregating subqueries that we are flattening so that we can check... return ReWrite(expression.QueryModel.SelectClause.Selector, _groupBy, _model); } + + private static Expression UnwrapUnary(Expression expression) + { + return expression is UnaryExpression unaryExpression + ? unaryExpression.Operand + : expression; + } } } From 4a6bf6782065adddfc9c6e786a4fad9d668dc533 Mon Sep 17 00:00:00 2001 From: maca88 Date: Mon, 13 Apr 2020 14:39:11 +0200 Subject: [PATCH 32/43] Add support for OData group by queries (#2135) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Frédéric Delaporte <12201973+fredericDelaporte@users.noreply.github.com> --- src/NHibernate.Test/Async/Linq/ODataTests.cs | 164 ++++++++++++++++++ src/NHibernate.Test/Linq/ODataTests.cs | 152 ++++++++++++++++ src/NHibernate.Test/NHibernate.Test.csproj | 1 + .../GroupBy/GroupBySelectClauseRewriter.cs | 2 +- .../Linq/GroupBy/GroupKeyNominator.cs | 21 +++ .../Linq/GroupResultOperatorExtensions.cs | 29 +++- .../PagingRewriterSelectClauseVisitor.cs | 1 - ...rentIdentifierRemovingExpressionVisitor.cs | 119 +++++++++++++ 8 files changed, 482 insertions(+), 7 deletions(-) create mode 100644 src/NHibernate.Test/Async/Linq/ODataTests.cs create mode 100644 src/NHibernate.Test/Linq/ODataTests.cs create mode 100644 src/NHibernate/Linq/Visitors/TransparentIdentifierRemovingExpressionVisitor.cs diff --git a/src/NHibernate.Test/Async/Linq/ODataTests.cs b/src/NHibernate.Test/Async/Linq/ODataTests.cs new file mode 100644 index 00000000000..b4b593648ba --- /dev/null +++ b/src/NHibernate.Test/Async/Linq/ODataTests.cs @@ -0,0 +1,164 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.AspNet.OData; +using Microsoft.AspNet.OData.Builder; +using Microsoft.AspNet.OData.Extensions; +using Microsoft.AspNet.OData.Query; +using Microsoft.AspNet.OData.Query.Expressions; +using Microsoft.AspNetCore.Http; +using Microsoft.OData.Edm; +using NHibernate.DomainModel.Northwind.Entities; +using NUnit.Framework; +using NHibernate.Linq; + +namespace NHibernate.Test.Linq +{ + using System.Threading.Tasks; + [TestFixture] + public class ODataTestsAsync : LinqTestCase + { + private IEdmModel _edmModel; + + protected override void OnSetUp() + { + base.OnSetUp(); + + _edmModel = CreatEdmModel(); + } + + [TestCase("$apply=groupby((Customer/CustomerId))", 89)] + [TestCase("$apply=groupby((Customer/CustomerId))&$orderby=Customer/CustomerId", 89)] + [TestCase("$apply=groupby((Customer/CustomerId, ShippingAddress/PostalCode), aggregate(OrderId with average as Average, Employee/EmployeeId with max as Max))", 89)] + [TestCase("$apply=groupby((Customer/CustomerId), aggregate(OrderId with sum as Total))&$skip=2", 87)] + public async Task OrderGroupByAsync(string queryString, int expectedRows) + { + var query = ApplyFilter(session.Query(), queryString); + Assert.That(query, Is.AssignableTo>()); + + var results = await (((IQueryable) query).ToListAsync()); + Assert.That(results, Has.Count.EqualTo(expectedRows)); + } + + private IQueryable ApplyFilter(IQueryable query, string queryString) + { + var context = new ODataQueryContext(CreatEdmModel(), typeof(T), null) { }; + var dataQuerySettings = new ODataQuerySettings {HandleNullPropagation = HandleNullPropagationOption.False}; + var serviceProvider = new ODataServiceProvider( + new Dictionary() + { + {typeof(DefaultQuerySettings), new DefaultQuerySettings()}, + {typeof(ODataOptions), new ODataOptions()}, + {typeof(IEdmModel), _edmModel}, + {typeof(ODataQuerySettings), dataQuerySettings}, + }); + + HttpContext httpContext = new DefaultHttpContext(); + httpContext.ODataFeature().RequestContainer = serviceProvider; + httpContext.RequestServices = serviceProvider; + var request = httpContext.Request; + Uri requestUri = new Uri($"http://localhost/?{queryString}"); + request.Method = HttpMethods.Get; + request.Scheme = requestUri.Scheme; + request.Host = new HostString(requestUri.Host); + request.QueryString = new QueryString(requestUri.Query); + request.Path = new PathString(requestUri.AbsolutePath); + var options = new ODataQueryOptions(context, request); + + return options.ApplyTo(query, dataQuerySettings); + } + + private static IEdmModel CreatEdmModel() + { + var builder = new ODataConventionModelBuilder(); + + var adressModel = builder.ComplexType
(); + adressModel.Property(o => o.City); + adressModel.Property(o => o.Country); + adressModel.Property(o => o.Fax); + adressModel.Property(o => o.PhoneNumber); + adressModel.Property(o => o.PostalCode); + adressModel.Property(o => o.Region); + adressModel.Property(o => o.Street); + + var customerModel = builder.EntitySet(nameof(Customer)); + customerModel.EntityType.HasKey(o => o.CustomerId); + customerModel.EntityType.Property(o => o.CompanyName); + customerModel.EntityType.Property(o => o.ContactTitle); + customerModel.EntityType.ComplexProperty(o => o.Address); + customerModel.EntityType.HasMany(o => o.Orders); + + var orderModel = builder.EntitySet(nameof(Order)); + orderModel.EntityType.HasKey(o => o.OrderId); + orderModel.EntityType.Property(o => o.Freight); + orderModel.EntityType.Property(o => o.OrderDate); + orderModel.EntityType.Property(o => o.RequiredDate); + orderModel.EntityType.Property(o => o.ShippedTo); + orderModel.EntityType.Property(o => o.ShippingDate); + orderModel.EntityType.ComplexProperty(o => o.ShippingAddress); + orderModel.EntityType.HasRequired(o => o.Customer); + orderModel.EntityType.HasOptional(o => o.Employee); + + var employeeModel = builder.EntitySet(nameof(Employee)); + employeeModel.EntityType.HasKey(o => o.EmployeeId); + employeeModel.EntityType.Property(o => o.BirthDate); + employeeModel.EntityType.Property(o => o.Extension); + employeeModel.EntityType.Property(o => o.FirstName); + employeeModel.EntityType.Property(o => o.HireDate); + employeeModel.EntityType.Property(o => o.LastName); + employeeModel.EntityType.Property(o => o.Notes); + employeeModel.EntityType.Property(o => o.Title); + employeeModel.EntityType.HasMany(o => o.Orders); + + return builder.GetEdmModel(); + } + + private class ODataServiceProvider : IServiceProvider + { + private readonly Dictionary _singletonObjects = new Dictionary(); + + public ODataServiceProvider(Dictionary singletonObjects) + { + _singletonObjects = singletonObjects; + } + + public object GetService(System.Type serviceType) + { + if (_singletonObjects.TryGetValue(serviceType, out var service)) + { + return service; + } + + var ctor = serviceType.GetConstructor(new System.Type[0]); + if (ctor != null) + { + return ctor.Invoke(new object[0]); + } + + ctor = serviceType.GetConstructor(new[] { typeof(DefaultQuerySettings) }); + if (ctor != null) + { + return ctor.Invoke(new object[] { GetService(typeof(DefaultQuerySettings)) }); + } + + ctor = serviceType.GetConstructor(new[] { typeof(IServiceProvider) }); + if (ctor != null) + { + return ctor.Invoke(new object[] { this }); + } + + return null; + } + } + } +} diff --git a/src/NHibernate.Test/Linq/ODataTests.cs b/src/NHibernate.Test/Linq/ODataTests.cs new file mode 100644 index 00000000000..fa187e46acb --- /dev/null +++ b/src/NHibernate.Test/Linq/ODataTests.cs @@ -0,0 +1,152 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.AspNet.OData; +using Microsoft.AspNet.OData.Builder; +using Microsoft.AspNet.OData.Extensions; +using Microsoft.AspNet.OData.Query; +using Microsoft.AspNet.OData.Query.Expressions; +using Microsoft.AspNetCore.Http; +using Microsoft.OData.Edm; +using NHibernate.DomainModel.Northwind.Entities; +using NUnit.Framework; + +namespace NHibernate.Test.Linq +{ + [TestFixture] + public class ODataTests : LinqTestCase + { + private IEdmModel _edmModel; + + protected override void OnSetUp() + { + base.OnSetUp(); + + _edmModel = CreatEdmModel(); + } + + [TestCase("$apply=groupby((Customer/CustomerId))", 89)] + [TestCase("$apply=groupby((Customer/CustomerId))&$orderby=Customer/CustomerId", 89)] + [TestCase("$apply=groupby((Customer/CustomerId, ShippingAddress/PostalCode), aggregate(OrderId with average as Average, Employee/EmployeeId with max as Max))", 89)] + [TestCase("$apply=groupby((Customer/CustomerId), aggregate(OrderId with sum as Total))&$skip=2", 87)] + public void OrderGroupBy(string queryString, int expectedRows) + { + var query = ApplyFilter(session.Query(), queryString); + Assert.That(query, Is.AssignableTo>()); + + var results = ((IQueryable) query).ToList(); + Assert.That(results, Has.Count.EqualTo(expectedRows)); + } + + private IQueryable ApplyFilter(IQueryable query, string queryString) + { + var context = new ODataQueryContext(CreatEdmModel(), typeof(T), null) { }; + var dataQuerySettings = new ODataQuerySettings {HandleNullPropagation = HandleNullPropagationOption.False}; + var serviceProvider = new ODataServiceProvider( + new Dictionary() + { + {typeof(DefaultQuerySettings), new DefaultQuerySettings()}, + {typeof(ODataOptions), new ODataOptions()}, + {typeof(IEdmModel), _edmModel}, + {typeof(ODataQuerySettings), dataQuerySettings}, + }); + + HttpContext httpContext = new DefaultHttpContext(); + httpContext.ODataFeature().RequestContainer = serviceProvider; + httpContext.RequestServices = serviceProvider; + var request = httpContext.Request; + Uri requestUri = new Uri($"http://localhost/?{queryString}"); + request.Method = HttpMethods.Get; + request.Scheme = requestUri.Scheme; + request.Host = new HostString(requestUri.Host); + request.QueryString = new QueryString(requestUri.Query); + request.Path = new PathString(requestUri.AbsolutePath); + var options = new ODataQueryOptions(context, request); + + return options.ApplyTo(query, dataQuerySettings); + } + + private static IEdmModel CreatEdmModel() + { + var builder = new ODataConventionModelBuilder(); + + var adressModel = builder.ComplexType
(); + adressModel.Property(o => o.City); + adressModel.Property(o => o.Country); + adressModel.Property(o => o.Fax); + adressModel.Property(o => o.PhoneNumber); + adressModel.Property(o => o.PostalCode); + adressModel.Property(o => o.Region); + adressModel.Property(o => o.Street); + + var customerModel = builder.EntitySet(nameof(Customer)); + customerModel.EntityType.HasKey(o => o.CustomerId); + customerModel.EntityType.Property(o => o.CompanyName); + customerModel.EntityType.Property(o => o.ContactTitle); + customerModel.EntityType.ComplexProperty(o => o.Address); + customerModel.EntityType.HasMany(o => o.Orders); + + var orderModel = builder.EntitySet(nameof(Order)); + orderModel.EntityType.HasKey(o => o.OrderId); + orderModel.EntityType.Property(o => o.Freight); + orderModel.EntityType.Property(o => o.OrderDate); + orderModel.EntityType.Property(o => o.RequiredDate); + orderModel.EntityType.Property(o => o.ShippedTo); + orderModel.EntityType.Property(o => o.ShippingDate); + orderModel.EntityType.ComplexProperty(o => o.ShippingAddress); + orderModel.EntityType.HasRequired(o => o.Customer); + orderModel.EntityType.HasOptional(o => o.Employee); + + var employeeModel = builder.EntitySet(nameof(Employee)); + employeeModel.EntityType.HasKey(o => o.EmployeeId); + employeeModel.EntityType.Property(o => o.BirthDate); + employeeModel.EntityType.Property(o => o.Extension); + employeeModel.EntityType.Property(o => o.FirstName); + employeeModel.EntityType.Property(o => o.HireDate); + employeeModel.EntityType.Property(o => o.LastName); + employeeModel.EntityType.Property(o => o.Notes); + employeeModel.EntityType.Property(o => o.Title); + employeeModel.EntityType.HasMany(o => o.Orders); + + return builder.GetEdmModel(); + } + + private class ODataServiceProvider : IServiceProvider + { + private readonly Dictionary _singletonObjects = new Dictionary(); + + public ODataServiceProvider(Dictionary singletonObjects) + { + _singletonObjects = singletonObjects; + } + + public object GetService(System.Type serviceType) + { + if (_singletonObjects.TryGetValue(serviceType, out var service)) + { + return service; + } + + var ctor = serviceType.GetConstructor(new System.Type[0]); + if (ctor != null) + { + return ctor.Invoke(new object[0]); + } + + ctor = serviceType.GetConstructor(new[] { typeof(DefaultQuerySettings) }); + if (ctor != null) + { + return ctor.Invoke(new object[] { GetService(typeof(DefaultQuerySettings)) }); + } + + ctor = serviceType.GetConstructor(new[] { typeof(IServiceProvider) }); + if (ctor != null) + { + return ctor.Invoke(new object[] { this }); + } + + return null; + } + } + } +} diff --git a/src/NHibernate.Test/NHibernate.Test.csproj b/src/NHibernate.Test/NHibernate.Test.csproj index 11e964f39a6..8b57857cf85 100644 --- a/src/NHibernate.Test/NHibernate.Test.csproj +++ b/src/NHibernate.Test/NHibernate.Test.csproj @@ -52,6 +52,7 @@ + diff --git a/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs b/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs index fe15d0bcf76..ed47d2c8363 100644 --- a/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs +++ b/src/NHibernate/Linq/GroupBy/GroupBySelectClauseRewriter.cs @@ -3,12 +3,12 @@ using System.Linq; using System.Linq.Expressions; using NHibernate.Linq.Expressions; +using NHibernate.Linq.Visitors; using Remotion.Linq; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Clauses.ResultOperators; using Remotion.Linq.Parsing; -using Remotion.Linq.Parsing.ExpressionVisitors; namespace NHibernate.Linq.GroupBy { diff --git a/src/NHibernate/Linq/GroupBy/GroupKeyNominator.cs b/src/NHibernate/Linq/GroupBy/GroupKeyNominator.cs index e5149289afa..75779b6dc18 100644 --- a/src/NHibernate/Linq/GroupBy/GroupKeyNominator.cs +++ b/src/NHibernate/Linq/GroupBy/GroupKeyNominator.cs @@ -61,6 +61,27 @@ protected override Expression VisitNew(NewExpression expression) return Expression.New(expression.Constructor, expression.Arguments.Select(VisitInternal), expression.Members); } + protected override Expression VisitMemberInit(MemberInitExpression node) + { + _transformed = true; + return Expression.MemberInit((NewExpression) VisitInternal(node.NewExpression), node.Bindings.Select(VisitMemberBinding)); + } + + protected override MemberAssignment VisitMemberAssignment(MemberAssignment node) + { + return node.Update(VisitInternal(node.Expression)); + } + + protected override MemberListBinding VisitMemberListBinding(MemberListBinding node) + { + return node.Update(node.Initializers.Select(o => o.Update(o.Arguments.Select(VisitInternal)))); + } + + protected override MemberMemberBinding VisitMemberMemberBinding(MemberMemberBinding node) + { + return node.Update(node.Bindings.Select(VisitMemberBinding)); + } + protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpression expression) { // If the (sub)expression contains a QuerySourceReference, then the entire expression should be nominated diff --git a/src/NHibernate/Linq/GroupResultOperatorExtensions.cs b/src/NHibernate/Linq/GroupResultOperatorExtensions.cs index 2a510db8685..2c73e050340 100644 --- a/src/NHibernate/Linq/GroupResultOperatorExtensions.cs +++ b/src/NHibernate/Linq/GroupResultOperatorExtensions.cs @@ -19,12 +19,31 @@ private static IEnumerable ExtractKeyExpressions(this Expression exp { // Recursively extract key expressions from nested initializers // --> new object[] { ((object)new object[] { x.A, x.B }), x.C } - // --> x.A, x.B, x.C - if (expr is NewExpression) - return (expr as NewExpression).Arguments.SelectMany(ExtractKeyExpressions); - if (expr is NewArrayExpression) - return (expr as NewArrayExpression).Expressions.SelectMany(ExtractKeyExpressions); + // --> new List { x.A, x.B } + // --> new CustomType(x.A, x.B) { C = x.C } + if (expr is NewExpression newExpression) + return newExpression.Arguments.SelectMany(ExtractKeyExpressions); + if (expr is NewArrayExpression newArrayExpression) + return newArrayExpression.Expressions.SelectMany(ExtractKeyExpressions); + if (expr is MemberInitExpression memberInitExpression) + return memberInitExpression.NewExpression.Arguments.SelectMany(ExtractKeyExpressions) + .Union(memberInitExpression.Bindings.SelectMany(ExtractKeyExpressions)); return new[] { expr }; } + + private static IEnumerable ExtractKeyExpressions(MemberBinding memberBinding) + { + switch (memberBinding) + { + case MemberAssignment memberAssignment: + return memberAssignment.Expression.ExtractKeyExpressions(); + case MemberMemberBinding memberMemberBinding: + return memberMemberBinding.Bindings.SelectMany(ExtractKeyExpressions); + case MemberListBinding memberListBinding: + return memberListBinding.Initializers.SelectMany(o => o.Arguments).SelectMany(ExtractKeyExpressions); + default: + return Enumerable.Empty(); + } + } } } diff --git a/src/NHibernate/Linq/Visitors/PagingRewriterSelectClauseVisitor.cs b/src/NHibernate/Linq/Visitors/PagingRewriterSelectClauseVisitor.cs index 5e996b1f311..156e2daf9d5 100644 --- a/src/NHibernate/Linq/Visitors/PagingRewriterSelectClauseVisitor.cs +++ b/src/NHibernate/Linq/Visitors/PagingRewriterSelectClauseVisitor.cs @@ -2,7 +2,6 @@ using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; using Remotion.Linq.Parsing; -using Remotion.Linq.Parsing.ExpressionVisitors; namespace NHibernate.Linq.Visitors { diff --git a/src/NHibernate/Linq/Visitors/TransparentIdentifierRemovingExpressionVisitor.cs b/src/NHibernate/Linq/Visitors/TransparentIdentifierRemovingExpressionVisitor.cs new file mode 100644 index 00000000000..a2901b4068d --- /dev/null +++ b/src/NHibernate/Linq/Visitors/TransparentIdentifierRemovingExpressionVisitor.cs @@ -0,0 +1,119 @@ +// Copyright (c) rubicon IT GmbH, www.rubicon.eu +// +// See the NOTICE file distributed with this work for additional information +// regarding copyright ownership. rubicon licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may not use this +// file except in compliance with the License. You may obtain a copy of the +// License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. +// + +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using Remotion.Linq; +using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Parsing; +using MemberBinding = Remotion.Linq.Parsing.ExpressionVisitors.MemberBindings.MemberBinding; + +namespace NHibernate.Linq.Visitors +{ + // Copied from Relinq and added a fallback for comparing two member info by DeclaringType and Name + // 6.0 TODO: drop if https://github.com/OData/WebApi/issues/2108 is fixed and add a possible breaking + // change requiring to upgrade OData. (See https://github.com/nhibernate/nhibernate-core/pull/2322#discussion_r401215456 ) + /// + /// Replaces expression patterns of the form new T { x = 1, y = 2 }.x () or + /// new T ( x = 1, y = 2 ).x () to 1 (or 2 if y is accessed instead of x). + /// Expressions are also replaced within subqueries; the is changed by the replacement operations, it is not copied. + /// + internal sealed class TransparentIdentifierRemovingExpressionVisitor : RelinqExpressionVisitor + { + public static Expression ReplaceTransparentIdentifiers(Expression expression) + { + Expression expressionBeforeRemove; + Expression expressionAfterRemove = expression; + + // Run again and again until no replacements have been made. + do + { + expressionBeforeRemove = expressionAfterRemove; + expressionAfterRemove = new TransparentIdentifierRemovingExpressionVisitor().Visit(expressionAfterRemove); + } while (expressionAfterRemove != expressionBeforeRemove); + + return expressionAfterRemove; + } + + private TransparentIdentifierRemovingExpressionVisitor() + { + } + + protected override Expression VisitMember(MemberExpression memberExpression) + { + var memberBindings = GetMemberBindingsCreatedByExpression(memberExpression.Expression); + if (memberBindings == null) + return base.VisitMember(memberExpression); + + var matchingAssignment = memberBindings + .Where(binding => binding.MatchesReadAccess(memberExpression.Member)) + .LastOrDefault(); + + // Added logic: In some cases (e.g OData), the member can be from a different derived class, in such case + // we need to check the member DeclaringType instead of ReflectedType + if (matchingAssignment == null && memberExpression.Expression.NodeType == ExpressionType.MemberInit) + { + matchingAssignment = memberBindings + .Where(binding => AreEqual(binding.BoundMember, memberExpression.Member)) + .LastOrDefault(); + } + + if (matchingAssignment == null) + return base.VisitMember(memberExpression); + else + return matchingAssignment.AssociatedExpression; + } + + protected override Expression VisitSubQuery(SubQueryExpression expression) + { + expression.QueryModel.TransformExpressions(ReplaceTransparentIdentifiers); + return expression; // Note that we modifiy the (mutable) QueryModel, we return an unchanged expression + } + + private IEnumerable GetMemberBindingsCreatedByExpression(Expression expression) + { + var memberInitExpression = expression as MemberInitExpression; + if (memberInitExpression != null) + { + return memberInitExpression.Bindings + .Where(binding => binding is MemberAssignment) + .Select(assignment => MemberBinding.Bind(assignment.Member, ((MemberAssignment) assignment).Expression)); + } + else + { + var newExpression = expression as NewExpression; + if (newExpression != null && newExpression.Members != null) + return GetMemberBindingsForNewExpression(newExpression); + else + return null; + } + } + + private IEnumerable GetMemberBindingsForNewExpression(NewExpression newExpression) + { + for (int i = 0; i < newExpression.Members.Count; ++i) + yield return MemberBinding.Bind(newExpression.Members[i], newExpression.Arguments[i]); + } + + private static bool AreEqual(MemberInfo memberInfo, MemberInfo toComapre) + { + return memberInfo.DeclaringType == toComapre.DeclaringType && memberInfo.Name == toComapre.Name; + } + } +} From ac1877b7a50a5996f11b39cea999fb04c91f90ad Mon Sep 17 00:00:00 2001 From: Roman Artiukhin Date: Mon, 13 Apr 2020 17:39:31 +0300 Subject: [PATCH 33/43] Support basic arithmetic operations (+, -, *, /) in QueryOver (#2156) --- .../Criteria/Lambda/IntegrationFixture.cs | 35 ++++ .../Criteria/Lambda/IntegrationFixture.cs | 35 ++++ .../Criterion/ConstantProjection.cs | 10 +- src/NHibernate/Impl/ExpressionProcessor.cs | 158 ++++++++++++------ 4 files changed, 180 insertions(+), 58 deletions(-) diff --git a/src/NHibernate.Test/Async/Criteria/Lambda/IntegrationFixture.cs b/src/NHibernate.Test/Async/Criteria/Lambda/IntegrationFixture.cs index 017634dc4f5..abd2733e24c 100644 --- a/src/NHibernate.Test/Async/Criteria/Lambda/IntegrationFixture.cs +++ b/src/NHibernate.Test/Async/Criteria/Lambda/IntegrationFixture.cs @@ -510,5 +510,40 @@ public async Task StatelessSessionAsync() Assert.That(statelessPerson2.Id, Is.EqualTo(personId)); } } + + [Test] + public async Task QueryOverArithmeticAsync() + { + using (ISession s = OpenSession()) + using (ITransaction t = s.BeginTransaction()) + { + await (s.SaveAsync(new Person() {Name = "test person 1", Age = 20})); + await (s.SaveAsync(new Person() {Name = "test person 2", Age = 50})); + await (t.CommitAsync()); + } + + using (var s = OpenSession()) + { + var persons1 = await (s.QueryOver().Where(p => ((p.Age * 2) / 2) + 20 - 20 == 20).ListAsync()); + var persons2 = await (s.QueryOver().Where(p => (-(-p.Age)) > 20).ListAsync()); + var persons3 = await (s.QueryOver().WhereRestrictionOn(p => ((p.Age * 2) / 2) + 20 - 20).IsBetween(19).And(21).ListAsync()); + var persons4 = await (s.QueryOver().WhereRestrictionOn(p => -(-p.Age)).IsBetween(19).And(21).ListAsync()); + var persons5 = await (s.QueryOver().WhereRestrictionOn(p => ((p.Age * 2) / 2) + 20 - 20).IsBetween(19).And(51).ListAsync()); + var persons6 = await (s.QueryOver().Where(p => ((p.Age * 2) / 2) + 20 - 20 == p.Age - p.Age + 20).ListAsync()); +#pragma warning disable CS0472 // The result of the expression is always the same since a value of this type is never equal to 'null' + var persons7 = await (s.QueryOver().Where(p => ((p.Age * 2) / 2) + 20 - 20 == null || p.Age * 2 == 20 * 1).ListAsync()); +#pragma warning restore CS0472 // The result of the expression is always the same since a value of this type is never equal to 'null' + var val1 = await (s.QueryOver().Select(p => p.Age * 2).Where(p => p.Age == 20).SingleOrDefaultAsync()); + + Assert.That(persons1.Count, Is.EqualTo(1)); + Assert.That(persons2.Count, Is.EqualTo(1)); + Assert.That(persons3.Count, Is.EqualTo(1)); + Assert.That(persons4.Count, Is.EqualTo(1)); + Assert.That(persons5.Count, Is.EqualTo(2)); + Assert.That(persons6.Count, Is.EqualTo(1)); + Assert.That(persons7.Count, Is.EqualTo(0)); + Assert.That(val1, Is.EqualTo(40)); + } + } } } diff --git a/src/NHibernate.Test/Criteria/Lambda/IntegrationFixture.cs b/src/NHibernate.Test/Criteria/Lambda/IntegrationFixture.cs index 0b3f46ba784..dd23c4b91b4 100644 --- a/src/NHibernate.Test/Criteria/Lambda/IntegrationFixture.cs +++ b/src/NHibernate.Test/Criteria/Lambda/IntegrationFixture.cs @@ -498,5 +498,40 @@ public void StatelessSession() Assert.That(statelessPerson2.Id, Is.EqualTo(personId)); } } + + [Test] + public void QueryOverArithmetic() + { + using (ISession s = OpenSession()) + using (ITransaction t = s.BeginTransaction()) + { + s.Save(new Person() {Name = "test person 1", Age = 20}); + s.Save(new Person() {Name = "test person 2", Age = 50}); + t.Commit(); + } + + using (var s = OpenSession()) + { + var persons1 = s.QueryOver().Where(p => ((p.Age * 2) / 2) + 20 - 20 == 20).List(); + var persons2 = s.QueryOver().Where(p => (-(-p.Age)) > 20).List(); + var persons3 = s.QueryOver().WhereRestrictionOn(p => ((p.Age * 2) / 2) + 20 - 20).IsBetween(19).And(21).List(); + var persons4 = s.QueryOver().WhereRestrictionOn(p => -(-p.Age)).IsBetween(19).And(21).List(); + var persons5 = s.QueryOver().WhereRestrictionOn(p => ((p.Age * 2) / 2) + 20 - 20).IsBetween(19).And(51).List(); + var persons6 = s.QueryOver().Where(p => ((p.Age * 2) / 2) + 20 - 20 == p.Age - p.Age + 20).List(); +#pragma warning disable CS0472 // The result of the expression is always the same since a value of this type is never equal to 'null' + var persons7 = s.QueryOver().Where(p => ((p.Age * 2) / 2) + 20 - 20 == null || p.Age * 2 == 20 * 1).List(); +#pragma warning restore CS0472 // The result of the expression is always the same since a value of this type is never equal to 'null' + var val1 = s.QueryOver().Select(p => p.Age * 2).Where(p => p.Age == 20).SingleOrDefault(); + + Assert.That(persons1.Count, Is.EqualTo(1)); + Assert.That(persons2.Count, Is.EqualTo(1)); + Assert.That(persons3.Count, Is.EqualTo(1)); + Assert.That(persons4.Count, Is.EqualTo(1)); + Assert.That(persons5.Count, Is.EqualTo(2)); + Assert.That(persons6.Count, Is.EqualTo(1)); + Assert.That(persons7.Count, Is.EqualTo(0)); + Assert.That(val1, Is.EqualTo(40)); + } + } } } diff --git a/src/NHibernate/Criterion/ConstantProjection.cs b/src/NHibernate/Criterion/ConstantProjection.cs index 3504d633ab6..2d901a51b76 100644 --- a/src/NHibernate/Criterion/ConstantProjection.cs +++ b/src/NHibernate/Criterion/ConstantProjection.cs @@ -13,7 +13,7 @@ namespace NHibernate.Criterion public class ConstantProjection : SimpleProjection { private readonly object value; - private readonly TypedValue typedValue; + public TypedValue TypedValue { get; } public ConstantProjection(object value) : this(value, NHibernateUtil.GuessType(value.GetType())) { @@ -22,7 +22,7 @@ public ConstantProjection(object value) : this(value, NHibernateUtil.GuessType(v public ConstantProjection(object value, IType type) { this.value = value; - typedValue = new TypedValue(type, this.value); + TypedValue = new TypedValue(type, this.value); } public override bool IsAggregate @@ -43,19 +43,19 @@ public override bool IsGrouped public override SqlString ToSqlString(ICriteria criteria, int position, ICriteriaQuery criteriaQuery) { return new SqlString( - criteriaQuery.NewQueryParameter(typedValue).Single(), + criteriaQuery.NewQueryParameter(TypedValue).Single(), " as ", GetColumnAliases(position, criteria, criteriaQuery)[0]); } public override IType[] GetTypes(ICriteria criteria, ICriteriaQuery criteriaQuery) { - return new IType[] { typedValue.Type }; + return new IType[] { TypedValue.Type }; } public override TypedValue[] GetTypedValues(ICriteria criteria, ICriteriaQuery criteriaQuery) { - return new TypedValue[] { typedValue }; + return new TypedValue[] { TypedValue }; } } } diff --git a/src/NHibernate/Impl/ExpressionProcessor.cs b/src/NHibernate/Impl/ExpressionProcessor.cs index 874d1c6e063..945a4402494 100644 --- a/src/NHibernate/Impl/ExpressionProcessor.cs +++ b/src/NHibernate/Impl/ExpressionProcessor.cs @@ -5,6 +5,9 @@ using System.Runtime.CompilerServices; using System.Text.RegularExpressions; using NHibernate.Criterion; +using NHibernate.Dialect.Function; +using NHibernate.Engine; +using NHibernate.Type; using NHibernate.Util; using Expression = System.Linq.Expressions.Expression; @@ -84,16 +87,18 @@ public Order CreateOrder(Func orderStringDelegate, Func /// Retrieve the property name from a supplied PropertyProjection - /// Note: throws if the supplied IProjection is not a PropertyProjection + /// Note: throws if the supplied IProjection is not a IPropertyProjection ///
public string AsProperty() { if (_property != null) return _property; - var propertyProjection = _projection as PropertyProjection; + var propertyProjection = _projection as IPropertyProjection; if (propertyProjection == null) throw new InvalidOperationException("Cannot determine property for " + _projection); return propertyProjection.PropertyName; } + + internal bool IsConstant(out ConstantProjection value) => (value = _projection as ConstantProjection) != null; } private static readonly Dictionary> _simpleExpressionCreators; @@ -101,6 +106,8 @@ public string AsProperty() private static readonly Dictionary>> _subqueryExpressionCreatorTypes; private static readonly Dictionary> _customMethodCallProcessors; private static readonly Dictionary> _customProjectionProcessors; + private static readonly Dictionary _binaryArithmethicTemplates = new Dictionary(); + private static readonly ISQLFunction _unaryNegateTemplate; static ExpressionProcessor() { @@ -195,6 +202,17 @@ static ExpressionProcessor() RegisterCustomProjection(() => Math.Round(default(double), default(int)), ProjectionsExtensions.ProcessRound); RegisterCustomProjection(() => Math.Round(default(decimal), default(int)), ProjectionsExtensions.ProcessRound); RegisterCustomProjection(() => ProjectionsExtensions.AsEntity(default(object)), ProjectionsExtensions.ProcessAsEntity); + + RegisterBinaryArithmeticExpression(ExpressionType.Add, "+"); + RegisterBinaryArithmeticExpression(ExpressionType.Subtract, "-"); + RegisterBinaryArithmeticExpression(ExpressionType.Multiply, "*"); + RegisterBinaryArithmeticExpression(ExpressionType.Divide, "/"); + _unaryNegateTemplate = new VarArgsSQLFunction("(-", string.Empty, ")"); + } + + private static void RegisterBinaryArithmeticExpression(ExpressionType type, string sqlOperand) + { + _binaryArithmethicTemplates[type] = new VarArgsSQLFunction("(", sqlOperand, ")"); } private static ICriterion Eq(ProjectionInfo property, object value) @@ -245,15 +263,13 @@ public static object FindValue(Expression expression) public static ProjectionInfo FindMemberProjection(Expression expression) { if (!IsMemberExpression(expression)) - return ProjectionInfo.ForProjection(Projections.Constant(FindValue(expression))); + return AsArithmeticProjection(expression) + ?? ProjectionInfo.ForProjection(Projections.Constant(FindValue(expression), NHibernateUtil.GuessType(expression.Type))); - var unaryExpression = expression as UnaryExpression; - if (unaryExpression != null) + var unwrapExpression = UnwrapConvertExpression(expression); + if (unwrapExpression != null) { - if (!IsConversion(unaryExpression.NodeType)) - throw new ArgumentException("Cannot interpret member from " + expression, nameof(expression)); - - return FindMemberProjection(unaryExpression.Operand); + return FindMemberProjection(unwrapExpression); } var methodCallExpression = expression as MethodCallExpression; @@ -266,20 +282,69 @@ public static ProjectionInfo FindMemberProjection(Expression expression) return ProjectionInfo.ForProjection(processor(methodCallExpression)); } } - var memberExpression = expression as MemberExpression; - if (memberExpression != null) + var memberExpression = expression as MemberExpression; + if (memberExpression != null) { - var signature = Signature(memberExpression.Member); + var signature = Signature(memberExpression.Member); Func processor; if (_customProjectionProcessors.TryGetValue(signature, out processor)) { - return ProjectionInfo.ForProjection(processor(memberExpression)); + return ProjectionInfo.ForProjection(processor(memberExpression)); } } return ProjectionInfo.ForProperty(FindMemberExpression(expression)); } + private static Expression UnwrapConvertExpression(Expression expression) + { + if (expression is UnaryExpression unaryExpression) + { + if (!IsConversion(unaryExpression.NodeType)) + { + if (IsSupportedUnaryExpression(unaryExpression)) + return null; + + throw new ArgumentException("Cannot interpret member from " + expression, nameof(expression)); + } + return unaryExpression.Operand; + } + + return null; + } + + private static bool IsSupportedUnaryExpression(UnaryExpression expression) + { + return expression.NodeType == ExpressionType.Negate; + } + + private static ProjectionInfo AsArithmeticProjection(Expression expression) + { + if (!(expression is BinaryExpression be)) + { + if (expression is UnaryExpression unary && unary.NodeType == ExpressionType.Negate) + { + return ProjectionInfo.ForProjection( + new SqlFunctionProjection(_unaryNegateTemplate, TypeFactory.HeuristicType(unary.Type), FindMemberProjection(unary.Operand).AsProjection())); + } + + var unwrapExpression = UnwrapConvertExpression(expression); + return unwrapExpression != null ? AsArithmeticProjection(unwrapExpression) : null; + } + + if (!_binaryArithmethicTemplates.TryGetValue(be.NodeType, out var template)) + { + return null; + } + + return ProjectionInfo.ForProjection( + new SqlFunctionProjection( + template, + TypeFactory.HeuristicType(be.Type), + FindMemberProjection(be.Left).AsProjection(), + FindMemberProjection(be.Right).AsProjection())); + } + //http://stackoverflow.com/a/2509524/259946 private static readonly Regex GeneratedMemberNameRegex = new Regex(@"^(CS\$)?<\w*>[1-9a-s]__[a-zA-Z]+[0-9]*$", RegexOptions.Compiled | RegexOptions.Singleline); @@ -407,13 +472,10 @@ private static System.Type FindMemberType(Expression expression) return memberExpression.Type; } - var unaryExpression = expression as UnaryExpression; - if (unaryExpression != null) + var unwrapExpression = UnwrapConvertExpression(expression); + if (unwrapExpression != null) { - if (!IsConversion(unaryExpression.NodeType)) - throw new ArgumentException("Cannot interpret member from " + expression, nameof(expression)); - - return FindMemberType(unaryExpression.Operand); + return FindMemberType(unwrapExpression); } var methodCallExpression = expression as MethodCallExpression; @@ -422,6 +484,9 @@ private static System.Type FindMemberType(Expression expression) return methodCallExpression.Method.ReturnType; } + if (expression is BinaryExpression || expression is UnaryExpression) + return expression.Type; + throw new ArgumentException("Could not determine member type from " + expression, nameof(expression)); } @@ -443,13 +508,10 @@ private static bool IsMemberExpression(Expression expression) return EvaluatesToNull(memberExpression.Expression); } - var unaryExpression = expression as UnaryExpression; - if (unaryExpression != null) + var unwrapExpression = UnwrapConvertExpression(expression); + if (unwrapExpression != null) { - if (!IsConversion(unaryExpression.NodeType)) - throw new ArgumentException("Cannot interpret member from " + expression, nameof(expression)); - - return IsMemberExpression(unaryExpression.Operand); + return IsMemberExpression(unwrapExpression); } var methodCallExpression = expression as MethodCallExpression; @@ -504,21 +566,12 @@ private static object ConvertType(object value, System.Type type) throw new ArgumentException(string.Format("Cannot convert '{0}' to {1}", value, type)); } - private static ICriterion ProcessSimpleExpression(BinaryExpression be) - { - if (be.Left.NodeType == ExpressionType.Call && ((MethodCallExpression)be.Left).Method.Name == "CompareString") - return ProcessVisualBasicStringComparison(be); - - return ProcessSimpleExpression(be.Left, be.Right, be.NodeType); - } - - private static ICriterion ProcessSimpleExpression(Expression left, Expression right, ExpressionType nodeType) + private static ICriterion ProcessSimpleExpression(Expression left, TypedValue rightValue, ExpressionType nodeType) { ProjectionInfo property = FindMemberProjection(left); System.Type propertyType = FindMemberType(left); - object value = FindValue(right); - value = ConvertType(value, propertyType); + var value = ConvertType(rightValue.Value, propertyType); if (value == null) return ProcessSimpleNullExpression(property, nodeType); @@ -530,14 +583,17 @@ private static ICriterion ProcessSimpleExpression(Expression left, Expression ri return simpleExpressionCreator(property, value); } - private static ICriterion ProcessVisualBasicStringComparison(BinaryExpression be) + private static ICriterion ProcessAsVisualBasicStringComparison(Expression left, ExpressionType nodeType) { - var methodCall = (MethodCallExpression)be.Left; + if (left.NodeType != ExpressionType.Call) + { + return null; + } - if (IsMemberExpression(methodCall.Arguments[1])) - return ProcessMemberExpression(methodCall.Arguments[0], methodCall.Arguments[1], be.NodeType); - else - return ProcessSimpleExpression(methodCall.Arguments[0], methodCall.Arguments[1], be.NodeType); + var methodCall = (MethodCallExpression) left; + return methodCall.Method.Name == "CompareString" + ? ProcessMemberExpression(methodCall.Arguments[0], methodCall.Arguments[1], nodeType) + : null; } private static ICriterion ProcessSimpleNullExpression(ProjectionInfo property, ExpressionType expressionType) @@ -552,16 +608,16 @@ private static ICriterion ProcessSimpleNullExpression(ProjectionInfo property, E throw new ArgumentException("Cannot supply null value to operator " + expressionType, nameof(expressionType)); } - private static ICriterion ProcessMemberExpression(BinaryExpression be) - { - return ProcessMemberExpression(be.Left, be.Right, be.NodeType); - } - private static ICriterion ProcessMemberExpression(Expression left, Expression right, ExpressionType nodeType) { - ProjectionInfo leftProperty = FindMemberProjection(left); ProjectionInfo rightProperty = FindMemberProjection(right); + if (rightProperty.IsConstant(out var constProjection)) + { + return ProcessAsVisualBasicStringComparison(left, nodeType) + ?? ProcessSimpleExpression(left, constProjection.TypedValue, nodeType); + } + ProjectionInfo leftProperty = FindMemberProjection(left); Func propertyExpressionCreator; if (!_propertyExpressionCreators.TryGetValue(nodeType, out propertyExpressionCreator)) throw new InvalidOperationException("Unhandled property expression type: " + nodeType); @@ -599,11 +655,7 @@ private static ICriterion ProcessBinaryExpression(BinaryExpression expression) case ExpressionType.GreaterThanOrEqual: case ExpressionType.LessThan: case ExpressionType.LessThanOrEqual: - if (IsMemberExpression(expression.Right)) - return ProcessMemberExpression(expression); - else - return ProcessSimpleExpression(expression); - + return ProcessMemberExpression(expression.Left, expression.Right, expression.NodeType); default: throw new NotImplementedException("Unhandled binary expression: " + expression.NodeType + ", " + expression); } From 03950167c9d0fd5a6577afe8fb14455c8a87e879 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Delaporte?= <12201973+fredericDelaporte@users.noreply.github.com> Date: Tue, 14 Apr 2020 19:27:18 +0200 Subject: [PATCH 34/43] Fix SQLite typing (#2346) Co-authored-by: maca88 --- src/NHibernate/Dialect/SQLiteDialect.cs | 61 +++++++++++++++-------- src/NHibernate/Type/TimeAsTimeSpanType.cs | 9 ++-- 2 files changed, 45 insertions(+), 25 deletions(-) diff --git a/src/NHibernate/Dialect/SQLiteDialect.cs b/src/NHibernate/Dialect/SQLiteDialect.cs index 864defac9a3..5eb1f51dd87 100644 --- a/src/NHibernate/Dialect/SQLiteDialect.cs +++ b/src/NHibernate/Dialect/SQLiteDialect.cs @@ -38,32 +38,48 @@ public SQLiteDialect() protected virtual void RegisterColumnTypes() { + // SQLite really has only five types, and a very lax typing system, see https://www.sqlite.org/datatype3.html + // Please do not map (again) fancy types that do not actually exist in SQLite, as this is kind of supported by + // SQLite but creates bugs in convert operations. RegisterColumnType(DbType.Binary, "BLOB"); - RegisterColumnType(DbType.Byte, "TINYINT"); - RegisterColumnType(DbType.Int16, "SMALLINT"); - RegisterColumnType(DbType.Int32, "INT"); - RegisterColumnType(DbType.Int64, "BIGINT"); + RegisterColumnType(DbType.Byte, "INTEGER"); + RegisterColumnType(DbType.Int16, "INTEGER"); + RegisterColumnType(DbType.Int32, "INTEGER"); + RegisterColumnType(DbType.Int64, "INTEGER"); RegisterColumnType(DbType.SByte, "INTEGER"); RegisterColumnType(DbType.UInt16, "INTEGER"); RegisterColumnType(DbType.UInt32, "INTEGER"); RegisterColumnType(DbType.UInt64, "INTEGER"); - RegisterColumnType(DbType.Currency, "NUMERIC"); - RegisterColumnType(DbType.Decimal, "NUMERIC"); - RegisterColumnType(DbType.Double, "DOUBLE"); - RegisterColumnType(DbType.Single, "DOUBLE"); - RegisterColumnType(DbType.VarNumeric, "NUMERIC"); + + // NUMERIC and REAL are almost the same, they are binary floating point numbers. There is only a slight difference + // for values without a floating part. They will be represented as integers with numeric, but still as floating + // values with real. The side-effect of this is numeric being able of storing exactly bigger integers than real. + // But it also creates bugs in division, when dividing two numeric happening to be integers, the result is then + // never fractional. So we use "REAL" for all. + RegisterColumnType(DbType.Currency, "REAL"); + RegisterColumnType(DbType.Decimal, "REAL"); + RegisterColumnType(DbType.Double, "REAL"); + RegisterColumnType(DbType.Single, "REAL"); + RegisterColumnType(DbType.VarNumeric, "REAL"); + RegisterColumnType(DbType.AnsiString, "TEXT"); RegisterColumnType(DbType.String, "TEXT"); RegisterColumnType(DbType.AnsiStringFixedLength, "TEXT"); RegisterColumnType(DbType.StringFixedLength, "TEXT"); - RegisterColumnType(DbType.Date, "DATE"); - RegisterColumnType(DbType.DateTime, "DATETIME"); - RegisterColumnType(DbType.Time, "TIME"); - RegisterColumnType(DbType.Boolean, "BOOL"); - // UNIQUEIDENTIFIER is not a SQLite type, but SQLite does not care much, see - // https://www.sqlite.org/datatype3.html - RegisterColumnType(DbType.Guid, "UNIQUEIDENTIFIER"); + // https://www.sqlite.org/datatype3.html#boolean_datatype + RegisterColumnType(DbType.Boolean, "INTEGER"); + + // See https://www.sqlite.org/datatype3.html#date_and_time_datatype, we have three choices for date and time + // The one causing the less issues in case of an explicit cast is text. Beware, System.Data.SQLite has an + // internal use only "DATETIME" type. Using it causes it to directly convert the text stored into SQLite to + // a .Net DateTime, but also causes columns in SQLite to have numeric affinity and convert to destroy the + // value. As said in their chm documentation, this "DATETIME" type is for System.Data.SQLite internal use only. + RegisterColumnType(DbType.Date, "TEXT"); + RegisterColumnType(DbType.DateTime, "TEXT"); + RegisterColumnType(DbType.Time, "TEXT"); + + RegisterColumnType(DbType.Guid, _binaryGuid ? "BLOB" : "TEXT"); } protected virtual void RegisterFunctions() @@ -98,8 +114,6 @@ protected virtual void RegisterFunctions() RegisterFunction("iif", new SQLFunctionTemplate(null, "case when ?1 then ?2 else ?3 end")); - RegisterFunction("cast", new SQLiteCastFunction()); - RegisterFunction("round", new StandardSQLFunction("round")); // SQLite has no built-in support of bitwise xor, but can emulate it. @@ -112,7 +126,7 @@ protected virtual void RegisterFunctions() if (_binaryGuid) RegisterFunction("strguid", new SQLFunctionTemplate(NHibernateUtil.String, "substr(hex(?1), 7, 2) || substr(hex(?1), 5, 2) || substr(hex(?1), 3, 2) || substr(hex(?1), 1, 2) || '-' || substr(hex(?1), 11, 2) || substr(hex(?1), 9, 2) || '-' || substr(hex(?1), 15, 2) || substr(hex(?1), 13, 2) || '-' || substr(hex(?1), 17, 4) || '-' || substr(hex(?1), 21) ")); else - RegisterFunction("strguid", new SQLFunctionTemplate(NHibernateUtil.String, "cast(?1 as char)")); + RegisterFunction("strguid", new SQLFunctionTemplate(NHibernateUtil.String, "cast(?1 as text)")); // SQLite random function yields a long, ranging form MinValue to MaxValue. (-9223372036854775808 to // 9223372036854775807). HQL random requires a float from 0 inclusive to 1 exclusive, so we divide by @@ -131,7 +145,8 @@ public override void Configure(IDictionary settings) ConfigureBinaryGuid(settings); - // Re-register functions depending on settings. + // Re-register functions and types depending on settings. + RegisterColumnTypes(); RegisterFunctions(); } @@ -485,13 +500,15 @@ public override bool SupportsForeignKeyConstraintInAlterTable /// public override int MaxAliasLength => 128; + // Since v5.3 + [Obsolete("This class has no usage in NHibernate anymore and will be removed in a future version. Use or extend CastFunction instead.")] [Serializable] protected class SQLiteCastFunction : CastFunction { protected override bool CastingIsRequired(string sqlType) { - // SQLite doesn't support casting to datetime types. It assumes you want an integer and destroys the date string. - if (StringHelper.ContainsCaseInsensitive(sqlType, "date") || StringHelper.ContainsCaseInsensitive(sqlType, "time")) + if (StringHelper.ContainsCaseInsensitive(sqlType, "date") || + StringHelper.ContainsCaseInsensitive(sqlType, "time")) return false; return true; } diff --git a/src/NHibernate/Type/TimeAsTimeSpanType.cs b/src/NHibernate/Type/TimeAsTimeSpanType.cs index 51fa6745b57..e525ecfa555 100644 --- a/src/NHibernate/Type/TimeAsTimeSpanType.cs +++ b/src/NHibernate/Type/TimeAsTimeSpanType.cs @@ -43,10 +43,13 @@ public override object Get(DbDataReader rs, int index, ISessionImplementor sessi try { var value = rs[index]; - if(value is TimeSpan time) //For those dialects where DbType.Time means TimeSpan. + if (value is TimeSpan time) //For those dialects where DbType.Time means TimeSpan. return time; - - return ((DateTime)value).TimeOfDay; + + // Todo: investigate if this convert should be made culture invariant, here and in other NHibernate types, + // such as AbstractDateTimeType and TimeType, or even in all other places doing such converts in NHibernate. + var dbValue = Convert.ToDateTime(value); + return dbValue.TimeOfDay; } catch (Exception ex) { From a217713aa60351992ae5c255b278ef3e14425468 Mon Sep 17 00:00:00 2001 From: "g.yakimov" Date: Thu, 16 Apr 2020 19:47:06 +0300 Subject: [PATCH 35/43] consider usage of QueryOver --- src/NHibernate.Test/Hql/EntityJoinHqlTest.cs | 39 +++++++++++++++++++ .../Criteria/CriteriaQueryTranslator.cs | 2 +- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs b/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs index f1b82cb1da6..c4efd67e915 100644 --- a/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs +++ b/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs @@ -458,6 +458,45 @@ public void Join_Inheritance() Assert.That(results, Is.EquivalentTo(new[] { visit_1.Id, visit_2.Id, })); } + [Test] + public void Join_Inheritance_QueryOver() + { + // arrange + IEnumerable results; + var person = new PersonBase { Login = "dave", FamilyName = "grohl" }; + var visit_1 = new UserEntityVisit { PersonBase = person }; + var visit_2 = new UserEntityVisit { PersonBase = person }; + + using (ISession arrangeSession = OpenSession()) + using (ITransaction tx = arrangeSession.BeginTransaction()) + { + arrangeSession.Save(person); + arrangeSession.Save(visit_1); + arrangeSession.Save(visit_2); + arrangeSession.Flush(); + + tx.Commit(); + } + + // act + using (var session = OpenSession()) + { + PersonBase f = null; + results = + session.QueryOver() + .JoinAlias( + x => x.PersonBase, + () => f, + SqlCommand.JoinType.LeftOuterJoin, + Restrictions.Where(() => f.Deleted == false)) + .List() + .Select(x => x.Id); + } + + // assert + Assert.That(results, Is.EquivalentTo(new[] { visit_1.Id, visit_2.Id, })); + } + #region Test Setup protected override HbmMapping GetMappings() diff --git a/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs b/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs index 6dafb3fda1d..ecafd801a7a 100644 --- a/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs +++ b/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs @@ -770,7 +770,7 @@ private bool TryGetColumns(ICriteria subcriteria, string path, bool verifyProper // here we can check if the condition belongs to a with clause bool useLastIndex = false; var withClause = pathCriteria as Subcriteria != null ? ((Subcriteria) pathCriteria).WithClause as SimpleExpression : null; - if (withClause != null && withClause.PropertyName == propertyName) + if (withClause != null && withClause.PropertyName.EndsWith(propertyName)) { useLastIndex = true; } From 1382cc8320abf131603a54b5abed22ead372ff10 Mon Sep 17 00:00:00 2001 From: "g.yakimov" Date: Thu, 23 Apr 2020 10:40:43 +0300 Subject: [PATCH 36/43] Revert "Remove unused useLastIndex parameter" This reverts commit 6abb7448f3b951fa0bb1d4626a2bc1b43737cb4c. Revert "Remove unused useLastIndex parameter" This reverts commit bc99fedb2e3146a30fa3dae6513b5a3f5a291d0e. --- src/NHibernate/Persister/Entity/AbstractEntityPersister.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs index 51dc9100eba..32639db4d0e 100644 --- a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs @@ -2114,7 +2114,7 @@ public virtual int GetSubclassPropertyTableNumber(string propertyPath, bool useL ? Array.LastIndexOf(SubclassPropertyNameClosure, rootPropertyName) : Array.IndexOf(SubclassPropertyNameClosure, rootPropertyName); //TODO: optimize this better! - return index == -1 ? 0 : GetSubclassPropertyTableNumber(index); + return index == -1 ? 0 : GetSubclassPropertyTableNumber(index, false); } public virtual Declarer GetSubclassPropertyDeclarer(string propertyPath) From c885345b2bbb236d8917f2e178133068751dfd85 Mon Sep 17 00:00:00 2001 From: Alexander Zaytsev Date: Fri, 8 May 2020 22:01:42 +1200 Subject: [PATCH 37/43] Move tests to where they belong --- src/NHibernate.Test/Hql/EntityJoinHqlTest.cs | 84 ----------------- .../NHSpecificTest/GH2330/Entity.cs | 10 +++ .../NHSpecificTest/GH2330/FixtureByCode.cs | 90 +++++++++++++++++++ .../{Hql => NHSpecificTest/GH2330}/Node.cs | 2 +- 4 files changed, 101 insertions(+), 85 deletions(-) create mode 100644 src/NHibernate.Test/NHSpecificTest/GH2330/Entity.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/GH2330/FixtureByCode.cs rename src/NHibernate.Test/{Hql => NHSpecificTest/GH2330}/Node.cs (96%) diff --git a/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs b/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs index dd67ce6dd88..37c7779efe3 100644 --- a/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs +++ b/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs @@ -421,82 +421,6 @@ public void CrossJoinAndWhereClause() } } - [Test] - public void Join_Inheritance() - { - // arrange - IEnumerable results; - var person = new PersonBase { Login = "dave", FamilyName = "grohl" }; - var visit_1 = new UserEntityVisit { PersonBase = person }; - var visit_2 = new UserEntityVisit { PersonBase = person }; - - using (ISession arrangeSession = OpenSession()) - using (ITransaction tx = arrangeSession.BeginTransaction()) - { - arrangeSession.Save(person); - arrangeSession.Save(visit_1); - arrangeSession.Save(visit_2); - arrangeSession.Flush(); - - tx.Commit(); - } - - // act - using (var session = OpenSession()) - { - results = session.CreateCriteria() - .CreateCriteria( - $"{nameof(UserEntityVisit.PersonBase)}", - "f", - SqlCommand.JoinType.LeftOuterJoin, - Restrictions.Eq("Deleted", false)) - .List() - .Select(x => x.Id); - } - - // assert - Assert.That(results, Is.EquivalentTo(new[] { visit_1.Id, visit_2.Id, })); - } - - [Test] - public void Join_Inheritance_QueryOver() - { - // arrange - IEnumerable results; - var person = new PersonBase { Login = "dave", FamilyName = "grohl" }; - var visit_1 = new UserEntityVisit { PersonBase = person }; - var visit_2 = new UserEntityVisit { PersonBase = person }; - - using (ISession arrangeSession = OpenSession()) - using (ITransaction tx = arrangeSession.BeginTransaction()) - { - arrangeSession.Save(person); - arrangeSession.Save(visit_1); - arrangeSession.Save(visit_2); - arrangeSession.Flush(); - - tx.Commit(); - } - - // act - using (var session = OpenSession()) - { - PersonBase f = null; - results = - session.QueryOver() - .JoinAlias( - x => x.PersonBase, - () => f, - SqlCommand.JoinType.LeftOuterJoin, - Restrictions.Where(() => f.Deleted == false)) - .List() - .Select(x => x.Id); - } - - // assert - Assert.That(results, Is.EquivalentTo(new[] { visit_1.Id, visit_2.Id, })); - } - #region Test Setup protected override HbmMapping GetMappings() @@ -568,9 +492,6 @@ protected override HbmMapping GetMappings() rc.Id(e => e.Id, m => m.Generator(Generators.GuidComb)); rc.Property(e => e.Name); }); - - Node.AddMapping(mapper); - UserEntityVisit.AddMapping(mapper); mapper.Class( rc => @@ -604,11 +525,6 @@ protected override HbmMapping GetMappings() }); }); - - - Node.AddMapping(mapper); - UserEntityVisit.AddMapping(mapper); - return mapper.CompileMappingForAllExplicitlyAddedEntities(); } diff --git a/src/NHibernate.Test/NHSpecificTest/GH2330/Entity.cs b/src/NHibernate.Test/NHSpecificTest/GH2330/Entity.cs new file mode 100644 index 00000000000..0c24058f9ba --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH2330/Entity.cs @@ -0,0 +1,10 @@ +using System; + +namespace NHibernate.Test.NHSpecificTest.GH2330 +{ + class Entity + { + public virtual Guid Id { get; set; } + public virtual string Name { get; set; } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH2330/FixtureByCode.cs b/src/NHibernate.Test/NHSpecificTest/GH2330/FixtureByCode.cs new file mode 100644 index 00000000000..9223534a5bc --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH2330/FixtureByCode.cs @@ -0,0 +1,90 @@ +using System.Linq; +using NHibernate.Cfg.MappingSchema; +using NHibernate.Criterion; +using NHibernate.Mapping.ByCode; +using NHibernate.Test.Hql; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH2330 +{ + [TestFixture] + public class ByCodeFixture : TestCaseMappingByCode + { + private object _visit1Id; + private object _visit2Id; + + protected override HbmMapping GetMappings() + { + var mapper = new ModelMapper(); + + Node.AddMapping(mapper); + UserEntityVisit.AddMapping(mapper); + + return mapper.CompileMappingForAllExplicitlyAddedEntities(); + } + + protected override void OnSetUp() + { + using (var arrangeSession = OpenSession()) + using (var tx = arrangeSession.BeginTransaction()) + { + var person = new PersonBase {Login = "dave", FamilyName = "grohl"}; + arrangeSession.Save(person); + _visit1Id = arrangeSession.Save(new UserEntityVisit {PersonBase = person}); + _visit2Id = arrangeSession.Save(new UserEntityVisit {PersonBase = person}); + arrangeSession.Flush(); + + tx.Commit(); + } + } + + protected override void OnTearDown() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + session.Delete("from System.Object"); + + transaction.Commit(); + } + } + + [Test] + public void Join_Inheritance() + { + using (var session = OpenSession()) + using (session.BeginTransaction()) + { + var results = session + .CreateCriteria() + .CreateCriteria( + $"{nameof(UserEntityVisit.PersonBase)}", + "f", + SqlCommand.JoinType.LeftOuterJoin, + Restrictions.Eq("Deleted", false)) + .List() + .Select(x => x.Id); + Assert.That(results, Is.EquivalentTo(new[] {_visit1Id, _visit2Id,})); + } + } + + [Test] + public void Join_Inheritance_QueryOver() + { + using (var session = OpenSession()) + { + PersonBase f = null; + var results = session.QueryOver() + .JoinAlias( + x => x.PersonBase, + () => f, + SqlCommand.JoinType.LeftOuterJoin, + Restrictions.Where(() => f.Deleted == false)) + .List() + .Select(x => x.Id); + // assert + Assert.That(results, Is.EquivalentTo(new[] {_visit1Id, _visit2Id,})); + } + } + } +} diff --git a/src/NHibernate.Test/Hql/Node.cs b/src/NHibernate.Test/NHSpecificTest/GH2330/Node.cs similarity index 96% rename from src/NHibernate.Test/Hql/Node.cs rename to src/NHibernate.Test/NHSpecificTest/GH2330/Node.cs index e3cb2937002..d0e4cbf139c 100644 --- a/src/NHibernate.Test/Hql/Node.cs +++ b/src/NHibernate.Test/NHSpecificTest/GH2330/Node.cs @@ -1,7 +1,7 @@ using System; using NHibernate.Mapping.ByCode; -namespace NHibernate.Test.Hql +namespace NHibernate.Test.NHSpecificTest.GH2330 { public abstract class Node { From 2349e5679469439171387a10b1ec749204acf6e2 Mon Sep 17 00:00:00 2001 From: Alexander Zaytsev Date: Fri, 8 May 2020 22:04:07 +1200 Subject: [PATCH 38/43] Revert changes --- src/NHibernate.Test/Hql/EntityJoinHqlTest.cs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs b/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs index 37c7779efe3..f01b6303d30 100644 --- a/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs +++ b/src/NHibernate.Test/Hql/EntityJoinHqlTest.cs @@ -1,8 +1,5 @@ -using System.Collections.Generic; -using System.Linq; -using System.Text.RegularExpressions; +using System.Text.RegularExpressions; using NHibernate.Cfg.MappingSchema; -using NHibernate.Criterion; using NHibernate.Mapping.ByCode; using NHibernate.Test.Hql.EntityJoinHqlTestEntities; using NUnit.Framework; @@ -525,6 +522,7 @@ protected override HbmMapping GetMappings() }); }); + return mapper.CompileMappingForAllExplicitlyAddedEntities(); } From 310482dca0fbd6380c0387405344ebe21333d706 Mon Sep 17 00:00:00 2001 From: Alexander Zaytsev Date: Fri, 8 May 2020 22:06:13 +1200 Subject: [PATCH 39/43] Cleanup tests --- src/NHibernate.Test/NHSpecificTest/GH2330/Entity.cs | 10 ---------- ...ureByCode.cs => JoinedSubclassWithClauseFixture.cs} | 7 ++++--- 2 files changed, 4 insertions(+), 13 deletions(-) delete mode 100644 src/NHibernate.Test/NHSpecificTest/GH2330/Entity.cs rename src/NHibernate.Test/NHSpecificTest/GH2330/{FixtureByCode.cs => JoinedSubclassWithClauseFixture.cs} (95%) diff --git a/src/NHibernate.Test/NHSpecificTest/GH2330/Entity.cs b/src/NHibernate.Test/NHSpecificTest/GH2330/Entity.cs deleted file mode 100644 index 0c24058f9ba..00000000000 --- a/src/NHibernate.Test/NHSpecificTest/GH2330/Entity.cs +++ /dev/null @@ -1,10 +0,0 @@ -using System; - -namespace NHibernate.Test.NHSpecificTest.GH2330 -{ - class Entity - { - public virtual Guid Id { get; set; } - public virtual string Name { get; set; } - } -} diff --git a/src/NHibernate.Test/NHSpecificTest/GH2330/FixtureByCode.cs b/src/NHibernate.Test/NHSpecificTest/GH2330/JoinedSubclassWithClauseFixture.cs similarity index 95% rename from src/NHibernate.Test/NHSpecificTest/GH2330/FixtureByCode.cs rename to src/NHibernate.Test/NHSpecificTest/GH2330/JoinedSubclassWithClauseFixture.cs index 9223534a5bc..69e0f575f63 100644 --- a/src/NHibernate.Test/NHSpecificTest/GH2330/FixtureByCode.cs +++ b/src/NHibernate.Test/NHSpecificTest/GH2330/JoinedSubclassWithClauseFixture.cs @@ -2,13 +2,12 @@ using NHibernate.Cfg.MappingSchema; using NHibernate.Criterion; using NHibernate.Mapping.ByCode; -using NHibernate.Test.Hql; using NUnit.Framework; namespace NHibernate.Test.NHSpecificTest.GH2330 { [TestFixture] - public class ByCodeFixture : TestCaseMappingByCode + public class JoinedSubclassWithClauseFixture : TestCaseMappingByCode { private object _visit1Id; private object _visit2Id; @@ -64,6 +63,7 @@ public void Join_Inheritance() Restrictions.Eq("Deleted", false)) .List() .Select(x => x.Id); + Assert.That(results, Is.EquivalentTo(new[] {_visit1Id, _visit2Id,})); } } @@ -72,6 +72,7 @@ public void Join_Inheritance() public void Join_Inheritance_QueryOver() { using (var session = OpenSession()) + using (session.BeginTransaction()) { PersonBase f = null; var results = session.QueryOver() @@ -82,7 +83,7 @@ public void Join_Inheritance_QueryOver() Restrictions.Where(() => f.Deleted == false)) .List() .Select(x => x.Id); - // assert + Assert.That(results, Is.EquivalentTo(new[] {_visit1Id, _visit2Id,})); } } From c8386ee2017e87af7852aeb80629bd42cacbfbde Mon Sep 17 00:00:00 2001 From: Alexander Zaytsev Date: Fri, 8 May 2020 22:08:50 +1200 Subject: [PATCH 40/43] Cleanup tests --- .../GH2330/JoinedSubclassWithClauseFixture.cs | 27 ++++++++- .../NHSpecificTest/GH2330/Node.cs | 55 +------------------ 2 files changed, 28 insertions(+), 54 deletions(-) diff --git a/src/NHibernate.Test/NHSpecificTest/GH2330/JoinedSubclassWithClauseFixture.cs b/src/NHibernate.Test/NHSpecificTest/GH2330/JoinedSubclassWithClauseFixture.cs index 69e0f575f63..7348a5da7a0 100644 --- a/src/NHibernate.Test/NHSpecificTest/GH2330/JoinedSubclassWithClauseFixture.cs +++ b/src/NHibernate.Test/NHSpecificTest/GH2330/JoinedSubclassWithClauseFixture.cs @@ -16,8 +16,31 @@ protected override HbmMapping GetMappings() { var mapper = new ModelMapper(); - Node.AddMapping(mapper); - UserEntityVisit.AddMapping(mapper); + mapper.Class(ca => + { + ca.Id(x => x.Id, map => map.Generator(Generators.Identity)); + ca.Property(x => x.Deleted); + ca.Property(x => x.FamilyName); + ca.Table("Node"); + ca.Abstract(true); + }); + + mapper.JoinedSubclass( + ca => + { + ca.Key(x => x.Column("FK_Node_ID")); + ca.Extends(typeof(Node)); + ca.Property(x => x.Deleted); + ca.Property(x => x.Login); + }); + + mapper.Class( + ca => + { + ca.Id(x => x.Id, map => map.Generator(Generators.Identity)); + ca.Property(x => x.Deleted); + ca.ManyToOne(x => x.PersonBase); + }); return mapper.CompileMappingForAllExplicitlyAddedEntities(); } diff --git a/src/NHibernate.Test/NHSpecificTest/GH2330/Node.cs b/src/NHibernate.Test/NHSpecificTest/GH2330/Node.cs index d0e4cbf139c..bad76be6d32 100644 --- a/src/NHibernate.Test/NHSpecificTest/GH2330/Node.cs +++ b/src/NHibernate.Test/NHSpecificTest/GH2330/Node.cs @@ -1,39 +1,12 @@ using System; -using NHibernate.Mapping.ByCode; namespace NHibernate.Test.NHSpecificTest.GH2330 { public abstract class Node { - private int _id; - public virtual int Id - { - get { return _id; } - set { _id = value; } - } - + public virtual int Id { get; set; } public virtual bool Deleted { get; set; } public virtual string FamilyName { get; set; } - - public static void AddMapping(ModelMapper mapper) - { - mapper.Class(ca => - { - ca.Id(x => x.Id, map => map.Generator(Generators.Identity)); - ca.Property(x => x.Deleted); - ca.Property(x => x.FamilyName); - ca.Table("Node"); - ca.Abstract(true); - }); - - mapper.JoinedSubclass(ca => - { - ca.Key(x => x.Column("FK_Node_ID")); - ca.Extends(typeof(Node)); - ca.Property(x => x.Deleted); - ca.Property(x => x.Login); - }); - } } [Serializable] @@ -46,30 +19,8 @@ public class PersonBase : Node [Serializable] public class UserEntityVisit { - private int _id; - public virtual int Id - { - get { return _id; } - set { _id = value; } - } - + public virtual int Id { get; set; } public virtual bool Deleted { get; set; } - - private PersonBase _PersonBase; - public virtual PersonBase PersonBase - { - get { return _PersonBase; } - set { _PersonBase = value; } - } - - public static void AddMapping(ModelMapper mapper) - { - mapper.Class(ca => - { - ca.Id(x => x.Id, map => map.Generator(Generators.Identity)); - ca.Property(x => x.Deleted); - ca.ManyToOne(x => x.PersonBase); - }); - } + public virtual PersonBase PersonBase { get; set; } } } From b1df2038c1a5e82fcb016d450444175de34fd177 Mon Sep 17 00:00:00 2001 From: Alexander Zaytsev Date: Fri, 8 May 2020 22:12:49 +1200 Subject: [PATCH 41/43] Generate async --- .../GH2330/JoinedSubclassWithClauseFixture.cs | 125 ++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 src/NHibernate.Test/Async/NHSpecificTest/GH2330/JoinedSubclassWithClauseFixture.cs diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH2330/JoinedSubclassWithClauseFixture.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH2330/JoinedSubclassWithClauseFixture.cs new file mode 100644 index 00000000000..9d9444737a6 --- /dev/null +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH2330/JoinedSubclassWithClauseFixture.cs @@ -0,0 +1,125 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System.Linq; +using NHibernate.Cfg.MappingSchema; +using NHibernate.Criterion; +using NHibernate.Mapping.ByCode; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH2330 +{ + using System.Threading.Tasks; + [TestFixture] + public class JoinedSubclassWithClauseFixtureAsync : TestCaseMappingByCode + { + private object _visit1Id; + private object _visit2Id; + + protected override HbmMapping GetMappings() + { + var mapper = new ModelMapper(); + + mapper.Class(ca => + { + ca.Id(x => x.Id, map => map.Generator(Generators.Identity)); + ca.Property(x => x.Deleted); + ca.Property(x => x.FamilyName); + ca.Table("Node"); + ca.Abstract(true); + }); + + mapper.JoinedSubclass( + ca => + { + ca.Key(x => x.Column("FK_Node_ID")); + ca.Extends(typeof(Node)); + ca.Property(x => x.Deleted); + ca.Property(x => x.Login); + }); + + mapper.Class( + ca => + { + ca.Id(x => x.Id, map => map.Generator(Generators.Identity)); + ca.Property(x => x.Deleted); + ca.ManyToOne(x => x.PersonBase); + }); + + return mapper.CompileMappingForAllExplicitlyAddedEntities(); + } + + protected override void OnSetUp() + { + using (var arrangeSession = OpenSession()) + using (var tx = arrangeSession.BeginTransaction()) + { + var person = new PersonBase {Login = "dave", FamilyName = "grohl"}; + arrangeSession.Save(person); + _visit1Id = arrangeSession.Save(new UserEntityVisit {PersonBase = person}); + _visit2Id = arrangeSession.Save(new UserEntityVisit {PersonBase = person}); + arrangeSession.Flush(); + + tx.Commit(); + } + } + + protected override void OnTearDown() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + session.Delete("from System.Object"); + + transaction.Commit(); + } + } + + [Test] + public async Task Join_InheritanceAsync() + { + using (var session = OpenSession()) + using (session.BeginTransaction()) + { + var results = (await (session + .CreateCriteria() + .CreateCriteria( + $"{nameof(UserEntityVisit.PersonBase)}", + "f", + SqlCommand.JoinType.LeftOuterJoin, + Restrictions.Eq("Deleted", false)) + .ListAsync())) + .Select(x => x.Id); + + Assert.That(results, Is.EquivalentTo(new[] {_visit1Id, _visit2Id,})); + } + } + + [Test] + public async Task Join_Inheritance_QueryOverAsync() + { + using (var session = OpenSession()) + using (session.BeginTransaction()) + { + PersonBase f = null; + var results = (await (session.QueryOver() + .JoinAlias( + x => x.PersonBase, + () => f, + SqlCommand.JoinType.LeftOuterJoin, + Restrictions.Where(() => f.Deleted == false)) + .ListAsync())) + .Select(x => x.Id); + + Assert.That(results, Is.EquivalentTo(new[] {_visit1Id, _visit2Id,})); + } + } + } +} From 1fadd9bc70305f35bae9a6abb854514a2dfeaa03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Delaporte?= <12201973+fredericDelaporte@users.noreply.github.com> Date: Sat, 24 Oct 2020 19:04:14 +0200 Subject: [PATCH 42/43] Undo fix to only keep tests --- .../Ast/ANTLR/Tree/AssignmentSpecification.cs | 2 +- .../Hql/Ast/ANTLR/Tree/ComponentJoin.cs | 8 ++--- .../Hql/Ast/ANTLR/Tree/IntoClause.cs | 4 +-- .../Criteria/CriteriaQueryTranslator.cs | 11 +------ .../Collection/AbstractCollectionPersister.cs | 8 ++--- .../Collection/CollectionPropertyMapping.cs | 6 ++-- .../Collection/ElementPropertyMapping.cs | 6 ++-- .../Entity/AbstractEntityPersister.cs | 31 +++++++++---------- .../Entity/AbstractPropertyMapping.cs | 4 +-- .../Entity/BasicEntityPropertyMapping.cs | 7 +++-- .../Persister/Entity/IPropertyMapping.cs | 7 ++--- src/NHibernate/Persister/Entity/IQueryable.cs | 3 +- .../Entity/JoinedSubclassEntityPersister.cs | 6 ++-- .../Entity/SingleTableEntityPersister.cs | 10 +++--- .../Entity/UnionSubclassEntityPersister.cs | 4 +-- 15 files changed, 52 insertions(+), 65 deletions(-) diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/AssignmentSpecification.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/AssignmentSpecification.cs index 4f450eead3f..351d29318fa 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/AssignmentSpecification.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/AssignmentSpecification.cs @@ -61,7 +61,7 @@ public AssignmentSpecification(IASTNode eq, IQueryable persister) } else { - temp.Add(persister.GetSubclassTableName(persister.GetSubclassPropertyTableNumber(propertyPath, false))); + temp.Add(persister.GetSubclassTableName(persister.GetSubclassPropertyTableNumber(propertyPath))); } _tableNames = new HashSet(temp); diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs index 448bf8ceda2..bfffb9be928 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/ComponentJoin.cs @@ -150,14 +150,14 @@ public bool TryToType(string propertyName, out IType type) return fromElementType.GetBasePropertyMapping().TryToType(GetPropertyPath(propertyName), out type); } - public string[] ToColumns(string alias, string propertyName, bool useLastIndex = false) + public string[] ToColumns(string alias, string propertyName) { - return fromElementType.GetBasePropertyMapping().ToColumns(alias, GetPropertyPath(propertyName), useLastIndex); + return fromElementType.GetBasePropertyMapping().ToColumns(alias, GetPropertyPath(propertyName)); } - public string[] ToColumns(string propertyName, bool useLastIndex = false) + public string[] ToColumns(string propertyName) { - return fromElementType.GetBasePropertyMapping().ToColumns(GetPropertyPath(propertyName), useLastIndex); + return fromElementType.GetBasePropertyMapping().ToColumns(GetPropertyPath(propertyName)); } #endregion diff --git a/src/NHibernate/Hql/Ast/ANTLR/Tree/IntoClause.cs b/src/NHibernate/Hql/Ast/ANTLR/Tree/IntoClause.cs index 05e4b9a4248..a59e6d5a8ec 100644 --- a/src/NHibernate/Hql/Ast/ANTLR/Tree/IntoClause.cs +++ b/src/NHibernate/Hql/Ast/ANTLR/Tree/IntoClause.cs @@ -214,7 +214,7 @@ private bool IsSuperclassProperty(string propertyName) // // we may want to disallow it for discrim-subclass just for // consistency-sake (currently does not work anyway)... - return _persister.GetSubclassPropertyTableNumber(propertyName, false) != 0; + return _persister.GetSubclassPropertyTableNumber(propertyName) != 0; } /// @@ -263,4 +263,4 @@ private static bool AreSqlTypesCompatible(SqlType target, SqlType source) return target.Equals(source); } } -} +} \ No newline at end of file diff --git a/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs b/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs index 3b066671af3..955c4802dd3 100644 --- a/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs +++ b/src/NHibernate/Loader/Criteria/CriteriaQueryTranslator.cs @@ -13,7 +13,6 @@ using NHibernate.Type; using NHibernate.Util; using IQueryable = NHibernate.Persister.Entity.IQueryable; -using static NHibernate.Impl.CriteriaImpl; namespace NHibernate.Loader.Criteria { @@ -767,15 +766,7 @@ private bool TryGetColumns(ICriteria subcriteria, string path, bool verifyProper return false; } - // here we can check if the condition belongs to a with clause - bool useLastIndex = false; - var withClause = pathCriteria as Subcriteria != null ? ((Subcriteria) pathCriteria).WithClause as SimpleExpression : null; - if (withClause != null && withClause.PropertyName.EndsWith(propertyName)) - { - useLastIndex = true; - } - - columns = propertyMapping.ToColumns(GetSQLAlias(pathCriteria), propertyName, useLastIndex); + columns = propertyMapping.ToColumns(GetSQLAlias(pathCriteria), propertyName); return true; } diff --git a/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs b/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs index 1f00777843c..033673f9656 100644 --- a/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs +++ b/src/NHibernate/Persister/Collection/AbstractCollectionPersister.cs @@ -1388,7 +1388,7 @@ public bool IsManyToManyFiltered(IDictionary enabledFilters) return IsManyToMany && (manyToManyWhereString != null || manyToManyFilterHelper.IsAffectedBy(enabledFilters)); } - public string[] ToColumns(string alias, string propertyName, bool useLastIndex = false) + public string[] ToColumns(string alias, string propertyName) { if ("index".Equals(propertyName)) { @@ -1399,10 +1399,10 @@ public string[] ToColumns(string alias, string propertyName, bool useLastIndex = return StringHelper.Qualify(alias, indexColumnNames); } - return elementPropertyMapping.ToColumns(alias, propertyName, useLastIndex); + return elementPropertyMapping.ToColumns(alias, propertyName); } - public string[] ToColumns(string propertyName, bool useLastIndex = false) + public string[] ToColumns(string propertyName) { if ("index".Equals(propertyName)) { @@ -1414,7 +1414,7 @@ public string[] ToColumns(string propertyName, bool useLastIndex = false) return indexColumnNames; } - return elementPropertyMapping.ToColumns(propertyName, useLastIndex); + return elementPropertyMapping.ToColumns(propertyName); } protected abstract SqlCommandInfo GenerateDeleteString(); diff --git a/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs b/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs index 569c53eb63d..e9e2f89dc51 100644 --- a/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs +++ b/src/NHibernate/Persister/Collection/CollectionPropertyMapping.cs @@ -57,7 +57,7 @@ public bool TryToType(string propertyName, out IType type) } } - public string[] ToColumns(string alias, string propertyName, bool useLastIndex = false) + public string[] ToColumns(string alias, string propertyName) { string[] cols; switch (propertyName) @@ -107,7 +107,7 @@ public string[] ToColumns(string alias, string propertyName, bool useLastIndex = } } - public string[] ToColumns(string propertyName, bool useLastIndex = false) + public string[] ToColumns(string propertyName) { throw new System.NotSupportedException("References to collections must be define a SQL alias"); } @@ -117,4 +117,4 @@ public IType Type get { return memberPersister.CollectionType; } } } -} +} \ No newline at end of file diff --git a/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs b/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs index 6ab04642a5f..20e9899ddb6 100644 --- a/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs +++ b/src/NHibernate/Persister/Collection/ElementPropertyMapping.cs @@ -47,7 +47,7 @@ public bool TryToType(string propertyName, out IType outType) } } - public string[] ToColumns(string alias, string propertyName, bool useLastIndex) + public string[] ToColumns(string alias, string propertyName) { if (propertyName == null || "id".Equals(propertyName)) { @@ -59,7 +59,7 @@ public string[] ToColumns(string alias, string propertyName, bool useLastIndex) } } - public string[] ToColumns(string propertyName, bool useLastIndex) + public string[] ToColumns(string propertyName) { throw new System.NotSupportedException("References to collections must be define a SQL alias"); } @@ -71,4 +71,4 @@ public IType Type #endregion } -} +} \ No newline at end of file diff --git a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs index d2a24b8995f..2c71d037c0b 100644 --- a/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/AbstractEntityPersister.cs @@ -1118,9 +1118,9 @@ protected virtual bool IsIdOfTable(int property, int table) return false; } - protected abstract int GetSubclassPropertyTableNumber(int i, bool useLastIndex); + protected abstract int GetSubclassPropertyTableNumber(int i); - internal int GetSubclassPropertyTableNumber(string propertyName, string entityName, bool useLastIndex = false) + internal int GetSubclassPropertyTableNumber(string propertyName, string entityName) { var type = propertyMapping.ToType(propertyName); if (type.IsAssociationType && ((IAssociationType) type).UseLHSPrimaryKey) @@ -1271,7 +1271,7 @@ protected internal virtual SqlString GenerateLazySelectString() // use the subclass closure int propertyNumber = GetSubclassPropertyIndex(lazyPropertyNames[i]); - int tableNumber = GetSubclassPropertyTableNumber(propertyNumber, false); + int tableNumber = GetSubclassPropertyTableNumber(propertyNumber); tableNumbers.Add(tableNumber); int[] colNumbers = subclassPropertyColumnNumberClosure[propertyNumber]; @@ -1326,7 +1326,7 @@ protected virtual IDictionary GenerateLazySelectStringsByFetc // use the subclass closure var propertyNumber = GetSubclassPropertyIndex(lazyPropertyDescriptor.Name); - var tableNumber = GetSubclassPropertyTableNumber(propertyNumber, false); + var tableNumber = GetSubclassPropertyTableNumber(propertyNumber); tableNumbers.Add(tableNumber); var colNumbers = subclassPropertyColumnNumberClosure[propertyNumber]; @@ -2055,12 +2055,12 @@ public virtual string GetRootTableAlias(string drivingAlias) return drivingAlias; } - public virtual string[] ToColumns(string alias, string propertyName, bool useLastIndex = false) + public virtual string[] ToColumns(string alias, string propertyName) { - return propertyMapping.ToColumns(alias, propertyName, useLastIndex); + return propertyMapping.ToColumns(alias, propertyName); } - public string[] ToColumns(string propertyName, bool useLastIndex = false) + public string[] ToColumns(string propertyName) { return propertyMapping.GetColumnNames(propertyName); } @@ -2088,7 +2088,7 @@ public string[] GetPropertyColumnNames(string propertyName) /// SingleTableEntityPersister defines an overloaded form /// which takes the entity name. /// - public virtual int GetSubclassPropertyTableNumber(string propertyPath, bool useLastIndex) + public virtual int GetSubclassPropertyTableNumber(string propertyPath) { string rootPropertyName = StringHelper.Root(propertyPath); IType type = propertyMapping.ToType(rootPropertyName); @@ -2115,16 +2115,13 @@ public virtual int GetSubclassPropertyTableNumber(string propertyPath, bool useL return getSubclassColumnTableNumberClosure()[idx]; } }*/ - int index = useLastIndex - ? Array.LastIndexOf(SubclassPropertyNameClosure, rootPropertyName) - : Array.IndexOf(SubclassPropertyNameClosure, rootPropertyName); //TODO: optimize this better! - - return index == -1 ? 0 : GetSubclassPropertyTableNumber(index, false); + int index = Array.IndexOf(SubclassPropertyNameClosure, rootPropertyName); //TODO: optimize this better! + return index == -1 ? 0 : GetSubclassPropertyTableNumber(index); } public virtual Declarer GetSubclassPropertyDeclarer(string propertyPath) { - int tableIndex = GetSubclassPropertyTableNumber(propertyPath, false); + int tableIndex = GetSubclassPropertyTableNumber(propertyPath); if (tableIndex == 0) { return Declarer.Class; @@ -2172,7 +2169,7 @@ private string GetSubclassAliasedColumn(string rootAlias, int tableNumber, strin public string[] ToColumns(string name, int i) { - string alias = GenerateTableAlias(name, GetSubclassPropertyTableNumber(i, false)); + string alias = GenerateTableAlias(name, GetSubclassPropertyTableNumber(i)); string[] cols = GetSubclassPropertyColumnNames(i); string[] templates = SubclassPropertyFormulaTemplateClosure[i]; string[] result = new string[cols.Length]; @@ -2406,7 +2403,7 @@ private EntityLoader GetAppropriateUniqueKeyLoader(string propertyName, IDiction return uniqueKeyLoaders[propertyName]; } - return CreateUniqueKeyLoader(propertyMapping.ToType(propertyName), propertyMapping.ToColumns(propertyName, false), enabledFilters); + return CreateUniqueKeyLoader(propertyMapping.ToType(propertyName), propertyMapping.ToColumns(propertyName), enabledFilters); } public int GetPropertyIndex(string propertyName) @@ -3690,7 +3687,7 @@ private IDictionary GetColumnsToTableAliasMap(string rootAlias) if (cols != null && cols.Length > 0) { - PropertyKey key = new PropertyKey(cols[0], GetSubclassPropertyTableNumber(i, false)); + PropertyKey key = new PropertyKey(cols[0], GetSubclassPropertyTableNumber(i)); propDictionary[key] = property; } } diff --git a/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs b/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs index 40f9550802e..c027568bf18 100644 --- a/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs +++ b/src/NHibernate/Persister/Entity/AbstractPropertyMapping.cs @@ -44,7 +44,7 @@ public bool TryToType(string propertyName, out IType type) return typesByPropertyPath.TryGetValue(propertyName, out type); } - public virtual string[] ToColumns(string alias, string propertyName, bool useLastIndex) + public virtual string[] ToColumns(string alias, string propertyName) { //TODO: *two* hashmap lookups here is one too many... string[] columns = GetColumns(propertyName); @@ -71,7 +71,7 @@ private string[] GetColumns(string propertyName) return columns; } - public virtual string[] ToColumns(string propertyName, bool useLastIndex) + public virtual string[] ToColumns(string propertyName) { string[] columns = GetColumns(propertyName); diff --git a/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs b/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs index ff0e71aefc0..52701f71697 100644 --- a/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs +++ b/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs @@ -26,10 +26,11 @@ public override IType Type get { return persister.Type; } } - public override string[] ToColumns(string alias, string propertyName, bool useLastIndex) + public override string[] ToColumns(string alias, string propertyName) { - var tableAlias = persister.GenerateTableAlias(alias, persister.GetSubclassPropertyTableNumber(propertyName, useLastIndex)); - return base.ToColumns(tableAlias, propertyName, useLastIndex); + return + base.ToColumns(persister.GenerateTableAlias(alias, persister.GetSubclassPropertyTableNumber(propertyName)), + propertyName); } } } diff --git a/src/NHibernate/Persister/Entity/IPropertyMapping.cs b/src/NHibernate/Persister/Entity/IPropertyMapping.cs index fc1dc5bf495..dbe08dd9139 100644 --- a/src/NHibernate/Persister/Entity/IPropertyMapping.cs +++ b/src/NHibernate/Persister/Entity/IPropertyMapping.cs @@ -34,11 +34,10 @@ public interface IPropertyMapping /// /// /// - /// /// - string[] ToColumns(string alias, string propertyName, bool useLastIndex = false); + string[] ToColumns(string alias, string propertyName); /// Given a property path, return the corresponding column name(s). - string[] ToColumns(string propertyName, bool useLastIndex = false); + string[] ToColumns(string propertyName); } -} +} \ No newline at end of file diff --git a/src/NHibernate/Persister/Entity/IQueryable.cs b/src/NHibernate/Persister/Entity/IQueryable.cs index 5ef7985bd89..f2de9dd6e52 100644 --- a/src/NHibernate/Persister/Entity/IQueryable.cs +++ b/src/NHibernate/Persister/Entity/IQueryable.cs @@ -113,14 +113,13 @@ public interface IQueryable : ILoadable, IPropertyMapping, IJoinable /// to which this property is mapped. ///
/// The name of the property. - /// The name of the property. /// The number of the table to which the property is mapped. /// /// Note that this is not relative to the results from {@link #getConstraintOrderedTableNameClosure()}. /// It is relative to the subclass table name closure maintained internal to the persister (yick!). /// It is also relative to the indexing used to resolve {@link #getSubclassTableName}... /// - int GetSubclassPropertyTableNumber(string propertyPath, bool useLastIndex); + int GetSubclassPropertyTableNumber(string propertyPath); /// Determine whether the given property is declared by our /// mapped class, our super class, or one of our subclasses... diff --git a/src/NHibernate/Persister/Entity/JoinedSubclassEntityPersister.cs b/src/NHibernate/Persister/Entity/JoinedSubclassEntityPersister.cs index 0c3ee9130f7..00b03e88157 100644 --- a/src/NHibernate/Persister/Entity/JoinedSubclassEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/JoinedSubclassEntityPersister.cs @@ -525,7 +525,7 @@ public override string GenerateFilterConditionAlias(string rootAlias) return GenerateTableAlias(rootAlias, tableSpan - 1); } - public override string[] ToColumns(string alias, string propertyName, bool useLastIndex) + public override string[] ToColumns(string alias, string propertyName) { if (EntityClass.Equals(propertyName)) { @@ -541,11 +541,11 @@ public override string[] ToColumns(string alias, string propertyName, bool useLa } else { - return base.ToColumns(alias, propertyName, useLastIndex); + return base.ToColumns(alias, propertyName); } } - protected override int GetSubclassPropertyTableNumber(int i, bool useLastIndex) + protected override int GetSubclassPropertyTableNumber(int i) { return subclassPropertyTableNumberClosure[i]; } diff --git a/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs b/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs index 4dbb004fc8e..9aa8a71a3e4 100644 --- a/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/SingleTableEntityPersister.cs @@ -675,7 +675,7 @@ protected override void AddDiscriminatorToSelect(SelectFragment select, string n select.AddColumn(name, DiscriminatorColumnName, DiscriminatorAlias); } - protected override int GetSubclassPropertyTableNumber(int i, bool useLastIndex) + protected override int GetSubclassPropertyTableNumber(int i) { return subclassPropertyTableNumberClosure[i]; } @@ -696,12 +696,12 @@ protected override void AddDiscriminatorToInsert(SqlInsertBuilder insert) protected override bool IsSubclassPropertyDeferred(string propertyName, string entityName) { return - hasSequentialSelects && IsSubclassTableSequentialSelect(base.GetSubclassPropertyTableNumber(propertyName, entityName, false)); + hasSequentialSelects && IsSubclassTableSequentialSelect(base.GetSubclassPropertyTableNumber(propertyName, entityName)); } protected override bool IsPropertyDeferred(int propertyIndex) { - return _hasSequentialSelect && subclassTableSequentialSelect[GetSubclassPropertyTableNumber(propertyIndex, false)]; + return _hasSequentialSelect && subclassTableSequentialSelect[GetSubclassPropertyTableNumber(propertyIndex)]; } //Since v5.3 @@ -713,9 +713,9 @@ public override bool HasSequentialSelect //Since v5.3 [Obsolete("This method has no more usage in NHibernate and will be removed in a future version.")] - public new int GetSubclassPropertyTableNumber(string propertyName, string entityName, bool useLastIndex = false) + public new int GetSubclassPropertyTableNumber(string propertyName, string entityName) { - return base.GetSubclassPropertyTableNumber(propertyName, entityName, useLastIndex); + return base.GetSubclassPropertyTableNumber(propertyName, entityName); } //Since v5.3 diff --git a/src/NHibernate/Persister/Entity/UnionSubclassEntityPersister.cs b/src/NHibernate/Persister/Entity/UnionSubclassEntityPersister.cs index 0b1dd021b61..c8d2c834cac 100644 --- a/src/NHibernate/Persister/Entity/UnionSubclassEntityPersister.cs +++ b/src/NHibernate/Persister/Entity/UnionSubclassEntityPersister.cs @@ -289,12 +289,12 @@ protected override void AddDiscriminatorToSelect(SelectFragment select, string n select.AddColumn(name, DiscriminatorColumnName, DiscriminatorAlias); } - protected override int GetSubclassPropertyTableNumber(int i, bool useLastIndex) + protected override int GetSubclassPropertyTableNumber(int i) { return 0; } - public override int GetSubclassPropertyTableNumber(string propertyName, bool useLastIndex) + public override int GetSubclassPropertyTableNumber(string propertyName) { return 0; } From 2591e3bc14c6e94da638b9a409ea2d6a8f16f03f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Delaporte?= <12201973+fredericDelaporte@users.noreply.github.com> Date: Sat, 24 Oct 2020 19:06:27 +0200 Subject: [PATCH 43/43] Undo remaining whitespace change --- src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs b/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs index 52701f71697..02f625bd550 100644 --- a/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs +++ b/src/NHibernate/Persister/Entity/BasicEntityPropertyMapping.cs @@ -30,7 +30,7 @@ public override string[] ToColumns(string alias, string propertyName) { return base.ToColumns(persister.GenerateTableAlias(alias, persister.GetSubclassPropertyTableNumber(propertyName)), - propertyName); + propertyName); } } }