diff --git a/src/NHibernate.Test/Async/Extralazy/ExtraLazyFixture.cs b/src/NHibernate.Test/Async/Extralazy/ExtraLazyFixture.cs index f2fb6f699ab..4543d6e8a6f 100644 --- a/src/NHibernate.Test/Async/Extralazy/ExtraLazyFixture.cs +++ b/src/NHibernate.Test/Async/Extralazy/ExtraLazyFixture.cs @@ -1326,6 +1326,52 @@ public async Task SetAddAsync(bool initialize) } } + [Test] + public async Task SetAddWithOverrideEqualsAsync() + { + User gavin; + User robert; + User tom; + + using (var s = OpenSession()) + using (var t = s.BeginTransaction()) + { + gavin = new User("gavin", "secret"); + robert = new User("robert", "secret"); + tom = new User("tom", "secret"); + await (s.PersistAsync(gavin)); + await (s.PersistAsync(robert)); + await (s.PersistAsync(tom)); + + gavin.Followers.Add(new UserFollower(gavin, robert)); + gavin.Followers.Add(new UserFollower(gavin, tom)); + robert.Followers.Add(new UserFollower(robert, tom)); + + Assert.That(gavin.Followers.Count, Is.EqualTo(2), "Gavin's documents count after adding 2"); + Assert.That(robert.Followers.Count, Is.EqualTo(1), "Robert's followers count after adding one"); + + await (t.CommitAsync()); + } + + using (var s = OpenSession()) + using (var t = s.BeginTransaction()) + { + gavin = await (s.GetAsync("gavin")); + robert = await (s.GetAsync("robert")); + tom = await (s.GetAsync("tom")); + + // Re-add + Assert.That(gavin.Followers.Add(new UserFollower(gavin, robert)), Is.False, "Re-adding element"); + Assert.That(NHibernateUtil.IsInitialized(gavin.Followers), Is.True, "Documents initialization status after re-adding"); + Assert.That(gavin.Followers, Has.Count.EqualTo(2), "Gavin's followers count after re-adding"); + + // Add new + Assert.That(robert.Followers.Add(new UserFollower(robert, gavin)), Is.True, "Adding element"); + Assert.That(NHibernateUtil.IsInitialized(gavin.Followers), Is.True, "Documents initialization status after adding"); + Assert.That(gavin.Followers, Has.Count.EqualTo(2), "Robert's followers count after re-adding"); + } + } + [TestCase(false, false)] [TestCase(false, true)] [TestCase(true, false)] diff --git a/src/NHibernate.Test/Extralazy/ExtraLazyFixture.cs b/src/NHibernate.Test/Extralazy/ExtraLazyFixture.cs index bc2e9e6019c..8b11af550fc 100644 --- a/src/NHibernate.Test/Extralazy/ExtraLazyFixture.cs +++ b/src/NHibernate.Test/Extralazy/ExtraLazyFixture.cs @@ -1315,6 +1315,52 @@ public void SetAdd(bool initialize) } } + [Test] + public void SetAddWithOverrideEquals() + { + User gavin; + User robert; + User tom; + + using (var s = OpenSession()) + using (var t = s.BeginTransaction()) + { + gavin = new User("gavin", "secret"); + robert = new User("robert", "secret"); + tom = new User("tom", "secret"); + s.Persist(gavin); + s.Persist(robert); + s.Persist(tom); + + gavin.Followers.Add(new UserFollower(gavin, robert)); + gavin.Followers.Add(new UserFollower(gavin, tom)); + robert.Followers.Add(new UserFollower(robert, tom)); + + Assert.That(gavin.Followers.Count, Is.EqualTo(2), "Gavin's documents count after adding 2"); + Assert.That(robert.Followers.Count, Is.EqualTo(1), "Robert's followers count after adding one"); + + t.Commit(); + } + + using (var s = OpenSession()) + using (var t = s.BeginTransaction()) + { + gavin = s.Get("gavin"); + robert = s.Get("robert"); + tom = s.Get("tom"); + + // Re-add + Assert.That(gavin.Followers.Add(new UserFollower(gavin, robert)), Is.False, "Re-adding element"); + Assert.That(NHibernateUtil.IsInitialized(gavin.Followers), Is.True, "Documents initialization status after re-adding"); + Assert.That(gavin.Followers, Has.Count.EqualTo(2), "Gavin's followers count after re-adding"); + + // Add new + Assert.That(robert.Followers.Add(new UserFollower(robert, gavin)), Is.True, "Adding element"); + Assert.That(NHibernateUtil.IsInitialized(gavin.Followers), Is.True, "Documents initialization status after adding"); + Assert.That(gavin.Followers, Has.Count.EqualTo(2), "Robert's followers count after re-adding"); + } + } + [TestCase(false, false)] [TestCase(false, true)] [TestCase(true, false)] diff --git a/src/NHibernate.Test/Extralazy/User.cs b/src/NHibernate.Test/Extralazy/User.cs index e25b92adea5..b4f174e3243 100644 --- a/src/NHibernate.Test/Extralazy/User.cs +++ b/src/NHibernate.Test/Extralazy/User.cs @@ -50,6 +50,8 @@ public virtual ISet Photos public virtual ISet Permissions { get; set; } = new HashSet(); + public virtual ISet Followers { get; set; } = new HashSet(); + public virtual IList Companies { get; set; } = new List(); public virtual IList CreditCards { get; set; } = new List(); diff --git a/src/NHibernate.Test/Extralazy/UserFollower.cs b/src/NHibernate.Test/Extralazy/UserFollower.cs new file mode 100644 index 00000000000..8c3ba8f48b4 --- /dev/null +++ b/src/NHibernate.Test/Extralazy/UserFollower.cs @@ -0,0 +1,46 @@ +using System; + +namespace NHibernate.Test.Extralazy +{ + public class UserFollower : IEquatable + { + public UserFollower(User user, User follower) + { + User = user; + Follower = follower; + } + + protected UserFollower() + { + } + + public virtual int Id { get; set; } + + public virtual User User { get; set; } + + public virtual User Follower { get; set; } + + public override bool Equals(object obj) + { + if (obj == null) return false; + if (ReferenceEquals(this, obj)) return true; + if (obj.GetType() != GetType()) return false; + return Equals((UserFollower) obj); + } + + public virtual bool Equals(UserFollower other) + { + if (other == null) return false; + if (ReferenceEquals(this, other)) return true; + return Equals(User.Name, other.User.Name) && Equals(Follower.Name, other.Follower.Name); + } + + public override int GetHashCode() + { + unchecked + { + return (User.Name.GetHashCode() * 397) ^ Follower.Name.GetHashCode(); + } + } + } +} diff --git a/src/NHibernate.Test/Extralazy/UserGroup.hbm.xml b/src/NHibernate.Test/Extralazy/UserGroup.hbm.xml index 9641c5800c6..2dd3fb6433d 100644 --- a/src/NHibernate.Test/Extralazy/UserGroup.hbm.xml +++ b/src/NHibernate.Test/Extralazy/UserGroup.hbm.xml @@ -30,6 +30,10 @@ + + + + @@ -77,6 +81,13 @@ + + + + + + + diff --git a/src/NHibernate/Async/Collection/AbstractPersistentCollection.cs b/src/NHibernate/Async/Collection/AbstractPersistentCollection.cs index aba0ae46271..95897e09f8e 100644 --- a/src/NHibernate/Async/Collection/AbstractPersistentCollection.cs +++ b/src/NHibernate/Async/Collection/AbstractPersistentCollection.cs @@ -74,13 +74,14 @@ public abstract partial class AbstractPersistentCollection : IPersistentCollecti return null; } - internal async Task IsTransientAsync(object element, CancellationToken cancellationToken) + internal async Task CanSkipElementExistenceCheckAsync(object element, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); var queryableCollection = (IQueryableCollection) Session.Factory.GetCollectionPersister(Role); return queryableCollection != null && queryableCollection.ElementType.IsEntityType && + !queryableCollection.ElementPersister.EntityMetamodel.OverridesEquals && !element.IsProxy() && !Session.PersistenceContext.IsEntryFor(element) && await (ForeignKeys.IsTransientFastAsync(queryableCollection.ElementPersister.EntityName, element, Session, cancellationToken)).ConfigureAwait(false) == true; diff --git a/src/NHibernate/Collection/AbstractPersistentCollection.cs b/src/NHibernate/Collection/AbstractPersistentCollection.cs index 13e0957914a..7626b0e2509 100644 --- a/src/NHibernate/Collection/AbstractPersistentCollection.cs +++ b/src/NHibernate/Collection/AbstractPersistentCollection.cs @@ -775,12 +775,13 @@ private AbstractQueueOperationTracker TryFlushAndGetQueueOperationTracker(string return queueOperationTracker; } - internal bool IsTransient(object element) + internal bool CanSkipElementExistenceCheck(object element) { var queryableCollection = (IQueryableCollection) Session.Factory.GetCollectionPersister(Role); return queryableCollection != null && queryableCollection.ElementType.IsEntityType && + !queryableCollection.ElementPersister.EntityMetamodel.OverridesEquals && !element.IsProxy() && !Session.PersistenceContext.IsEntryFor(element) && ForeignKeys.IsTransientFast(queryableCollection.ElementPersister.EntityName, element, Session) == true; diff --git a/src/NHibernate/Collection/Generic/PersistentGenericSet.cs b/src/NHibernate/Collection/Generic/PersistentGenericSet.cs index c4df138fe5a..f0318b7d382 100644 --- a/src/NHibernate/Collection/Generic/PersistentGenericSet.cs +++ b/src/NHibernate/Collection/Generic/PersistentGenericSet.cs @@ -314,8 +314,8 @@ public bool Contains(T item) public bool Add(T o) { // Skip checking the element existence in the database if we know that the element - // is transient and the operation queue is enabled - if (WasInitialized || !IsOperationQueueEnabled || !IsTransient(o)) + // is transient, the mapped class does not override Equals method and the operation queue is enabled + if (WasInitialized || !IsOperationQueueEnabled || !CanSkipElementExistenceCheck(o)) { var exists = IsOperationQueueEnabled ? ReadElementExistence(o, out _) : null; if (!exists.HasValue) diff --git a/src/NHibernate/Tuple/Entity/EntityMetamodel.cs b/src/NHibernate/Tuple/Entity/EntityMetamodel.cs index 3b793d348a8..2a02711b2bf 100644 --- a/src/NHibernate/Tuple/Entity/EntityMetamodel.cs +++ b/src/NHibernate/Tuple/Entity/EntityMetamodel.cs @@ -90,6 +90,7 @@ public EntityMetamodel(PersistentClass persistentClass, ISessionFactoryImplement type = persistentClass.MappedClass; rootType = persistentClass.RootClazz.MappedClass; rootTypeAssemblyQualifiedName = rootType == null ? null : rootType.AssemblyQualifiedName; + OverridesEquals = type != null && ReflectHelper.OverridesEquals(type); // type will be null for dynamic entities identifierProperty = PropertyFactory.BuildIdentifierProperty(persistentClass, sessionFactory.GetIdentifierGenerator(rootName)); @@ -549,6 +550,8 @@ public StandardProperty[] Properties get { return properties; } } + internal bool OverridesEquals { get; private set; } + public int GetPropertyIndex(string propertyName) { int? index = GetPropertyIndexOrNull(propertyName);