@@ -46,15 +46,15 @@ public sealed partial class SqlCommand : DbCommand, ICloneable
4646 private static readonly Func < SqlCommand , CommandBehavior , AsyncCallback , object , int , bool , bool , IAsyncResult > s_beginExecuteXmlReaderInternal = BeginExecuteXmlReaderInternalCallback ;
4747 private static readonly Func < SqlCommand , CommandBehavior , AsyncCallback , object , int , bool , bool , IAsyncResult > s_beginExecuteNonQueryInternal = BeginExecuteNonQueryInternalCallback ;
4848
49- internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext < SqlCommand , SqlDataReader >
49+ internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext < SqlCommand , SqlDataReader , CancellationTokenRegistration >
5050 {
5151 public Guid OperationID ;
5252 public CommandBehavior CommandBehavior ;
5353
5454 public SqlCommand Command => _owner ;
5555 public TaskCompletionSource < SqlDataReader > TaskCompletionSource => _source ;
5656
57- public void Set ( SqlCommand command , TaskCompletionSource < SqlDataReader > source , IDisposable disposable , CommandBehavior behavior , Guid operationID )
57+ public void Set ( SqlCommand command , TaskCompletionSource < SqlDataReader > source , CancellationTokenRegistration disposable , CommandBehavior behavior , Guid operationID )
5858 {
5959 base . Set ( command , source , disposable ) ;
6060 CommandBehavior = behavior ;
@@ -73,6 +73,31 @@ protected override void AfterCleared(SqlCommand owner)
7373 }
7474 }
7575
76+ internal sealed class ExecuteNonQueryAsyncCallContext : AAsyncCallContext < SqlCommand , int , CancellationTokenRegistration >
77+ {
78+ public Guid OperationID ;
79+
80+ public SqlCommand Command => _owner ;
81+
82+ public TaskCompletionSource < int > TaskCompletionSource => _source ;
83+
84+ public void Set ( SqlCommand command , TaskCompletionSource < int > source , CancellationTokenRegistration disposable , Guid operationID )
85+ {
86+ base . Set ( command , source , disposable ) ;
87+ OperationID = operationID ;
88+ }
89+
90+ protected override void Clear ( )
91+ {
92+ OperationID = default ;
93+ }
94+
95+ protected override void AfterCleared ( SqlCommand owner )
96+ {
97+
98+ }
99+ }
100+
76101 private CommandType _commandType ;
77102 private int ? _commandTimeout ;
78103 private UpdateRowSource _updatedRowSource = UpdateRowSource . Both ;
@@ -2540,37 +2565,56 @@ private Task<int> InternalExecuteNonQueryAsync(CancellationToken cancellationTok
25402565 }
25412566
25422567 Task < int > returnedTask = source . Task ;
2568+ returnedTask = RegisterForConnectionCloseNotification ( returnedTask ) ;
2569+
2570+ ExecuteNonQueryAsyncCallContext context = new ExecuteNonQueryAsyncCallContext ( ) ;
2571+ context . Set ( this , source , registration , operationId ) ;
25432572 try
25442573 {
2545- returnedTask = RegisterForConnectionCloseNotification ( returnedTask ) ;
2546-
2547- Task < int > . Factory . FromAsync ( BeginExecuteNonQueryAsync , EndExecuteNonQueryAsync , null ) . ContinueWith ( ( t ) =>
2548- {
2549- registration . Dispose ( ) ;
2550- if ( t . IsFaulted )
2551- {
2552- Exception e = t . Exception . InnerException ;
2553- _diagnosticListener . WriteCommandError ( operationId , this , _transaction , e ) ;
2554- source . SetException ( e ) ;
2555- }
2556- else
2574+ Task < int > . Factory . FromAsync (
2575+ static ( AsyncCallback callback , object stateObject ) => ( ( ExecuteNonQueryAsyncCallContext ) stateObject ) . Command . BeginExecuteNonQueryAsync ( callback , stateObject ) ,
2576+ static ( IAsyncResult result ) => ( ( ExecuteNonQueryAsyncCallContext ) result . AsyncState ) . Command . EndExecuteNonQueryAsync ( result ) ,
2577+ state : context
2578+ ) . ContinueWith (
2579+ static ( Task < int > task , object state ) =>
25572580 {
2558- if ( t . IsCanceled )
2581+ ExecuteNonQueryAsyncCallContext context = ( ExecuteNonQueryAsyncCallContext ) state ;
2582+
2583+ Guid operationId = context . OperationID ;
2584+ SqlCommand command = context . Command ;
2585+ TaskCompletionSource < int > source = context . TaskCompletionSource ;
2586+
2587+ context . Dispose ( ) ;
2588+ context = null ;
2589+
2590+ if ( task . IsFaulted )
25592591 {
2560- source . SetCanceled ( ) ;
2592+ Exception e = task . Exception . InnerException ;
2593+ _diagnosticListener . WriteCommandError ( operationId , command , command . _transaction , e ) ;
2594+ source . SetException ( e ) ;
25612595 }
25622596 else
25632597 {
2564- source . SetResult ( t . Result ) ;
2598+ if ( task . IsCanceled )
2599+ {
2600+ source . SetCanceled ( ) ;
2601+ }
2602+ else
2603+ {
2604+ source . SetResult ( task . Result ) ;
2605+ }
2606+ _diagnosticListener . WriteCommandAfter ( operationId , command , command . _transaction ) ;
25652607 }
2566- _diagnosticListener . WriteCommandAfter ( operationId , this , _transaction ) ;
2567- }
2568- } , TaskScheduler . Default ) ;
2608+ } ,
2609+ state : context ,
2610+ scheduler : TaskScheduler . Default
2611+ ) ;
25692612 }
25702613 catch ( Exception e )
25712614 {
25722615 _diagnosticListener . WriteCommandError ( operationId , this , _transaction , e ) ;
25732616 source . SetException ( e ) ;
2617+ context . Dispose ( ) ;
25742618 }
25752619
25762620 return returnedTask ;
@@ -2645,11 +2689,11 @@ private Task<SqlDataReader> InternalExecuteReaderAsync(CommandBehavior behavior,
26452689 }
26462690
26472691 Task < SqlDataReader > returnedTask = source . Task ;
2692+ ExecuteReaderAsyncCallContext context = null ;
26482693 try
26492694 {
26502695 returnedTask = RegisterForConnectionCloseNotification ( returnedTask ) ;
26512696
2652- ExecuteReaderAsyncCallContext context = null ;
26532697 if ( _activeConnection ? . InnerConnection is SqlInternalConnection sqlInternalConnection )
26542698 {
26552699 context = Interlocked . Exchange ( ref sqlInternalConnection . CachedCommandExecuteReaderAsyncContext , null ) ;
@@ -2677,6 +2721,7 @@ private Task<SqlDataReader> InternalExecuteReaderAsync(CommandBehavior behavior,
26772721 }
26782722
26792723 source . SetException ( e ) ;
2724+ context . Dispose ( ) ;
26802725 }
26812726
26822727 return returnedTask ;
0 commit comments