@@ -46,15 +46,15 @@ public sealed partial class SqlCommand : DbCommand, ICloneable
4646 private static readonly Func < SqlCommand , CommandBehavior , AsyncCallback , object , int , bool , bool , IAsyncResult > s_beginExecuteXmlReaderInternal = BeginExecuteXmlReaderInternalCallback ;
4747 private static readonly Func < SqlCommand , CommandBehavior , AsyncCallback , object , int , bool , bool , IAsyncResult > s_beginExecuteNonQueryInternal = BeginExecuteNonQueryInternalCallback ;
4848
49- internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext < SqlCommand , SqlDataReader >
49+ internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext < SqlCommand , SqlDataReader , CancellationTokenRegistration >
5050 {
5151 public Guid OperationID ;
5252 public CommandBehavior CommandBehavior ;
5353
5454 public SqlCommand Command => _owner ;
5555 public TaskCompletionSource < SqlDataReader > TaskCompletionSource => _source ;
5656
57- public void Set ( SqlCommand command , TaskCompletionSource < SqlDataReader > source , IDisposable disposable , CommandBehavior behavior , Guid operationID )
57+ public void Set ( SqlCommand command , TaskCompletionSource < SqlDataReader > source , CancellationTokenRegistration disposable , CommandBehavior behavior , Guid operationID )
5858 {
5959 base . Set ( command , source , disposable ) ;
6060 CommandBehavior = behavior ;
@@ -73,6 +73,31 @@ protected override void AfterCleared(SqlCommand owner)
7373 }
7474 }
7575
76+ internal sealed class ExecuteNonQueryAsyncCallContext : AAsyncCallContext < SqlCommand , int , CancellationTokenRegistration >
77+ {
78+ public Guid OperationID ;
79+
80+ public SqlCommand Command => _owner ;
81+
82+ public TaskCompletionSource < int > TaskCompletionSource => _source ;
83+
84+ public void Set ( SqlCommand command , TaskCompletionSource < int > source , CancellationTokenRegistration disposable , Guid operationID )
85+ {
86+ base . Set ( command , source , disposable ) ;
87+ OperationID = operationID ;
88+ }
89+
90+ protected override void Clear ( )
91+ {
92+ OperationID = default ;
93+ }
94+
95+ protected override void AfterCleared ( SqlCommand owner )
96+ {
97+
98+ }
99+ }
100+
76101 private CommandType _commandType ;
77102 private int ? _commandTimeout ;
78103 private UpdateRowSource _updatedRowSource = UpdateRowSource . Both ;
@@ -2540,23 +2565,36 @@ private Task<int> InternalExecuteNonQueryAsync(CancellationToken cancellationTok
25402565 }
25412566
25422567 Task < int > returnedTask = source . Task ;
2568+ returnedTask = RegisterForConnectionCloseNotification ( returnedTask ) ;
2569+
2570+ ExecuteNonQueryAsyncCallContext context = new ExecuteNonQueryAsyncCallContext ( ) ;
2571+ context . Set ( this , source , registration , operationId ) ;
25432572 try
25442573 {
2545- returnedTask = RegisterForConnectionCloseNotification ( returnedTask ) ;
2546-
2547- Task < int > . Factory . FromAsync ( BeginExecuteNonQueryAsync , EndExecuteNonQueryAsync , null )
2548- . ContinueWith ( ( Task < int > task ) =>
2574+ Task < int > . Factory . FromAsync (
2575+ static ( AsyncCallback callback , object stateObject ) => ( ( ExecuteNonQueryAsyncCallContext ) stateObject ) . Command . BeginExecuteNonQueryAsync ( callback , stateObject ) ,
2576+ static ( IAsyncResult result ) => ( ( ExecuteNonQueryAsyncCallContext ) result . AsyncState ) . Command . EndExecuteNonQueryAsync ( result ) ,
2577+ state : context
2578+ ) . ContinueWith (
2579+ static ( Task < int > task , object state ) =>
25492580 {
2550- registration . Dispose ( ) ;
2581+ ExecuteNonQueryAsyncCallContext context = ( ExecuteNonQueryAsyncCallContext ) state ;
2582+
2583+ Guid operationId = context . OperationID ;
2584+ SqlCommand command = context . Command ;
2585+ TaskCompletionSource < int > source = context . TaskCompletionSource ;
2586+
2587+ context . Dispose ( ) ;
2588+ context = null ;
2589+
25512590 if ( task . IsFaulted )
25522591 {
25532592 Exception e = task . Exception . InnerException ;
2554- s_diagnosticListener . WriteCommandError ( operationId , this , _transaction , e ) ;
2593+ s_diagnosticListener . WriteCommandError ( operationId , command , command . _transaction , e ) ;
25552594 source . SetException ( e ) ;
25562595 }
25572596 else
25582597 {
2559- s_diagnosticListener . WriteCommandAfter ( operationId , this , _transaction ) ;
25602598 if ( task . IsCanceled )
25612599 {
25622600 source . SetCanceled ( ) ;
@@ -2565,15 +2603,18 @@ private Task<int> InternalExecuteNonQueryAsync(CancellationToken cancellationTok
25652603 {
25662604 source . SetResult ( task . Result ) ;
25672605 }
2606+ s_diagnosticListener . WriteCommandAfter ( operationId , command , command . _transaction ) ;
25682607 }
2569- } ,
2570- TaskScheduler . Default
2608+ } ,
2609+ state : context ,
2610+ scheduler : TaskScheduler . Default
25712611 ) ;
25722612 }
25732613 catch ( Exception e )
25742614 {
25752615 s_diagnosticListener . WriteCommandError ( operationId , this , _transaction , e ) ;
25762616 source . SetException ( e ) ;
2617+ context . Dispose ( ) ;
25772618 }
25782619
25792620 return returnedTask ;
@@ -2648,11 +2689,11 @@ private Task<SqlDataReader> InternalExecuteReaderAsync(CommandBehavior behavior,
26482689 }
26492690
26502691 Task < SqlDataReader > returnedTask = source . Task ;
2692+ ExecuteReaderAsyncCallContext context = null ;
26512693 try
26522694 {
26532695 returnedTask = RegisterForConnectionCloseNotification ( returnedTask ) ;
26542696
2655- ExecuteReaderAsyncCallContext context = null ;
26562697 if ( _activeConnection ? . InnerConnection is SqlInternalConnection sqlInternalConnection )
26572698 {
26582699 context = Interlocked . Exchange ( ref sqlInternalConnection . CachedCommandExecuteReaderAsyncContext , null ) ;
@@ -2680,6 +2721,7 @@ private Task<SqlDataReader> InternalExecuteReaderAsync(CommandBehavior behavior,
26802721 }
26812722
26822723 source . SetException ( e ) ;
2724+ context . Dispose ( ) ;
26832725 }
26842726
26852727 return returnedTask ;
0 commit comments