diff --git a/src/AsyncGenerator.yml b/src/AsyncGenerator.yml index 99c765a110c..5c7754819fb 100644 --- a/src/AsyncGenerator.yml +++ b/src/AsyncGenerator.yml @@ -160,9 +160,6 @@ transformation: configureAwaitArgument: false localFunctions: true - asyncLock: - type: NHibernate.Util.AsyncLock - methodName: LockAsync documentationComments: addOrReplaceMethodSummary: - name: Commit diff --git a/src/NHibernate.Test/Async/UtilityTest/AsyncReaderWriterLockFixture.cs b/src/NHibernate.Test/Async/UtilityTest/AsyncReaderWriterLockFixture.cs new file mode 100644 index 00000000000..b22f20a3cd0 --- /dev/null +++ b/src/NHibernate.Test/Async/UtilityTest/AsyncReaderWriterLockFixture.cs @@ -0,0 +1,215 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using NHibernate.Util; +using NUnit.Framework; + +namespace NHibernate.Test.UtilityTest +{ + public partial class AsyncReaderWriterLockFixture + { + + [Test, Explicit] + public async Task TestConcurrentReadWriteAsync() + { + var l = new AsyncReaderWriterLock(); + for (var i = 0; i < 2; i++) + { + var writeReleaser = await (l.WriteLockAsync()); + Assert.That(l.Writing, Is.True); + + var secondWriteSemaphore = new SemaphoreSlim(0); + var secondWriteReleaser = default(AsyncReaderWriterLock.Releaser); + var secondWriteThread = new Thread( + () => + { + secondWriteSemaphore.Wait(); + secondWriteReleaser = l.WriteLock(); + }); + secondWriteThread.Priority = ThreadPriority.Highest; + secondWriteThread.Start(); + await (AssertEqualValueAsync(() => secondWriteThread.ThreadState == ThreadState.WaitSleepJoin, true)); + + var secondReadThreads = new Thread[20]; + var secondReadReleasers = new AsyncReaderWriterLock.Releaser[secondReadThreads.Length]; + var secondReadSemaphore = new SemaphoreSlim(0); + for (var j = 0; j < secondReadReleasers.Length; j++) + { + var index = j; + var thread = new Thread( + () => + { + secondReadSemaphore.Wait(); + secondReadReleasers[index] = l.ReadLock(); + }); + thread.Priority = ThreadPriority.Highest; + secondReadThreads[j] = thread; + thread.Start(); + } + + await (AssertEqualValueAsync(() => secondReadThreads.All(o => o.ThreadState == ThreadState.WaitSleepJoin), true)); + + var firstReadReleaserTasks = new Task[30]; + var firstReadStopSemaphore = new SemaphoreSlim(0); + for (var j = 0; j < firstReadReleaserTasks.Length; j++) + { + firstReadReleaserTasks[j] = Task.Run(async () => + { + var releaser = await (l.ReadLockAsync()); + await (firstReadStopSemaphore.WaitAsync()); + releaser.Dispose(); + }); + } + + await (AssertEqualValueAsync(() => l.ReadersWaiting, firstReadReleaserTasks.Length, waitDelay: 60000)); + + writeReleaser.Dispose(); + secondWriteSemaphore.Release(); + secondReadSemaphore.Release(secondReadThreads.Length); + await (Task.Delay(1000)); + firstReadStopSemaphore.Release(firstReadReleaserTasks.Length); + + await (AssertEqualValueAsync(() => firstReadReleaserTasks.All(o => o.IsCompleted), true)); + Assert.That(l.ReadersWaiting, Is.EqualTo(secondReadThreads.Length)); + Assert.That(l.CurrentReaders, Is.EqualTo(0)); + await (AssertEqualValueAsync(() => secondWriteThread.IsAlive, false)); + await (AssertEqualValueAsync(() => secondReadThreads.All(o => o.IsAlive), true)); + + secondWriteReleaser.Dispose(); + await (AssertEqualValueAsync(() => secondReadThreads.All(o => !o.IsAlive), true)); + + Assert.That(l.ReadersWaiting, Is.EqualTo(0)); + Assert.That(l.CurrentReaders, Is.EqualTo(secondReadThreads.Length)); + + foreach (var secondReadReleaser in secondReadReleasers) + { + secondReadReleaser.Dispose(); + } + + Assert.That(l.ReadersWaiting, Is.EqualTo(0)); + Assert.That(l.CurrentReaders, Is.EqualTo(0)); + } + } + + [Test] + public async Task TestInvaildExitReadLockUsageAsync() + { + var l = new AsyncReaderWriterLock(); + var readReleaser = await (l.ReadLockAsync()); + var readReleaser2 = await (l.ReadLockAsync()); + + readReleaser.Dispose(); + readReleaser2.Dispose(); + Assert.Throws(() => readReleaser.Dispose()); + Assert.Throws(() => readReleaser2.Dispose()); + } + + [Test] + public void TestOperationAfterDisposeAsync() + { + var l = new AsyncReaderWriterLock(); + l.Dispose(); + + Assert.ThrowsAsync(() => l.ReadLockAsync()); + Assert.ThrowsAsync(() => l.WriteLockAsync()); + } + + [Test] + public async Task TestInvaildExitWriteLockUsageAsync() + { + var l = new AsyncReaderWriterLock(); + var writeReleaser = await (l.WriteLockAsync()); + + writeReleaser.Dispose(); + Assert.Throws(() => writeReleaser.Dispose()); + } + + private static async Task LockAsync( + AsyncReaderWriterLock readWriteLock, + Random random, + LockStatistics lockStatistics, + System.Action checkLockAction, + Func canContinue, CancellationToken cancellationToken = default(CancellationToken)) + { + while (canContinue()) + { + var isRead = random.Next(100) < 80; + var releaser = isRead ? await (readWriteLock.ReadLockAsync()) : await (readWriteLock.WriteLockAsync()); + lock (readWriteLock) + { + if (isRead) + { + lockStatistics.ReadLockCount++; + } + else + { + lockStatistics.WriteLockCount++; + } + + checkLockAction(); + } + + await (Task.Delay(10, cancellationToken)); + + lock (readWriteLock) + { + releaser.Dispose(); + if (isRead) + { + lockStatistics.ReadLockCount--; + } + else + { + lockStatistics.WriteLockCount--; + } + + checkLockAction(); + } + } + } + + private static async Task AssertEqualValueAsync(Func getValueFunc, T value, Task task = null, int waitDelay = 5000, CancellationToken cancellationToken = default(CancellationToken)) + { + var currentTime = 0; + var step = 5; + while (currentTime < waitDelay) + { + if (task != null) + { + task.Wait(step); + } + else + { + await (Task.Delay(step, cancellationToken)); + } + + currentTime += step; + if (getValueFunc().Equals(value)) + { + return; + } + + step *= 2; + } + + Assert.That(getValueFunc(), Is.EqualTo(value)); + } + + private static Task AssertTaskCompletedAsync(Task task, CancellationToken cancellationToken = default(CancellationToken)) + { + return AssertEqualValueAsync(() => task.IsCompleted, true, task, cancellationToken: cancellationToken); + } + } +} diff --git a/src/NHibernate.Test/NHibernate.Test.csproj b/src/NHibernate.Test/NHibernate.Test.csproj index 365e1dcbf99..16ef3aab72e 100644 --- a/src/NHibernate.Test/NHibernate.Test.csproj +++ b/src/NHibernate.Test/NHibernate.Test.csproj @@ -44,6 +44,11 @@ + + + UtilityTest\AsyncReaderWriterLock.cs + + diff --git a/src/NHibernate.Test/UtilityTest/AsyncReaderWriterLockFixture.cs b/src/NHibernate.Test/UtilityTest/AsyncReaderWriterLockFixture.cs new file mode 100644 index 00000000000..b737b044def --- /dev/null +++ b/src/NHibernate.Test/UtilityTest/AsyncReaderWriterLockFixture.cs @@ -0,0 +1,475 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using NHibernate.Util; +using NUnit.Framework; + +namespace NHibernate.Test.UtilityTest +{ + public partial class AsyncReaderWriterLockFixture + { + [Test] + public void TestBlocking() + { + var l = new AsyncReaderWriterLock(); + for (var i = 0; i < 2; i++) + { + var readReleaser = l.ReadLock(); + Assert.That(l.CurrentReaders, Is.EqualTo(1)); + + var readReleaserTask = Task.Run(() => l.ReadLock()); + AssertEqualValue(() => l.CurrentReaders, 2, readReleaserTask); + AssertTaskCompleted(readReleaserTask); + + var writeReleaserTask = Task.Run(() => l.WriteLock()); + AssertEqualValue(() => l.AcquiredWriteLock, true, writeReleaserTask); + AssertEqualValue(() => l.Writing, false, writeReleaserTask); + Assert.That(writeReleaserTask.IsCompleted, Is.False); + + readReleaser.Dispose(); + Assert.That(l.CurrentReaders, Is.EqualTo(1)); + Assert.That(writeReleaserTask.IsCompleted, Is.False); + + readReleaserTask.Result.Dispose(); + Assert.That(l.CurrentReaders, Is.EqualTo(0)); + AssertEqualValue(() => l.Writing, true, writeReleaserTask); + AssertTaskCompleted(writeReleaserTask); + + readReleaserTask = Task.Run(() => l.ReadLock()); + AssertEqualValue(() => l.ReadersWaiting, 1, readReleaserTask); + Assert.That(readReleaserTask.IsCompleted, Is.False); + + var writeReleaserTask2 = Task.Run(() => l.WriteLock()); + AssertEqualValue(() => l.WritersWaiting, 1, writeReleaserTask2); + Assert.That(writeReleaserTask2.IsCompleted, Is.False); + + writeReleaserTask.Result.Dispose(); + AssertEqualValue(() => l.WritersWaiting, 0, writeReleaserTask2); + AssertEqualValue(() => l.Writing, true, writeReleaserTask2); + Assert.That(readReleaserTask.IsCompleted, Is.False); + AssertTaskCompleted(writeReleaserTask2); + + writeReleaserTask2.Result.Dispose(); + AssertEqualValue(() => l.Writing, false, writeReleaserTask2); + AssertEqualValue(() => l.ReadersWaiting, 0, readReleaserTask); + AssertEqualValue(() => l.CurrentReaders, 1, readReleaserTask); + AssertTaskCompleted(readReleaserTask); + + readReleaserTask.Result.Dispose(); + Assert.That(l.ReadersWaiting, Is.EqualTo(0)); + Assert.That(l.WritersWaiting, Is.EqualTo(0)); + Assert.That(l.CurrentReaders, Is.EqualTo(0)); + Assert.That(l.Writing, Is.False); + } + } + + [Test] + public void TestBlockingAsync() + { + var l = new AsyncReaderWriterLock(); + for (var i = 0; i < 2; i++) + { + var readReleaserTask = l.ReadLockAsync(); + AssertEqualValue(() => l.CurrentReaders, 1, readReleaserTask); + AssertTaskCompleted(readReleaserTask); + + var readReleaserTask2 = l.ReadLockAsync(); + AssertEqualValue(() => l.CurrentReaders, 2, readReleaserTask2); + AssertTaskCompleted(readReleaserTask2); + + var writeReleaserTask = l.WriteLockAsync(); + AssertEqualValue(() => l.AcquiredWriteLock, true, writeReleaserTask); + AssertEqualValue(() => l.Writing, false, writeReleaserTask); + Assert.That(writeReleaserTask.IsCompleted, Is.False); + + readReleaserTask.Result.Dispose(); + Assert.That(l.CurrentReaders, Is.EqualTo(1)); + Assert.That(writeReleaserTask.IsCompleted, Is.False); + + readReleaserTask2.Result.Dispose(); + Assert.That(l.CurrentReaders, Is.EqualTo(0)); + AssertEqualValue(() => l.Writing, true, writeReleaserTask); + AssertTaskCompleted(writeReleaserTask); + + readReleaserTask = l.ReadLockAsync(); + AssertEqualValue(() => l.ReadersWaiting, 1, readReleaserTask); + Assert.That(readReleaserTask.IsCompleted, Is.False); + + var writeReleaserTask2 = l.WriteLockAsync(); + AssertEqualValue(() => l.WritersWaiting, 1, writeReleaserTask2); + Assert.That(writeReleaserTask2.IsCompleted, Is.False); + + writeReleaserTask.Result.Dispose(); + AssertEqualValue(() => l.WritersWaiting, 0, writeReleaserTask2); + AssertEqualValue(() => l.Writing, true, writeReleaserTask2); + Assert.That(readReleaserTask.IsCompleted, Is.False); + AssertTaskCompleted(writeReleaserTask2); + + writeReleaserTask2.Result.Dispose(); + AssertEqualValue(() => l.Writing, false, writeReleaserTask2); + AssertEqualValue(() => l.ReadersWaiting, 0, readReleaserTask); + AssertEqualValue(() => l.CurrentReaders, 1, readReleaserTask); + AssertTaskCompleted(readReleaserTask); + + readReleaserTask.Result.Dispose(); + Assert.That(l.ReadersWaiting, Is.EqualTo(0)); + Assert.That(l.WritersWaiting, Is.EqualTo(0)); + Assert.That(l.CurrentReaders, Is.EqualTo(0)); + Assert.That(l.Writing, Is.False); + } + } + + [Test, Explicit] + public void TestConcurrentReadWrite() + { + var l = new AsyncReaderWriterLock(); + for (var i = 0; i < 2; i++) + { + var writeReleaser = l.WriteLock(); + Assert.That(l.Writing, Is.True); + + var secondWriteSemaphore = new SemaphoreSlim(0); + var secondWriteReleaser = default(AsyncReaderWriterLock.Releaser); + var secondWriteThread = new Thread( + () => + { + secondWriteSemaphore.Wait(); + secondWriteReleaser = l.WriteLock(); + }); + secondWriteThread.Priority = ThreadPriority.Highest; + secondWriteThread.Start(); + AssertEqualValue(() => secondWriteThread.ThreadState == ThreadState.WaitSleepJoin, true); + + var secondReadThreads = new Thread[20]; + var secondReadReleasers = new AsyncReaderWriterLock.Releaser[secondReadThreads.Length]; + var secondReadSemaphore = new SemaphoreSlim(0); + for (var j = 0; j < secondReadReleasers.Length; j++) + { + var index = j; + var thread = new Thread( + () => + { + secondReadSemaphore.Wait(); + secondReadReleasers[index] = l.ReadLock(); + }); + thread.Priority = ThreadPriority.Highest; + secondReadThreads[j] = thread; + thread.Start(); + } + + AssertEqualValue(() => secondReadThreads.All(o => o.ThreadState == ThreadState.WaitSleepJoin), true); + + var firstReadReleaserTasks = new Task[30]; + var firstReadStopSemaphore = new SemaphoreSlim(0); + for (var j = 0; j < firstReadReleaserTasks.Length; j++) + { + firstReadReleaserTasks[j] = Task.Run(() => + { + var releaser = l.ReadLock(); + firstReadStopSemaphore.Wait(); + releaser.Dispose(); + }); + } + + AssertEqualValue(() => l.ReadersWaiting, firstReadReleaserTasks.Length, waitDelay: 60000); + + writeReleaser.Dispose(); + secondWriteSemaphore.Release(); + secondReadSemaphore.Release(secondReadThreads.Length); + Thread.Sleep(1000); + firstReadStopSemaphore.Release(firstReadReleaserTasks.Length); + + AssertEqualValue(() => firstReadReleaserTasks.All(o => o.IsCompleted), true); + Assert.That(l.ReadersWaiting, Is.EqualTo(secondReadThreads.Length)); + Assert.That(l.CurrentReaders, Is.EqualTo(0)); + AssertEqualValue(() => secondWriteThread.IsAlive, false); + AssertEqualValue(() => secondReadThreads.All(o => o.IsAlive), true); + + secondWriteReleaser.Dispose(); + AssertEqualValue(() => secondReadThreads.All(o => !o.IsAlive), true); + + Assert.That(l.ReadersWaiting, Is.EqualTo(0)); + Assert.That(l.CurrentReaders, Is.EqualTo(secondReadThreads.Length)); + + foreach (var secondReadReleaser in secondReadReleasers) + { + secondReadReleaser.Dispose(); + } + + Assert.That(l.ReadersWaiting, Is.EqualTo(0)); + Assert.That(l.CurrentReaders, Is.EqualTo(0)); + } + } + + [Test] + public void TestInvaildExitReadLockUsage() + { + var l = new AsyncReaderWriterLock(); + var readReleaser = l.ReadLock(); + var readReleaser2 = l.ReadLock(); + + readReleaser.Dispose(); + readReleaser2.Dispose(); + Assert.Throws(() => readReleaser.Dispose()); + Assert.Throws(() => readReleaser2.Dispose()); + } + + [Test] + public void TestOperationAfterDispose() + { + var l = new AsyncReaderWriterLock(); + l.Dispose(); + + Assert.Throws(() => l.ReadLock()); + Assert.Throws(() => l.WriteLock()); + } + + [Test] + public void TestInvaildExitWriteLockUsage() + { + var l = new AsyncReaderWriterLock(); + var writeReleaser = l.WriteLock(); + + writeReleaser.Dispose(); + Assert.Throws(() => writeReleaser.Dispose()); + } + + [Test] + public void TestMixingSyncAndAsync() + { + var l = new AsyncReaderWriterLock(); + var readReleaser = l.ReadLock(); + Assert.That(l.CurrentReaders, Is.EqualTo(1)); + + var readReleaserTask = l.ReadLockAsync(); + AssertEqualValue(() => l.CurrentReaders, 2, readReleaserTask); + AssertTaskCompleted(readReleaserTask); + + readReleaser.Dispose(); + Assert.That(l.CurrentReaders, Is.EqualTo(1)); + + readReleaserTask.Result.Dispose(); + Assert.That(l.CurrentReaders, Is.EqualTo(0)); + + var writeReleaser = l.WriteLock(); + Assert.That(l.AcquiredWriteLock, Is.True); + + var writeReleaserTask = l.WriteLockAsync(); + AssertEqualValue(() => l.WritersWaiting, 1, writeReleaserTask); + Assert.That(writeReleaserTask.IsCompleted, Is.False); + + readReleaserTask = Task.Run(() => l.ReadLock()); + AssertEqualValue(() => l.ReadersWaiting, 1, readReleaserTask); + Assert.That(readReleaserTask.IsCompleted, Is.False); + + var readReleaserTask2 = l.ReadLockAsync(); + AssertEqualValue(() => l.ReadersWaiting, 2, readReleaserTask2); + Assert.That(readReleaserTask2.IsCompleted, Is.False); + + writeReleaser.Dispose(); + AssertEqualValue(() => l.WritersWaiting, 0, writeReleaserTask); + AssertEqualValue(() => l.Writing, true, writeReleaserTask); + AssertTaskCompleted(writeReleaserTask); + Assert.That(readReleaserTask.IsCompleted, Is.False); + Assert.That(readReleaserTask2.IsCompleted, Is.False); + + writeReleaserTask.Result.Dispose(); + AssertEqualValue(() => l.CurrentReaders, 2, readReleaserTask); + AssertEqualValue(() => l.ReadersWaiting, 0, readReleaserTask2); + AssertTaskCompleted(readReleaserTask); + AssertTaskCompleted(readReleaserTask2); + } + + [Test] + public void TestWritePriorityOverReadAsync() + { + var l = new AsyncReaderWriterLock(); + for (var i = 0; i < 2; i++) + { + var writeReleaser = l.WriteLock(); + Assert.That(l.AcquiredWriteLock, Is.True); + + var readReleaserTask = l.ReadLockAsync(); + AssertEqualValue(() => l.ReadersWaiting, 1, readReleaserTask); + + var writeReleaserTask = l.WriteLockAsync(); + AssertEqualValue(() => l.WritersWaiting, 1, writeReleaserTask); + + writeReleaser.Dispose(); + AssertEqualValue(() => l.WritersWaiting, 0, writeReleaserTask); + AssertEqualValue(() => l.ReadersWaiting, 1, readReleaserTask); + AssertTaskCompleted(writeReleaserTask); + + writeReleaserTask.Result.Dispose(); + AssertEqualValue(() => l.ReadersWaiting, 0, readReleaserTask); + AssertTaskCompleted(readReleaserTask); + + readReleaserTask.Result.Dispose(); + } + } + + [Test] + public void TestPartialReleasingReadLockAsync() + { + var l = new AsyncReaderWriterLock(); + var readReleaserTask = l.ReadLockAsync(); + AssertEqualValue(() => l.CurrentReaders, 1, readReleaserTask); + AssertTaskCompleted(readReleaserTask); + + var readReleaserTask2 = l.ReadLockAsync(); + AssertEqualValue(() => l.CurrentReaders, 2, readReleaserTask); + AssertTaskCompleted(readReleaserTask2); + + var writeReleaserTask = l.WriteLockAsync(); + AssertEqualValue(() => l.AcquiredWriteLock, true, writeReleaserTask); + AssertEqualValue(() => l.Writing, false, writeReleaserTask); + Assert.That(writeReleaserTask.IsCompleted, Is.False); + + var readReleaserTask3 = l.ReadLockAsync(); + AssertEqualValue(() => l.ReadersWaiting, 1, readReleaserTask3); + Assert.That(readReleaserTask3.IsCompleted, Is.False); + + readReleaserTask.Result.Dispose(); + Assert.That(writeReleaserTask.IsCompleted, Is.False); + Assert.That(readReleaserTask3.IsCompleted, Is.False); + + readReleaserTask2.Result.Dispose(); + AssertEqualValue(() => l.Writing, true, writeReleaserTask); + AssertEqualValue(() => l.ReadersWaiting, 1, readReleaserTask3); + AssertTaskCompleted(writeReleaserTask); + Assert.That(readReleaserTask3.IsCompleted, Is.False); + + writeReleaserTask.Result.Dispose(); + AssertEqualValue(() => l.ReadersWaiting, 0, readReleaserTask3); + AssertTaskCompleted(readReleaserTask3); + } + + [Test, Explicit] + public async Task TestLoadSyncAndAsync() + { + var l = new AsyncReaderWriterLock(); + var lockStatistics = new LockStatistics(); + var incorrectLockCount = false; + var tasks = new Task[20]; + var masterRandom = new Random(); + var cancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + + for (var i = 0; i < tasks.Length; i++) + { + // Ensure that each random has its own unique seed + var random = new Random(masterRandom.Next()); + tasks[i] = i % 2 == 0 + ? Task.Run(() => Lock(l, random, lockStatistics, CheckLockCount, CanContinue)) + : LockAsync(l, random, lockStatistics, CheckLockCount, CanContinue); + } + + await Task.WhenAll(tasks); + Assert.That(incorrectLockCount, Is.False); + + void CheckLockCount() + { + if (!lockStatistics.Validate()) + { + Volatile.Write(ref incorrectLockCount, true); + } + } + + bool CanContinue() + { + return !cancellationTokenSource.Token.IsCancellationRequested; + } + } + + private class LockStatistics + { + public int ReadLockCount { get; set; } + + public int WriteLockCount { get; set; } + + public bool Validate() + { + return (ReadLockCount == 0 && WriteLockCount == 0) || + (ReadLockCount > 0 && WriteLockCount == 0) || + (ReadLockCount == 0 && WriteLockCount == 1); + } + } + + private static void Lock( + AsyncReaderWriterLock readWriteLock, + Random random, + LockStatistics lockStatistics, + System.Action checkLockAction, + Func canContinue) + { + while (canContinue()) + { + var isRead = random.Next(100) < 80; + var releaser = isRead ? readWriteLock.ReadLock() : readWriteLock.WriteLock(); + lock (readWriteLock) + { + if (isRead) + { + lockStatistics.ReadLockCount++; + } + else + { + lockStatistics.WriteLockCount++; + } + + checkLockAction(); + } + + Thread.Sleep(10); + + lock (readWriteLock) + { + releaser.Dispose(); + if (isRead) + { + lockStatistics.ReadLockCount--; + } + else + { + lockStatistics.WriteLockCount--; + } + + checkLockAction(); + } + } + } + + private static void AssertEqualValue(Func getValueFunc, T value, Task task = null, int waitDelay = 5000) + { + var currentTime = 0; + var step = 5; + while (currentTime < waitDelay) + { + if (task != null) + { + task.Wait(step); + } + else + { + Thread.Sleep(step); + } + + currentTime += step; + if (getValueFunc().Equals(value)) + { + return; + } + + step *= 2; + } + + Assert.That(getValueFunc(), Is.EqualTo(value)); + } + + private static void AssertTaskCompleted(Task task) + { + AssertEqualValue(() => task.IsCompleted, true, task); + } + } +} diff --git a/src/NHibernate/Async/Cache/ReadWriteCache.cs b/src/NHibernate/Async/Cache/ReadWriteCache.cs index 326e344bbc4..eac3b2bc339 100644 --- a/src/NHibernate/Async/Cache/ReadWriteCache.cs +++ b/src/NHibernate/Async/Cache/ReadWriteCache.cs @@ -12,6 +12,7 @@ using System.Collections.Generic; using System.Linq; using NHibernate.Cache.Access; +using NHibernate.Util; namespace NHibernate.Cache { @@ -19,7 +20,6 @@ namespace NHibernate.Cache using System.Threading; public partial class ReadWriteCache : IBatchableCacheConcurrencyStrategy { - private readonly NHibernate.Util.AsyncLock _lockObjectAsync = new NHibernate.Util.AsyncLock(); /// /// Do not return an item whose timestamp is later than the current @@ -41,7 +41,7 @@ public partial class ReadWriteCache : IBatchableCacheConcurrencyStrategy public async Task GetAsync(CacheKey key, long txTimestamp, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _lockObjectAsync.LockAsync()) + using (await (_asyncReaderWriterLock.ReadLockAsync()).ConfigureAwait(false)) { if (log.IsDebugEnabled()) { @@ -70,7 +70,8 @@ public async Task GetManyAsync(CacheKey[] keys, long timestamp, Cancel log.Debug("Cache lookup: {0}", string.Join(",", keys.AsEnumerable())); } var result = new object[keys.Length]; - using (await _lockObjectAsync.LockAsync()) + cancellationToken.ThrowIfCancellationRequested(); + using (await (_asyncReaderWriterLock.ReadLockAsync()).ConfigureAwait(false)) { var lockables = await (_cache.GetManyAsync(keys, cancellationToken)).ConfigureAwait(false); for (var i = 0; i < lockables.Length; i++) @@ -92,7 +93,7 @@ public async Task GetManyAsync(CacheKey[] keys, long timestamp, Cancel public async Task LockAsync(CacheKey key, object version, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _lockObjectAsync.LockAsync()) + using (await (_asyncReaderWriterLock.WriteLockAsync()).ConfigureAwait(false)) { if (log.IsDebugEnabled()) { @@ -135,8 +136,9 @@ public async Task PutManyAsync( // MinValue means cache is disabled return result; } + cancellationToken.ThrowIfCancellationRequested(); - using (await _lockObjectAsync.LockAsync()) + using (await (_asyncReaderWriterLock.WriteLockAsync()).ConfigureAwait(false)) { if (log.IsDebugEnabled()) { @@ -205,8 +207,9 @@ public async Task PutAsync(CacheKey key, object value, long txTimestamp, o // MinValue means cache is disabled return false; } + cancellationToken.ThrowIfCancellationRequested(); - using (await _lockObjectAsync.LockAsync()) + using (await (_asyncReaderWriterLock.WriteLockAsync()).ConfigureAwait(false)) { if (log.IsDebugEnabled()) { @@ -270,7 +273,7 @@ private Task DecrementLockAsync(object key, CacheLock @lock, CancellationToken c public async Task ReleaseAsync(CacheKey key, ISoftLock clientLock, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _lockObjectAsync.LockAsync()) + using (await (_asyncReaderWriterLock.WriteLockAsync()).ConfigureAwait(false)) { if (log.IsDebugEnabled()) { @@ -343,7 +346,7 @@ public Task RemoveAsync(CacheKey key, CancellationToken cancellationToken) public async Task AfterUpdateAsync(CacheKey key, object value, object version, ISoftLock clientLock, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _lockObjectAsync.LockAsync()) + using (await (_asyncReaderWriterLock.WriteLockAsync()).ConfigureAwait(false)) { if (log.IsDebugEnabled()) { @@ -390,7 +393,7 @@ public async Task AfterUpdateAsync(CacheKey key, object value, object vers public async Task AfterInsertAsync(CacheKey key, object value, object version, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _lockObjectAsync.LockAsync()) + using (await (_asyncReaderWriterLock.WriteLockAsync()).ConfigureAwait(false)) { if (log.IsDebugEnabled()) { diff --git a/src/NHibernate/Async/Cache/UpdateTimestampsCache.cs b/src/NHibernate/Async/Cache/UpdateTimestampsCache.cs index f97f25401be..774c25bec75 100644 --- a/src/NHibernate/Async/Cache/UpdateTimestampsCache.cs +++ b/src/NHibernate/Async/Cache/UpdateTimestampsCache.cs @@ -22,10 +22,6 @@ namespace NHibernate.Cache using System.Threading; public partial class UpdateTimestampsCache { - private readonly NHibernate.Util.AsyncLock _preInvalidate = new NHibernate.Util.AsyncLock(); - private readonly NHibernate.Util.AsyncLock _invalidate = new NHibernate.Util.AsyncLock(); - private readonly NHibernate.Util.AsyncLock _isUpToDate = new NHibernate.Util.AsyncLock(); - private readonly NHibernate.Util.AsyncLock _areUpToDate = new NHibernate.Util.AsyncLock(); public virtual Task ClearAsync(CancellationToken cancellationToken) { @@ -55,20 +51,20 @@ public Task PreInvalidateAsync(object[] spaces, CancellationToken cancellationTo } } - [MethodImpl()] public virtual async Task PreInvalidateAsync(IReadOnlyCollection spaces, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _preInvalidate.LockAsync()) + if (spaces.Count == 0) + return; + cancellationToken.ThrowIfCancellationRequested(); + + using (await (_asyncReaderWriterLock.WriteLockAsync()).ConfigureAwait(false)) { //TODO: to handle concurrent writes correctly, this should return a Lock to the client var ts = _updateTimestamps.NextTimestamp() + _updateTimestamps.Timeout; await (SetSpacesTimestampAsync(spaces, ts, cancellationToken)).ConfigureAwait(false); - //TODO: return new Lock(ts); } - - //TODO: return new Lock(ts); } //Since v5.1 @@ -90,11 +86,14 @@ public Task InvalidateAsync(object[] spaces, CancellationToken cancellationToken } } - [MethodImpl()] public virtual async Task InvalidateAsync(IReadOnlyCollection spaces, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _invalidate.LockAsync()) + if (spaces.Count == 0) + return; + cancellationToken.ThrowIfCancellationRequested(); + + using (await (_asyncReaderWriterLock.WriteLockAsync()).ConfigureAwait(false)) { //TODO: to handle concurrent writes correctly, the client should pass in a Lock long ts = _updateTimestamps.NextTimestamp(); @@ -113,9 +112,6 @@ private Task SetSpacesTimestampAsync(IReadOnlyCollection spaces, long ts } try { - if (spaces.Count == 0) - return Task.CompletedTask; - return _updateTimestamps.PutManyAsync( spaces.ToArray(), ArrayHelper.Fill(ts, spaces.Count), cancellationToken); @@ -126,45 +122,45 @@ private Task SetSpacesTimestampAsync(IReadOnlyCollection spaces, long ts } } - [MethodImpl()] public virtual async Task IsUpToDateAsync(ISet spaces, long timestamp /* H2.1 has Long here */, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _isUpToDate.LockAsync()) - { - if (spaces.Count == 0) - return true; + if (spaces.Count == 0) + return true; + cancellationToken.ThrowIfCancellationRequested(); + using (await (_asyncReaderWriterLock.ReadLockAsync()).ConfigureAwait(false)) + { var lastUpdates = await (_updateTimestamps.GetManyAsync(spaces.ToArray(), cancellationToken)).ConfigureAwait(false); return lastUpdates.All(lastUpdate => !IsOutdated(lastUpdate as long?, timestamp)); } } - [MethodImpl()] public virtual async Task AreUpToDateAsync(ISet[] spaces, long[] timestamps, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _areUpToDate.LockAsync()) - { - if (spaces.Length == 0) - return Array.Empty(); + if (spaces.Length == 0) + return Array.Empty(); - var allSpaces = new HashSet(); - foreach (var sp in spaces) - { - allSpaces.UnionWith(sp); - } + var allSpaces = new HashSet(); + foreach (var sp in spaces) + { + allSpaces.UnionWith(sp); + } - if (allSpaces.Count == 0) - return ArrayHelper.Fill(true, spaces.Length); + if (allSpaces.Count == 0) + return ArrayHelper.Fill(true, spaces.Length); - var keys = allSpaces.ToArray(); + var keys = allSpaces.ToArray(); + cancellationToken.ThrowIfCancellationRequested(); + using (await (_asyncReaderWriterLock.ReadLockAsync()).ConfigureAwait(false)) + { var index = 0; var lastUpdatesBySpace = - (await (_updateTimestamps - .GetManyAsync(keys, cancellationToken)).ConfigureAwait(false)) - .ToDictionary(u => keys[index++], u => u as long?); + (await (_updateTimestamps + .GetManyAsync(keys, cancellationToken)).ConfigureAwait(false)) + .ToDictionary(u => keys[index++], u => u as long?); var results = new bool[spaces.Length]; for (var i = 0; i < spaces.Length; i++) diff --git a/src/NHibernate/Async/Id/Enhanced/OptimizerFactory.cs b/src/NHibernate/Async/Id/Enhanced/OptimizerFactory.cs index 63f63c9ff6c..5ba08f4a31e 100644 --- a/src/NHibernate/Async/Id/Enhanced/OptimizerFactory.cs +++ b/src/NHibernate/Async/Id/Enhanced/OptimizerFactory.cs @@ -24,13 +24,11 @@ public partial class OptimizerFactory public partial class HiLoOptimizer : OptimizerSupport { - private readonly NHibernate.Util.AsyncLock _generate = new NHibernate.Util.AsyncLock(); - [MethodImpl()] public override async Task GenerateAsync(IAccessCallback callback, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _generate.LockAsync()) + using (await (_asyncLock.LockAsync()).ConfigureAwait(false)) { if (_lastSourceValue < 0) { @@ -51,6 +49,7 @@ public override async Task GenerateAsync(IAccessCallback callback, Cance _lastSourceValue = await (callback.GetNextValueAsync(cancellationToken)).ConfigureAwait(false); _upperLimit = (_lastSourceValue * IncrementSize) + 1; } + return Make(_value++); } } @@ -101,13 +100,11 @@ public abstract partial class OptimizerSupport : IOptimizer public partial class PooledOptimizer : OptimizerSupport, IInitialValueAwareOptimizer { - private readonly NHibernate.Util.AsyncLock _generate = new NHibernate.Util.AsyncLock(); - [MethodImpl()] public override async Task GenerateAsync(IAccessCallback callback, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _generate.LockAsync()) + using (await (_asyncLock.LockAsync()).ConfigureAwait(false)) { if (_hiValue < 0) { @@ -134,6 +131,7 @@ public override async Task GenerateAsync(IAccessCallback callback, Cance _hiValue = await (callback.GetNextValueAsync(cancellationToken)).ConfigureAwait(false); _value = _hiValue - IncrementSize; } + return Make(_value++); } } @@ -145,13 +143,11 @@ public override async Task GenerateAsync(IAccessCallback callback, Cance public partial class PooledLoOptimizer : OptimizerSupport { - private readonly NHibernate.Util.AsyncLock _generate = new NHibernate.Util.AsyncLock(); - [MethodImpl()] public override async Task GenerateAsync(IAccessCallback callback, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _generate.LockAsync()) + using (await (_asyncLock.LockAsync()).ConfigureAwait(false)) { if (_lastSourceValue < 0 || _value >= (_lastSourceValue + IncrementSize)) { @@ -161,6 +157,7 @@ public override async Task GenerateAsync(IAccessCallback callback, Cance while (_value < 1) _value++; } + return Make(_value++); } } diff --git a/src/NHibernate/Async/Id/Enhanced/TableGenerator.cs b/src/NHibernate/Async/Id/Enhanced/TableGenerator.cs index 3dd339de624..0ca2de00611 100644 --- a/src/NHibernate/Async/Id/Enhanced/TableGenerator.cs +++ b/src/NHibernate/Async/Id/Enhanced/TableGenerator.cs @@ -26,13 +26,11 @@ namespace NHibernate.Id.Enhanced using System.Threading; public partial class TableGenerator : TransactionHelper, IPersistentIdentifierGenerator, IConfigurable { - private readonly NHibernate.Util.AsyncLock _generate = new NHibernate.Util.AsyncLock(); - [MethodImpl()] public virtual async Task GenerateAsync(ISessionImplementor session, object obj, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _generate.LockAsync()) + using (await (_asyncLock.LockAsync()).ConfigureAwait(false)) { return await (Optimizer.GenerateAsync(new TableAccessCallback(session, this), cancellationToken)).ConfigureAwait(false); } diff --git a/src/NHibernate/Async/Id/IncrementGenerator.cs b/src/NHibernate/Async/Id/IncrementGenerator.cs index 0fd915d15cf..4df097d6624 100644 --- a/src/NHibernate/Async/Id/IncrementGenerator.cs +++ b/src/NHibernate/Async/Id/IncrementGenerator.cs @@ -19,6 +19,7 @@ using NHibernate.SqlCommand; using NHibernate.SqlTypes; using NHibernate.Type; +using NHibernate.Util; namespace NHibernate.Id { @@ -26,7 +27,6 @@ namespace NHibernate.Id using System.Threading; public partial class IncrementGenerator : IIdentifierGenerator, IConfigurable { - private readonly NHibernate.Util.AsyncLock _generate = new NHibernate.Util.AsyncLock(); /// /// @@ -35,16 +35,16 @@ public partial class IncrementGenerator : IIdentifierGenerator, IConfigurable /// /// A cancellation token that can be used to cancel the work /// - [MethodImpl()] public async Task GenerateAsync(ISessionImplementor session, object obj, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _generate.LockAsync()) + using (await (_asyncLock.LockAsync()).ConfigureAwait(false)) { if (_sql != null) { await (GetNextAsync(session, cancellationToken)).ConfigureAwait(false); } + return IdentifierGeneratorFactory.CreateNumber(_next++, _returnClass); } } diff --git a/src/NHibernate/Async/Id/SequenceHiLoGenerator.cs b/src/NHibernate/Async/Id/SequenceHiLoGenerator.cs index 94ee6d72da5..75992f456ed 100644 --- a/src/NHibernate/Async/Id/SequenceHiLoGenerator.cs +++ b/src/NHibernate/Async/Id/SequenceHiLoGenerator.cs @@ -23,7 +23,6 @@ namespace NHibernate.Id using System.Threading; public partial class SequenceHiLoGenerator : SequenceGenerator { - private readonly NHibernate.Util.AsyncLock _generate = new NHibernate.Util.AsyncLock(); #region IIdentifierGenerator Members @@ -35,11 +34,10 @@ public partial class SequenceHiLoGenerator : SequenceGenerator /// The entity for which the id is being generated. /// A cancellation token that can be used to cancel the work /// The new identifier as a , , or . - [MethodImpl()] public override async Task GenerateAsync(ISessionImplementor session, object obj, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _generate.LockAsync()) + using (await (_asyncLock.LockAsync()).ConfigureAwait(false)) { if (maxLo < 1) { diff --git a/src/NHibernate/Async/Id/TableGenerator.cs b/src/NHibernate/Async/Id/TableGenerator.cs index b2731653a92..3ad468a09be 100644 --- a/src/NHibernate/Async/Id/TableGenerator.cs +++ b/src/NHibernate/Async/Id/TableGenerator.cs @@ -29,7 +29,6 @@ namespace NHibernate.Id using System.Threading; public partial class TableGenerator : TransactionHelper, IPersistentIdentifierGenerator, IConfigurable { - private readonly NHibernate.Util.AsyncLock _generate = new NHibernate.Util.AsyncLock(); #region IIdentifierGenerator Members @@ -41,11 +40,10 @@ public partial class TableGenerator : TransactionHelper, IPersistentIdentifierGe /// The entity for which the id is being generated. /// A cancellation token that can be used to cancel the work /// The new identifier as a , , or . - [MethodImpl()] public virtual async Task GenerateAsync(ISessionImplementor session, object obj, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _generate.LockAsync()) + using (await (_asyncLock.LockAsync()).ConfigureAwait(false)) { // This has to be done using a different connection to the containing // transaction becase the new hi value must remain valid even if the diff --git a/src/NHibernate/Async/Id/TableHiLoGenerator.cs b/src/NHibernate/Async/Id/TableHiLoGenerator.cs index 8302dad5e98..663733f6b39 100644 --- a/src/NHibernate/Async/Id/TableHiLoGenerator.cs +++ b/src/NHibernate/Async/Id/TableHiLoGenerator.cs @@ -23,7 +23,6 @@ namespace NHibernate.Id using System.Threading; public partial class TableHiLoGenerator : TableGenerator { - private readonly NHibernate.Util.AsyncLock _generate = new NHibernate.Util.AsyncLock(); #region IIdentifierGenerator Members @@ -34,11 +33,10 @@ public partial class TableHiLoGenerator : TableGenerator /// The entity for which the id is being generated. /// A cancellation token that can be used to cancel the work /// The new identifier as a . - [MethodImpl()] public override async Task GenerateAsync(ISessionImplementor session, object obj, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); - using (await _generate.LockAsync()) + using (await (_asyncLock.LockAsync()).ConfigureAwait(false)) { if (maxLo < 1) { diff --git a/src/NHibernate/Cache/ReadWriteCache.cs b/src/NHibernate/Cache/ReadWriteCache.cs index 2bf891f3068..9bb25e51048 100644 --- a/src/NHibernate/Cache/ReadWriteCache.cs +++ b/src/NHibernate/Cache/ReadWriteCache.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using NHibernate.Cache.Access; +using NHibernate.Util; namespace NHibernate.Cache { @@ -33,9 +34,9 @@ public interface ILockable private static readonly INHibernateLogger log = NHibernateLogger.For(typeof(ReadWriteCache)); - private readonly object _lockObject = new object(); private CacheBase _cache; private int _nextLockId; + private readonly AsyncReaderWriterLock _asyncReaderWriterLock = new AsyncReaderWriterLock(); /// /// Gets the cache region name. @@ -95,7 +96,7 @@ private int NextLockId() /// public object Get(CacheKey key, long txTimestamp) { - lock (_lockObject) + using (_asyncReaderWriterLock.ReadLock()) { if (log.IsDebugEnabled()) { @@ -123,7 +124,7 @@ public object[] GetMany(CacheKey[] keys, long timestamp) log.Debug("Cache lookup: {0}", string.Join(",", keys.AsEnumerable())); } var result = new object[keys.Length]; - lock (_lockObject) + using (_asyncReaderWriterLock.ReadLock()) { var lockables = _cache.GetMany(keys); for (var i = 0; i < lockables.Length; i++) @@ -166,7 +167,7 @@ private static object GetValue(long timestamp, CacheKey key, ILockable lockable) /// public ISoftLock Lock(CacheKey key, object version) { - lock (_lockObject) + using (_asyncReaderWriterLock.WriteLock()) { if (log.IsDebugEnabled()) { @@ -209,7 +210,7 @@ public bool[] PutMany( return result; } - lock (_lockObject) + using (_asyncReaderWriterLock.WriteLock()) { if (log.IsDebugEnabled()) { @@ -278,7 +279,7 @@ public bool Put(CacheKey key, object value, long txTimestamp, object version, IC return false; } - lock (_lockObject) + using (_asyncReaderWriterLock.WriteLock()) { if (log.IsDebugEnabled()) { @@ -330,7 +331,7 @@ private void DecrementLock(object key, CacheLock @lock) public void Release(CacheKey key, ISoftLock clientLock) { - lock (_lockObject) + using (_asyncReaderWriterLock.WriteLock()) { if (log.IsDebugEnabled()) { @@ -382,6 +383,7 @@ public void Destroy() // The cache is externally provided and may be shared. Destroying the cache is // not the responsibility of this class. Cache = null; + _asyncReaderWriterLock.Dispose(); } /// @@ -390,7 +392,7 @@ public void Destroy() /// public bool AfterUpdate(CacheKey key, object value, object version, ISoftLock clientLock) { - lock (_lockObject) + using (_asyncReaderWriterLock.WriteLock()) { if (log.IsDebugEnabled()) { @@ -436,7 +438,7 @@ public bool AfterUpdate(CacheKey key, object value, object version, ISoftLock cl public bool AfterInsert(CacheKey key, object value, object version) { - lock (_lockObject) + using (_asyncReaderWriterLock.WriteLock()) { if (log.IsDebugEnabled()) { diff --git a/src/NHibernate/Cache/UpdateTimestampsCache.cs b/src/NHibernate/Cache/UpdateTimestampsCache.cs index 40369e4ac97..f6851f5ed44 100644 --- a/src/NHibernate/Cache/UpdateTimestampsCache.cs +++ b/src/NHibernate/Cache/UpdateTimestampsCache.cs @@ -19,6 +19,7 @@ public partial class UpdateTimestampsCache { private static readonly INHibernateLogger log = NHibernateLogger.For(typeof(UpdateTimestampsCache)); private readonly CacheBase _updateTimestamps; + private readonly AsyncReaderWriterLock _asyncReaderWriterLock = new AsyncReaderWriterLock(); public virtual void Clear() { @@ -54,14 +55,18 @@ public void PreInvalidate(object[] spaces) PreInvalidate(spaces.OfType().ToList()); } - [MethodImpl(MethodImplOptions.Synchronized)] public virtual void PreInvalidate(IReadOnlyCollection spaces) { - //TODO: to handle concurrent writes correctly, this should return a Lock to the client - var ts = _updateTimestamps.NextTimestamp() + _updateTimestamps.Timeout; - SetSpacesTimestamp(spaces, ts); + if (spaces.Count == 0) + return; - //TODO: return new Lock(ts); + using (_asyncReaderWriterLock.WriteLock()) + { + //TODO: to handle concurrent writes correctly, this should return a Lock to the client + var ts = _updateTimestamps.NextTimestamp() + _updateTimestamps.Timeout; + SetSpacesTimestamp(spaces, ts); + //TODO: return new Lock(ts); + } } //Since v5.1 @@ -72,38 +77,41 @@ public void Invalidate(object[] spaces) Invalidate(spaces.OfType().ToList()); } - [MethodImpl(MethodImplOptions.Synchronized)] public virtual void Invalidate(IReadOnlyCollection spaces) { - //TODO: to handle concurrent writes correctly, the client should pass in a Lock - long ts = _updateTimestamps.NextTimestamp(); - //TODO: if lock.getTimestamp().equals(ts) - if (log.IsDebugEnabled()) - log.Debug("Invalidating spaces [{0}]", StringHelper.CollectionToString(spaces)); - SetSpacesTimestamp(spaces, ts); + if (spaces.Count == 0) + return; + + using (_asyncReaderWriterLock.WriteLock()) + { + //TODO: to handle concurrent writes correctly, the client should pass in a Lock + long ts = _updateTimestamps.NextTimestamp(); + //TODO: if lock.getTimestamp().equals(ts) + if (log.IsDebugEnabled()) + log.Debug("Invalidating spaces [{0}]", StringHelper.CollectionToString(spaces)); + SetSpacesTimestamp(spaces, ts); + } } private void SetSpacesTimestamp(IReadOnlyCollection spaces, long ts) { - if (spaces.Count == 0) - return; - _updateTimestamps.PutMany( spaces.ToArray(), ArrayHelper.Fill(ts, spaces.Count)); } - [MethodImpl(MethodImplOptions.Synchronized)] public virtual bool IsUpToDate(ISet spaces, long timestamp /* H2.1 has Long here */) { if (spaces.Count == 0) return true; - var lastUpdates = _updateTimestamps.GetMany(spaces.ToArray()); - return lastUpdates.All(lastUpdate => !IsOutdated(lastUpdate as long?, timestamp)); + using (_asyncReaderWriterLock.ReadLock()) + { + var lastUpdates = _updateTimestamps.GetMany(spaces.ToArray()); + return lastUpdates.All(lastUpdate => !IsOutdated(lastUpdate as long?, timestamp)); + } } - [MethodImpl(MethodImplOptions.Synchronized)] public virtual bool[] AreUpToDate(ISet[] spaces, long[] timestamps) { if (spaces.Length == 0) @@ -120,20 +128,23 @@ public virtual bool[] AreUpToDate(ISet[] spaces, long[] timestamps) var keys = allSpaces.ToArray(); - var index = 0; - var lastUpdatesBySpace = - _updateTimestamps - .GetMany(keys) - .ToDictionary(u => keys[index++], u => u as long?); - - var results = new bool[spaces.Length]; - for (var i = 0; i < spaces.Length; i++) + using (_asyncReaderWriterLock.ReadLock()) { - var timestamp = timestamps[i]; - results[i] = spaces[i].All(space => !IsOutdated(lastUpdatesBySpace[space], timestamp)); - } + var index = 0; + var lastUpdatesBySpace = + _updateTimestamps + .GetMany(keys) + .ToDictionary(u => keys[index++], u => u as long?); + + var results = new bool[spaces.Length]; + for (var i = 0; i < spaces.Length; i++) + { + var timestamp = timestamps[i]; + results[i] = spaces[i].All(space => !IsOutdated(lastUpdatesBySpace[space], timestamp)); + } - return results; + return results; + } } // Since v5.3 @@ -142,6 +153,7 @@ public virtual void Destroy() { // The cache is externally provided and may be shared. Destroying the cache is // not the responsibility of this class. + _asyncReaderWriterLock.Dispose(); } private static bool IsOutdated(long? lastUpdate, long timestamp) diff --git a/src/NHibernate/Id/Enhanced/OptimizerFactory.cs b/src/NHibernate/Id/Enhanced/OptimizerFactory.cs index 0a410f7fda2..0adf3695551 100644 --- a/src/NHibernate/Id/Enhanced/OptimizerFactory.cs +++ b/src/NHibernate/Id/Enhanced/OptimizerFactory.cs @@ -101,6 +101,7 @@ public partial class HiLoOptimizer : OptimizerSupport private long _upperLimit; private long _lastSourceValue = -1; private long _value; + private readonly AsyncLock _asyncLock = new AsyncLock(); public HiLoOptimizer(System.Type returnClass, int incrementSize) : base(returnClass, incrementSize) { @@ -140,29 +141,32 @@ public override bool ApplyIncrementSizeToSourceValues get { return false; } } - [MethodImpl(MethodImplOptions.Synchronized)] public override object Generate(IAccessCallback callback) { - if (_lastSourceValue < 0) + using (_asyncLock.Lock()) { - _lastSourceValue = callback.GetNextValue(); - while (_lastSourceValue <= 0) + if (_lastSourceValue < 0) { _lastSourceValue = callback.GetNextValue(); - } + while (_lastSourceValue <= 0) + { + _lastSourceValue = callback.GetNextValue(); + } - // upperLimit defines the upper end of the bucket values - _upperLimit = (_lastSourceValue * IncrementSize) + 1; + // upperLimit defines the upper end of the bucket values + _upperLimit = (_lastSourceValue * IncrementSize) + 1; - // initialize value to the low end of the bucket - _value = _upperLimit - IncrementSize; - } - else if (_upperLimit <= _value) - { - _lastSourceValue = callback.GetNextValue(); - _upperLimit = (_lastSourceValue * IncrementSize) + 1; + // initialize value to the low end of the bucket + _value = _upperLimit - IncrementSize; + } + else if (_upperLimit <= _value) + { + _lastSourceValue = callback.GetNextValue(); + _upperLimit = (_lastSourceValue * IncrementSize) + 1; + } + + return Make(_value++); } - return Make(_value++); } } @@ -267,6 +271,7 @@ public partial class PooledOptimizer : OptimizerSupport, IInitialValueAwareOptim private long _hiValue = -1; private long _value; private long _initialValue; + private readonly AsyncLock _asyncLock = new AsyncLock(); public PooledOptimizer(System.Type returnClass, int incrementSize) : base(returnClass, incrementSize) { @@ -303,35 +308,38 @@ public void InjectInitialValue(long initialValue) _initialValue = initialValue; } - [MethodImpl(MethodImplOptions.Synchronized)] public override object Generate(IAccessCallback callback) { - if (_hiValue < 0) + using (_asyncLock.Lock()) { - _value = callback.GetNextValue(); - if (_value < 1) + if (_hiValue < 0) { - // unfortunately not really safe to normalize this - // to 1 as an initial value like we do the others - // because we would not be able to control this if - // we are using a sequence... - Log.Info("pooled optimizer source reported [{0}] as the initial value; use of 1 or greater highly recommended", _value); + _value = callback.GetNextValue(); + if (_value < 1) + { + // unfortunately not really safe to normalize this + // to 1 as an initial value like we do the others + // because we would not be able to control this if + // we are using a sequence... + Log.Info("pooled optimizer source reported [{0}] as the initial value; use of 1 or greater highly recommended", _value); + } + + if ((_initialValue == -1 && _value < IncrementSize) || _value == _initialValue) + _hiValue = callback.GetNextValue(); + else + { + _hiValue = _value; + _value = _hiValue - IncrementSize; + } } - - if ((_initialValue == -1 && _value < IncrementSize) || _value == _initialValue) - _hiValue = callback.GetNextValue(); - else + else if (_value >= _hiValue) { - _hiValue = _value; + _hiValue = callback.GetNextValue(); _value = _hiValue - IncrementSize; } + + return Make(_value++); } - else if (_value >= _hiValue) - { - _hiValue = callback.GetNextValue(); - _value = _hiValue - IncrementSize; - } - return Make(_value++); } } @@ -343,6 +351,7 @@ public partial class PooledLoOptimizer : OptimizerSupport { private long _lastSourceValue = -1; // last value read from db source private long _value; // the current generator value + private readonly AsyncLock _asyncLock = new AsyncLock(); public PooledLoOptimizer(System.Type returnClass, int incrementSize) : base(returnClass, incrementSize) { @@ -356,18 +365,21 @@ public PooledLoOptimizer(System.Type returnClass, int incrementSize) : base(retu } } - [MethodImpl(MethodImplOptions.Synchronized)] public override object Generate(IAccessCallback callback) { - if (_lastSourceValue < 0 || _value >= (_lastSourceValue + IncrementSize)) + using (_asyncLock.Lock()) { - _lastSourceValue = callback.GetNextValue(); - _value = _lastSourceValue; - // handle cases where initial-value is less than one (hsqldb for instance). - while (_value < 1) - _value++; + if (_lastSourceValue < 0 || _value >= (_lastSourceValue + IncrementSize)) + { + _lastSourceValue = callback.GetNextValue(); + _value = _lastSourceValue; + // handle cases where initial-value is less than one (hsqldb for instance). + while (_value < 1) + _value++; + } + + return Make(_value++); } - return Make(_value++); } public override long LastSourceValue diff --git a/src/NHibernate/Id/Enhanced/TableGenerator.cs b/src/NHibernate/Id/Enhanced/TableGenerator.cs index 881280f5237..60a287db13f 100644 --- a/src/NHibernate/Id/Enhanced/TableGenerator.cs +++ b/src/NHibernate/Id/Enhanced/TableGenerator.cs @@ -181,6 +181,7 @@ public partial class TableGenerator : TransactionHelper, IPersistentIdentifierGe private SqlTypes.SqlType[] insertParameterTypes; private SqlString updateQuery; private SqlTypes.SqlType[] updateParameterTypes; + private readonly AsyncLock _asyncLock = new AsyncLock(); public virtual string GeneratorKey() { @@ -378,10 +379,12 @@ protected void BuildInsertQuery() }; } - [MethodImpl(MethodImplOptions.Synchronized)] public virtual object Generate(ISessionImplementor session, object obj) { - return Optimizer.Generate(new TableAccessCallback(session, this)); + using (_asyncLock.Lock()) + { + return Optimizer.Generate(new TableAccessCallback(session, this)); + } } private partial class TableAccessCallback : IAccessCallback diff --git a/src/NHibernate/Id/IncrementGenerator.cs b/src/NHibernate/Id/IncrementGenerator.cs index 1d4fef66ba7..ba4012f60ec 100644 --- a/src/NHibernate/Id/IncrementGenerator.cs +++ b/src/NHibernate/Id/IncrementGenerator.cs @@ -9,6 +9,7 @@ using NHibernate.SqlCommand; using NHibernate.SqlTypes; using NHibernate.Type; +using NHibernate.Util; namespace NHibernate.Id { @@ -32,6 +33,7 @@ public partial class IncrementGenerator : IIdentifierGenerator, IConfigurable private long _next; private SqlString _sql; private System.Type _returnClass; + private readonly AsyncLock _asyncLock = new AsyncLock(); /// /// @@ -85,14 +87,17 @@ public void Configure(IType type, IDictionary parms, Dialect.Dia /// /// /// - [MethodImpl(MethodImplOptions.Synchronized)] public object Generate(ISessionImplementor session, object obj) { - if (_sql != null) + using (_asyncLock.Lock()) { - GetNext(session); + if (_sql != null) + { + GetNext(session); + } + + return IdentifierGeneratorFactory.CreateNumber(_next++, _returnClass); } - return IdentifierGeneratorFactory.CreateNumber(_next++, _returnClass); } private void GetNext(ISessionImplementor session) diff --git a/src/NHibernate/Id/SequenceHiLoGenerator.cs b/src/NHibernate/Id/SequenceHiLoGenerator.cs index 8e07c079548..2d9fca0b85b 100644 --- a/src/NHibernate/Id/SequenceHiLoGenerator.cs +++ b/src/NHibernate/Id/SequenceHiLoGenerator.cs @@ -46,6 +46,7 @@ public partial class SequenceHiLoGenerator : SequenceGenerator private int lo; private long hi; private System.Type returnClass; + private readonly AsyncLock _asyncLock = new AsyncLock(); #region IConfigurable Members @@ -75,27 +76,29 @@ public override void Configure(IType type, IDictionary parms, Di /// The this id is being generated in. /// The entity for which the id is being generated. /// The new identifier as a , , or . - [MethodImpl(MethodImplOptions.Synchronized)] public override object Generate(ISessionImplementor session, object obj) { - if (maxLo < 1) + using (_asyncLock.Lock()) { - //keep the behavior consistent even for boundary usages - long val = Convert.ToInt64(base.Generate(session, obj)); - if (val == 0) - val = Convert.ToInt64(base.Generate(session, obj)); - return IdentifierGeneratorFactory.CreateNumber(val, returnClass); - } + if (maxLo < 1) + { + //keep the behavior consistent even for boundary usages + long val = Convert.ToInt64(base.Generate(session, obj)); + if (val == 0) + val = Convert.ToInt64(base.Generate(session, obj)); + return IdentifierGeneratorFactory.CreateNumber(val, returnClass); + } - if (lo > maxLo) - { - long hival = Convert.ToInt64(base.Generate(session, obj)); - lo = (hival == 0) ? 1 : 0; - hi = hival * (maxLo + 1); - if (log.IsDebugEnabled()) - log.Debug("new hi value: {0}", hival); + if (lo > maxLo) + { + long hival = Convert.ToInt64(base.Generate(session, obj)); + lo = (hival == 0) ? 1 : 0; + hi = hival * (maxLo + 1); + if (log.IsDebugEnabled()) + log.Debug("new hi value: {0}", hival); + } + return IdentifierGeneratorFactory.CreateNumber(hi + lo++, returnClass); } - return IdentifierGeneratorFactory.CreateNumber(hi + lo++, returnClass); } #endregion diff --git a/src/NHibernate/Id/TableGenerator.cs b/src/NHibernate/Id/TableGenerator.cs index ce79bf2a532..09180687efc 100644 --- a/src/NHibernate/Id/TableGenerator.cs +++ b/src/NHibernate/Id/TableGenerator.cs @@ -70,6 +70,7 @@ public partial class TableGenerator : TransactionHelper, IPersistentIdentifierGe private SqlString updateSql; private SqlType[] parameterTypes; + private readonly AsyncLock _asyncLock = new AsyncLock(); #region IConfigurable Members @@ -151,13 +152,15 @@ public virtual void Configure(IType type, IDictionary parms, Dia /// The this id is being generated in. /// The entity for which the id is being generated. /// The new identifier as a , , or . - [MethodImpl(MethodImplOptions.Synchronized)] public virtual object Generate(ISessionImplementor session, object obj) { - // This has to be done using a different connection to the containing - // transaction becase the new hi value must remain valid even if the - // containing transaction rolls back. - return DoWorkInNewTransaction(session); + using (_asyncLock.Lock()) + { + // This has to be done using a different connection to the containing + // transaction becase the new hi value must remain valid even if the + // containing transaction rolls back. + return DoWorkInNewTransaction(session); + } } #endregion diff --git a/src/NHibernate/Id/TableHiLoGenerator.cs b/src/NHibernate/Id/TableHiLoGenerator.cs index 402e5dbd467..c64de1d3323 100644 --- a/src/NHibernate/Id/TableHiLoGenerator.cs +++ b/src/NHibernate/Id/TableHiLoGenerator.cs @@ -52,6 +52,7 @@ public partial class TableHiLoGenerator : TableGenerator private long lo; private long maxLo; private System.Type returnClass; + private readonly AsyncLock _asyncLock = new AsyncLock(); #region IConfigurable Members @@ -80,26 +81,28 @@ public override void Configure(IType type, IDictionary parms, Di /// The this id is being generated in. /// The entity for which the id is being generated. /// The new identifier as a . - [MethodImpl(MethodImplOptions.Synchronized)] public override object Generate(ISessionImplementor session, object obj) { - if (maxLo < 1) + using (_asyncLock.Lock()) { - //keep the behavior consistent even for boundary usages - long val = Convert.ToInt64(base.Generate(session, obj)); - if (val == 0) - val = Convert.ToInt64(base.Generate(session, obj)); - return IdentifierGeneratorFactory.CreateNumber(val, returnClass); - } - if (lo > maxLo) - { - long hival = Convert.ToInt64(base.Generate(session, obj)); - lo = (hival == 0) ? 1 : 0; - hi = hival * (maxLo + 1); - log.Debug("New high value: {0}", hival); - } + if (maxLo < 1) + { + //keep the behavior consistent even for boundary usages + long val = Convert.ToInt64(base.Generate(session, obj)); + if (val == 0) + val = Convert.ToInt64(base.Generate(session, obj)); + return IdentifierGeneratorFactory.CreateNumber(val, returnClass); + } + if (lo > maxLo) + { + long hival = Convert.ToInt64(base.Generate(session, obj)); + lo = (hival == 0) ? 1 : 0; + hi = hival * (maxLo + 1); + log.Debug("New high value: {0}", hival); + } - return IdentifierGeneratorFactory.CreateNumber(hi + lo++, returnClass); + return IdentifierGeneratorFactory.CreateNumber(hi + lo++, returnClass); + } } #endregion diff --git a/src/NHibernate/Util/AsyncLock.cs b/src/NHibernate/Util/AsyncLock.cs index f322d48f175..8a6f00bc95f 100644 --- a/src/NHibernate/Util/AsyncLock.cs +++ b/src/NHibernate/Util/AsyncLock.cs @@ -8,24 +8,32 @@ namespace NHibernate.Util public sealed class AsyncLock { private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1); - private readonly Task _releaser; + private readonly IDisposable _releaser; + private readonly Task _releaserTask; public AsyncLock() { - _releaser = Task.FromResult((IDisposable)new Releaser(this)); + _releaser = new Releaser(this); + _releaserTask = Task.FromResult(_releaser); } public Task LockAsync() { var wait = _semaphore.WaitAsync(); return wait.IsCompleted ? - _releaser : + _releaserTask : wait.ContinueWith( (_, state) => (IDisposable)state, - _releaser.Result, CancellationToken.None, + _releaser, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); } + public IDisposable Lock() + { + _semaphore.Wait(); + return _releaser; + } + private sealed class Releaser : IDisposable { private readonly AsyncLock _toRelease; @@ -33,4 +41,4 @@ private sealed class Releaser : IDisposable public void Dispose() { _toRelease._semaphore.Release(); } } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Util/AsyncReaderWriterLock.cs b/src/NHibernate/Util/AsyncReaderWriterLock.cs new file mode 100644 index 00000000000..46d25b99c9f --- /dev/null +++ b/src/NHibernate/Util/AsyncReaderWriterLock.cs @@ -0,0 +1,252 @@ +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace NHibernate.Util +{ + // Idea from: + // https://github.com/kpreisser/AsyncReaderWriterLockSlim + // https://devblogs.microsoft.com/pfxteam/building-async-coordination-primitives-part-7-asyncreaderwriterlock/ + internal class AsyncReaderWriterLock : IDisposable + { + private readonly SemaphoreSlim _writeLockSemaphore = new SemaphoreSlim(1, 1); + private readonly SemaphoreSlim _readLockSemaphore = new SemaphoreSlim(0, 1); + private readonly Releaser _writerReleaser; + private readonly Releaser _readerReleaser; + private readonly Task _readerReleaserTask; + private SemaphoreSlim _waitingReadLockSemaphore; + private SemaphoreSlim _waitingDisposalSemaphore; + private int _readersWaiting; + private int _currentReaders; + private int _writersWaiting; + private bool _disposed; + + public AsyncReaderWriterLock() + { + _writerReleaser = new Releaser(this, true); + _readerReleaser = new Releaser(this, false); + _readerReleaserTask = Task.FromResult(_readerReleaser); + } + + internal int CurrentReaders => _currentReaders; + + internal int WritersWaiting => _writersWaiting; + + internal int ReadersWaiting => _readersWaiting; + + internal bool Writing => _currentReaders == 0 && _writeLockSemaphore.CurrentCount == 0; + + internal bool AcquiredWriteLock => _writeLockSemaphore.CurrentCount == 0; + + public Releaser WriteLock() + { + if (!CanEnterWriteLock(out var waitForReadLocks)) + { + _writeLockSemaphore.Wait(); + lock (_writeLockSemaphore) + { + _writersWaiting--; + } + } + + if (waitForReadLocks) + { + _readLockSemaphore.Wait(); + } + + DisposeWaitingSemaphore(); + + return _writerReleaser; + } + + public async Task WriteLockAsync() + { + if (!CanEnterWriteLock(out var waitForReadLocks)) + { + await _writeLockSemaphore.WaitAsync().ConfigureAwait(false); + lock (_writeLockSemaphore) + { + _writersWaiting--; + } + } + + if (waitForReadLocks) + { + await _readLockSemaphore.WaitAsync().ConfigureAwait(false); + } + + DisposeWaitingSemaphore(); + + return _writerReleaser; + } + + public Releaser ReadLock() + { + if (CanEnterReadLock()) + { + return _readerReleaser; + } + + _waitingReadLockSemaphore.Wait(); + + return _readerReleaser; + } + + public Task ReadLockAsync() + { + return CanEnterReadLock() ? _readerReleaserTask : ReadLockInternalAsync(); + + async Task ReadLockInternalAsync() + { + await _waitingReadLockSemaphore.WaitAsync().ConfigureAwait(false); + + return _readerReleaser; + } + } + + public void Dispose() + { + lock (_writeLockSemaphore) + { + _writeLockSemaphore.Dispose(); + _readLockSemaphore.Dispose(); + _waitingReadLockSemaphore?.Dispose(); + _waitingDisposalSemaphore?.Dispose(); + _disposed = true; + } + } + + private bool CanEnterWriteLock(out bool waitForReadLocks) + { + waitForReadLocks = false; + lock (_writeLockSemaphore) + { + AssertNotDisposed(); + if (_writeLockSemaphore.CurrentCount > 0 && _writeLockSemaphore.Wait(0)) + { + waitForReadLocks = _currentReaders > 0; + return true; + } + + _writersWaiting++; + } + + return false; + } + + private void ExitWriteLock() + { + lock (_writeLockSemaphore) + { + AssertNotDisposed(); + if (_writeLockSemaphore.CurrentCount == 1) + { + throw new InvalidOperationException(); + } + + // Writers have the highest priority even if they came last + if (_writersWaiting > 0 || _waitingReadLockSemaphore == null) + { + _writeLockSemaphore.Release(); + return; + } + + if (_readersWaiting > 0) + { + _currentReaders += _readersWaiting; + _waitingReadLockSemaphore.Release(_readersWaiting); + _readersWaiting = 0; + // We have to dispose the waiting read lock only after all readers finished using it + _waitingDisposalSemaphore = _waitingReadLockSemaphore; + _waitingReadLockSemaphore = null; + } + + _writeLockSemaphore.Release(); + } + } + + private bool CanEnterReadLock() + { + lock (_writeLockSemaphore) + { + AssertNotDisposed(); + if (_writersWaiting == 0 && _writeLockSemaphore.CurrentCount > 0) + { + _currentReaders++; + + return true; + } + + if (_waitingReadLockSemaphore == null) + { + _waitingReadLockSemaphore = new SemaphoreSlim(0); + } + + _readersWaiting++; + + return false; + } + } + + private void ExitReadLock() + { + lock (_writeLockSemaphore) + { + AssertNotDisposed(); + if (_currentReaders == 0) + { + throw new InvalidOperationException(); + } + + _currentReaders--; + if (_currentReaders == 0 && _writeLockSemaphore.CurrentCount == 0) + { + _readLockSemaphore.Release(); + } + } + } + + private void DisposeWaitingSemaphore() + { + _waitingDisposalSemaphore?.Dispose(); + _waitingDisposalSemaphore = null; + } + + private void AssertNotDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(AsyncReaderWriterLock)); + } + } + + public struct Releaser : IDisposable + { + private readonly AsyncReaderWriterLock _toRelease; + private readonly bool _writer; + + internal Releaser(AsyncReaderWriterLock toRelease, bool writer) + { + _toRelease = toRelease; + _writer = writer; + } + + public void Dispose() + { + if (_toRelease == null) + { + return; + } + + if (_writer) + { + _toRelease.ExitWriteLock(); + } + else + { + _toRelease.ExitReadLock(); + } + } + } + } +}