Skip to content

Commit 829b9cf

Browse files
committed
convert ExecuteNonQueryAsync to use async context object
1 parent 9b1996a commit 829b9cf

File tree

3 files changed

+167
-74
lines changed

3 files changed

+167
-74
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/AAsyncCallContext.cs

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,38 +17,69 @@ namespace Microsoft.Data.SqlClient
1717
// CONSIDER creating your own Set method that calls the base Set rather than providing a parameterized ctor, it is friendlier to caching
1818
// DO NOT use this class' state after Dispose has been called. It will not throw ObjectDisposedException but it will be a cleared object
1919

20-
internal abstract class AAsyncCallContext<TOwner, TTask> : IDisposable
20+
internal abstract class AAsyncCallContext<TOwner, TTask, TDisposable> : AAsyncBaseCallContext<TOwner,TTask>
2121
where TOwner : class
22+
where TDisposable : IDisposable
2223
{
23-
protected TOwner _owner;
24-
protected TaskCompletionSource<TTask> _source;
25-
protected IDisposable _disposable;
24+
protected TDisposable _disposable;
2625

2726
protected AAsyncCallContext()
2827
{
2928
}
3029

31-
protected AAsyncCallContext(TOwner owner, TaskCompletionSource<TTask> source, IDisposable disposable = null)
30+
protected AAsyncCallContext(TOwner owner, TaskCompletionSource<TTask> source, TDisposable disposable = default)
3231
{
3332
Set(owner, source, disposable);
3433
}
3534

36-
protected void Set(TOwner owner, TaskCompletionSource<TTask> source, IDisposable disposable = null)
35+
protected void Set(TOwner owner, TaskCompletionSource<TTask> source, TDisposable disposable = default)
36+
{
37+
base.Set(owner, source);
38+
_disposable = disposable;
39+
}
40+
41+
protected override void DisposeCore()
42+
{
43+
TDisposable copyDisposable = _disposable;
44+
_disposable = default;
45+
_isDisposed = true;
46+
copyDisposable?.Dispose();
47+
}
48+
}
49+
50+
internal abstract class AAsyncBaseCallContext<TOwner, TTask>
51+
{
52+
protected TOwner _owner;
53+
protected TaskCompletionSource<TTask> _source;
54+
protected bool _isDisposed;
55+
56+
protected AAsyncBaseCallContext()
57+
{
58+
}
59+
60+
protected void Set(TOwner owner, TaskCompletionSource<TTask> source)
3761
{
3862
_owner = owner ?? throw new ArgumentNullException(nameof(owner));
3963
_source = source ?? throw new ArgumentNullException(nameof(source));
40-
_disposable = disposable;
64+
_isDisposed = false;
4165
}
4266

4367
protected void ClearCore()
4468
{
4569
_source = null;
4670
_owner = default;
47-
IDisposable copyDisposable = _disposable;
48-
_disposable = null;
49-
copyDisposable?.Dispose();
71+
try
72+
{
73+
DisposeCore();
74+
}
75+
finally
76+
{
77+
_isDisposed = true;
78+
}
5079
}
5180

81+
protected abstract void DisposeCore();
82+
5283
/// <summary>
5384
/// override this method to cleanup instance data before ClearCore is called which will blank the base data
5485
/// </summary>
@@ -65,16 +96,19 @@ protected virtual void AfterCleared(TOwner owner)
6596

6697
public void Dispose()
6798
{
68-
TOwner owner = _owner;
69-
try
70-
{
71-
Clear();
72-
}
73-
finally
99+
if (!_isDisposed)
74100
{
75-
ClearCore();
101+
TOwner owner = _owner;
102+
try
103+
{
104+
Clear();
105+
}
106+
finally
107+
{
108+
ClearCore();
109+
}
110+
AfterCleared(owner);
76111
}
77-
AfterCleared(owner);
78112
}
79113
}
80114
}

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs

Lines changed: 66 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@ public sealed partial class SqlCommand : DbCommand, ICloneable
4646
private static readonly Func<SqlCommand, CommandBehavior, AsyncCallback, object, int, bool, bool, IAsyncResult> s_beginExecuteXmlReaderInternal = BeginExecuteXmlReaderInternalCallback;
4747
private static readonly Func<SqlCommand, CommandBehavior, AsyncCallback, object, int, bool, bool, IAsyncResult> s_beginExecuteNonQueryInternal = BeginExecuteNonQueryInternalCallback;
4848

49-
internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext<SqlCommand, SqlDataReader>
49+
internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext<SqlCommand, SqlDataReader, CancellationTokenRegistration>
5050
{
5151
public Guid OperationID;
5252
public CommandBehavior CommandBehavior;
5353

5454
public SqlCommand Command => _owner;
5555
public TaskCompletionSource<SqlDataReader> TaskCompletionSource => _source;
5656

57-
public void Set(SqlCommand command, TaskCompletionSource<SqlDataReader> source, IDisposable disposable, CommandBehavior behavior, Guid operationID)
57+
public void Set(SqlCommand command, TaskCompletionSource<SqlDataReader> source, CancellationTokenRegistration disposable, CommandBehavior behavior, Guid operationID)
5858
{
5959
base.Set(command, source, disposable);
6060
CommandBehavior = behavior;
@@ -73,6 +73,31 @@ protected override void AfterCleared(SqlCommand owner)
7373
}
7474
}
7575

76+
internal sealed class ExecuteNonQueryAsyncCallContext : AAsyncCallContext<SqlCommand, int, CancellationTokenRegistration>
77+
{
78+
public Guid OperationID;
79+
80+
public SqlCommand Command => _owner;
81+
82+
public TaskCompletionSource<int> TaskCompletionSource => _source;
83+
84+
public void Set(SqlCommand command, TaskCompletionSource<int> source, CancellationTokenRegistration disposable, Guid operationID)
85+
{
86+
base.Set(command, source, disposable);
87+
OperationID = operationID;
88+
}
89+
90+
protected override void Clear()
91+
{
92+
OperationID = default;
93+
}
94+
95+
protected override void AfterCleared(SqlCommand owner)
96+
{
97+
98+
}
99+
}
100+
76101
private CommandType _commandType;
77102
private int? _commandTimeout;
78103
private UpdateRowSource _updatedRowSource = UpdateRowSource.Both;
@@ -2540,37 +2565,56 @@ private Task<int> InternalExecuteNonQueryAsync(CancellationToken cancellationTok
25402565
}
25412566

25422567
Task<int> returnedTask = source.Task;
2568+
returnedTask = RegisterForConnectionCloseNotification(returnedTask);
2569+
2570+
ExecuteNonQueryAsyncCallContext context = new ExecuteNonQueryAsyncCallContext();
2571+
context.Set(this, source, registration, operationId);
25432572
try
25442573
{
2545-
returnedTask = RegisterForConnectionCloseNotification(returnedTask);
2546-
2547-
Task<int>.Factory.FromAsync(BeginExecuteNonQueryAsync, EndExecuteNonQueryAsync, null).ContinueWith((t) =>
2548-
{
2549-
registration.Dispose();
2550-
if (t.IsFaulted)
2551-
{
2552-
Exception e = t.Exception.InnerException;
2553-
_diagnosticListener.WriteCommandError(operationId, this, _transaction, e);
2554-
source.SetException(e);
2555-
}
2556-
else
2574+
Task<int>.Factory.FromAsync(
2575+
static (AsyncCallback callback, object stateObject) => ((ExecuteNonQueryAsyncCallContext)stateObject).Command.BeginExecuteNonQueryAsync(callback, stateObject),
2576+
static (IAsyncResult result) => ((ExecuteNonQueryAsyncCallContext)result.AsyncState).Command.EndExecuteNonQueryAsync(result),
2577+
state: context
2578+
).ContinueWith(
2579+
static (Task<int> task, object state) =>
25572580
{
2558-
if (t.IsCanceled)
2581+
ExecuteNonQueryAsyncCallContext context = (ExecuteNonQueryAsyncCallContext)state;
2582+
2583+
Guid operationId = context.OperationID;
2584+
SqlCommand command = context.Command;
2585+
TaskCompletionSource<int> source = context.TaskCompletionSource;
2586+
2587+
context.Dispose();
2588+
context = null;
2589+
2590+
if (task.IsFaulted)
25592591
{
2560-
source.SetCanceled();
2592+
Exception e = task.Exception.InnerException;
2593+
_diagnosticListener.WriteCommandError(operationId, command, command._transaction, e);
2594+
source.SetException(e);
25612595
}
25622596
else
25632597
{
2564-
source.SetResult(t.Result);
2598+
if (task.IsCanceled)
2599+
{
2600+
source.SetCanceled();
2601+
}
2602+
else
2603+
{
2604+
source.SetResult(task.Result);
2605+
}
2606+
_diagnosticListener.WriteCommandAfter(operationId, command, command._transaction);
25652607
}
2566-
_diagnosticListener.WriteCommandAfter(operationId, this, _transaction);
2567-
}
2568-
}, TaskScheduler.Default);
2608+
},
2609+
state: context,
2610+
scheduler: TaskScheduler.Default
2611+
);
25692612
}
25702613
catch (Exception e)
25712614
{
25722615
_diagnosticListener.WriteCommandError(operationId, this, _transaction, e);
25732616
source.SetException(e);
2617+
context.Dispose();
25742618
}
25752619

25762620
return returnedTask;
@@ -2645,11 +2689,11 @@ private Task<SqlDataReader> InternalExecuteReaderAsync(CommandBehavior behavior,
26452689
}
26462690

26472691
Task<SqlDataReader> returnedTask = source.Task;
2692+
ExecuteReaderAsyncCallContext context = null;
26482693
try
26492694
{
26502695
returnedTask = RegisterForConnectionCloseNotification(returnedTask);
26512696

2652-
ExecuteReaderAsyncCallContext context = null;
26532697
if (_activeConnection?.InnerConnection is SqlInternalConnection sqlInternalConnection)
26542698
{
26552699
context = Interlocked.Exchange(ref sqlInternalConnection.CachedCommandExecuteReaderAsyncContext, null);
@@ -2677,6 +2721,7 @@ private Task<SqlDataReader> InternalExecuteReaderAsync(CommandBehavior behavior,
26772721
}
26782722

26792723
source.SetException(e);
2724+
context.Dispose();
26802725
}
26812726

26822727
return returnedTask;

0 commit comments

Comments
 (0)