From 404c88bd80eb536789d0f332c4b4d36ee603cb7e Mon Sep 17 00:00:00 2001
From: Wraith2 <wraith2@gmail.com>
Date: Sun, 21 Feb 2021 19:28:49 +0000
Subject: [PATCH] Clean up AAsyncCallContext and SqlDataReader uses of it

---
 .../Microsoft/Data/SqlClient/SqlDataReader.cs | 339 ++++++++++--------
 1 file changed, 181 insertions(+), 158 deletions(-)

diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs
index 3fcad61f2e..ef6c6ff941 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs
@@ -4298,16 +4298,16 @@ private static Task<bool> NextResultAsyncExecute(Task task, object state)
             if (task != null)
             {
                 SqlClientEventSource.Log.TryTraceEvent("SqlDataReader.NextResultAsyncExecute | attempt retry {0}", ObjectID);
-                context._reader.PrepareForAsyncContinuation();
+                context.Reader.PrepareForAsyncContinuation();
             }
 
-            if (context._reader.TryNextResult(out bool more))
+            if (context.Reader.TryNextResult(out bool more))
             {
                 // completed
                 return more ? ADP.TrueTask : ADP.FalseTask;
             }
 
-            return context._reader.ExecuteAsyncCall(context);
+            return context.Reader.ExecuteAsyncCall(context);
         }
 
         // NOTE: This will return null if it completed sequentially
@@ -4338,12 +4338,12 @@ internal Task<int> GetBytesAsync(int columnIndex, byte[] buffer, int index, int
 
             var context = new GetBytesAsyncCallContext(this)
             {
-                columnIndex = columnIndex,
-                buffer = buffer,
-                index = index,
-                length = length,
-                timeout = timeout,
-                cancellationToken = cancellationToken,
+                _columnIndex = columnIndex,
+                _buffer = buffer,
+                _index = index,
+                _length = length,
+                _timeout = timeout,
+                _cancellationToken = cancellationToken,
             };
 
             // Check if we need to skip columns
@@ -4368,18 +4368,18 @@ internal Task<int> GetBytesAsync(int columnIndex, byte[] buffer, int index, int
                     timeoutToken = timeoutCancellationSource.Token;
                 }
 
-                context._disposable = timeoutCancellationSource;
-                context.timeoutToken = timeoutToken;
-                context._source = source;
 
                 PrepareAsyncInvocation(useSnapshot: true);
 
+                context.Set(this, source, timeoutCancellationSource);
+                context._timeoutToken = timeoutToken;
+
                 return InvokeAsyncCall(context);
             }
             else
             {
                 // We're already at the correct column, just read the data
-                context.mode = GetBytesAsyncCallContext.OperationMode.Read;
+                context._mode = GetBytesAsyncCallContext.OperationMode.Read;
 
                 // Switch to async
                 PrepareAsyncInvocation(useSnapshot: false);
@@ -4399,9 +4399,9 @@ internal Task<int> GetBytesAsync(int columnIndex, byte[] buffer, int index, int
         private static Task<int> GetBytesAsyncSeekExecute(Task task, object state)
         {
             GetBytesAsyncCallContext context = (GetBytesAsyncCallContext)state;
-            SqlDataReader reader = context._reader;
+            SqlDataReader reader = context.Reader;
 
-            Debug.Assert(context.mode == GetBytesAsyncCallContext.OperationMode.Seek, "context.mode must be Seek to check if seeking can resume");
+            Debug.Assert(context._mode == GetBytesAsyncCallContext.OperationMode.Seek, "context.mode must be Seek to check if seeking can resume");
 
             if (task != null)
             {
@@ -4411,16 +4411,16 @@ private static Task<int> GetBytesAsyncSeekExecute(Task task, object state)
             // Prepare for stateObj timeout
             reader.SetTimeout(reader._defaultTimeoutMilliseconds);
 
-            if (reader.TryReadColumnHeader(context.columnIndex))
+            if (reader.TryReadColumnHeader(context._columnIndex))
             {
                 // Only once we have read up to where we need to be can we check the cancellation tokens (otherwise we will be in an unknown state)
 
-                if (context.cancellationToken.IsCancellationRequested)
+                if (context._cancellationToken.IsCancellationRequested)
                 {
                     // User requested cancellation
-                    return Task.FromCanceled<int>(context.cancellationToken);
+                    return Task.FromCanceled<int>(context._cancellationToken);
                 }
-                else if (context.timeoutToken.IsCancellationRequested)
+                else if (context._timeoutToken.IsCancellationRequested)
                 {
                     // Timeout
                     return Task.FromException<int>(ADP.ExceptionWithStackTrace(ADP.IO(SQLMessage.Timeout())));
@@ -4428,7 +4428,7 @@ private static Task<int> GetBytesAsyncSeekExecute(Task task, object state)
                 else
                 {
                     // Up to the correct column - continue to read
-                    context.mode = GetBytesAsyncCallContext.OperationMode.Read;
+                    context._mode = GetBytesAsyncCallContext.OperationMode.Read;
                     reader.SwitchToAsyncWithoutSnapshot();
                     int totalBytesRead;
                     var readTask = reader.GetBytesAsyncReadDataStage(context, true, out totalBytesRead);
@@ -4452,18 +4452,18 @@ private static Task<int> GetBytesAsyncSeekExecute(Task task, object state)
         private static Task<int> GetBytesAsyncReadExecute(Task task, object state)
         {
             var context = (GetBytesAsyncCallContext)state;
-            SqlDataReader reader = context._reader;
+            SqlDataReader reader = context.Reader;
 
-            Debug.Assert(context.mode == GetBytesAsyncCallContext.OperationMode.Read, "context.mode must be Read to check if read can resume");
+            Debug.Assert(context._mode == GetBytesAsyncCallContext.OperationMode.Read, "context.mode must be Read to check if read can resume");
 
             reader.PrepareForAsyncContinuation();
 
-            if (context.cancellationToken.IsCancellationRequested)
+            if (context._cancellationToken.IsCancellationRequested)
             {
                 // User requested cancellation
-                return Task.FromCanceled<int>(context.cancellationToken);
+                return Task.FromCanceled<int>(context._cancellationToken);
             }
-            else if (context.timeoutToken.IsCancellationRequested)
+            else if (context._timeoutToken.IsCancellationRequested)
             {
                 // Timeout
                 return Task.FromException<int>(ADP.ExceptionWithStackTrace(ADP.IO(SQLMessage.Timeout())));
@@ -4475,18 +4475,18 @@ private static Task<int> GetBytesAsyncReadExecute(Task task, object state)
 
                 int bytesReadThisIteration;
                 bool result = reader.TryGetBytesInternalSequential(
-                    context.columnIndex,
-                    context.buffer,
-                    context.index + context.totalBytesRead,
-                    context.length - context.totalBytesRead,
+                    context._columnIndex,
+                    context._buffer,
+                    context._index + context._totalBytesRead,
+                    context._length - context._totalBytesRead,
                     out bytesReadThisIteration
                 );
-                context.totalBytesRead += bytesReadThisIteration;
-                Debug.Assert(context.totalBytesRead <= context.length, "Read more bytes than required");
+                context._totalBytesRead += bytesReadThisIteration;
+                Debug.Assert(context._totalBytesRead <= context._length, "Read more bytes than required");
 
                 if (result)
                 {
-                    return Task.FromResult<int>(context.totalBytesRead);
+                    return Task.FromResult<int>(context._totalBytesRead);
                 }
                 else
                 {
@@ -4497,24 +4497,24 @@ out bytesReadThisIteration
 
         private Task<int> GetBytesAsyncReadDataStage(GetBytesAsyncCallContext context, bool isContinuation, out int bytesRead)
         {
-            Debug.Assert(context.mode == GetBytesAsyncCallContext.OperationMode.Read, "context.Mode must be Read to read data");
+            Debug.Assert(context._mode == GetBytesAsyncCallContext.OperationMode.Read, "context.Mode must be Read to read data");
 
-            _lastColumnWithDataChunkRead = context.columnIndex;
+            _lastColumnWithDataChunkRead = context._columnIndex;
             TaskCompletionSource<int> source = null;
 
             // Prepare for stateObj timeout
             SetTimeout(_defaultTimeoutMilliseconds);
 
             // Try to read without any continuations (all the data may already be in the stateObj's buffer)
-            bool filledBuffer = context._reader.TryGetBytesInternalSequential(
-                context.columnIndex,
-                context.buffer,
-                context.index + context.totalBytesRead,
-                context.length - context.totalBytesRead,
+            bool filledBuffer = context.Reader.TryGetBytesInternalSequential(
+                context._columnIndex,
+                context._buffer,
+                context._index + context._totalBytesRead,
+                context._length - context._totalBytesRead,
                 out bytesRead
             );
-            context.totalBytesRead += bytesRead;
-            Debug.Assert(context.totalBytesRead <= context.length, "Read more bytes than required");
+            context._totalBytesRead += bytesRead;
+            Debug.Assert(context._totalBytesRead <= context._length, "Read more bytes than required");
 
             if (!filledBuffer)
             {
@@ -4522,7 +4522,7 @@ out bytesRead
                 if (!isContinuation)
                 {
                     // This is the first async operation which is happening - setup the _currentTask and timeout
-                    Debug.Assert(context._source == null, "context._source should not be non-null when trying to change to async");
+                    Debug.Assert(context.Source == null, "context._source should not be non-null when trying to change to async");
                     source = new TaskCompletionSource<int>();
                     Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
                     if (original != null)
@@ -4530,7 +4530,7 @@ out bytesRead
                         source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending()));
                         return source.Task;
                     }
-                    context._source = source;
+                    context.Source = source;
                     // Check if cancellation due to close is requested (this needs to be done after setting _currentTask)
                     if (_cancelAsyncOnCloseToken.IsCancellationRequested)
                     {
@@ -4540,14 +4540,14 @@ out bytesRead
                     }
 
                     // Timeout
-                    Debug.Assert(context.timeoutToken == CancellationToken.None, "TimeoutToken is set when GetBytesAsyncReadDataStage is not a continuation");
-                    if (context.timeout > 0)
+                    Debug.Assert(context._timeoutToken == CancellationToken.None, "TimeoutToken is set when GetBytesAsyncReadDataStage is not a continuation");
+                    if (context._timeout > 0)
                     {
                         CancellationTokenSource timeoutCancellationSource = new CancellationTokenSource();
-                        timeoutCancellationSource.CancelAfter(context.timeout);
-                        Debug.Assert(context._disposable is null, "setting context.disposable would lose the previous disposable");
-                        context._disposable = timeoutCancellationSource;
-                        context.timeoutToken = timeoutCancellationSource.Token;
+                        timeoutCancellationSource.CancelAfter(context._timeout);
+                        Debug.Assert(context.Disposable is null, "setting context.disposable would lose the previous disposable");
+                        context.Disposable = timeoutCancellationSource;
+                        context._timeoutToken = timeoutCancellationSource.Token;
                     }
                 }
 
@@ -4559,10 +4559,10 @@ out bytesRead
                 }
                 else
                 {
-                    Debug.Assert(context._source != null, "context._source should not be null when continuing");
+                    Debug.Assert(context.Source != null, "context._source should not be null when continuing");
                     // setup for cleanup/completing
                     retryTask.ContinueWith(
-                        continuationAction: AAsyncCallContext<int>.s_completeCallback,
+                        continuationAction: SqlDataReaderAsyncCallContext<int>.s_completeCallback,
                         state: context,
                         TaskScheduler.Default
                     );
@@ -4704,7 +4704,7 @@ public override Task<bool> ReadAsync(CancellationToken cancellationToken)
                     context = new ReadAsyncCallContext();
                 }
 
-                Debug.Assert(context._reader == null && context._source == null && context._disposable == null, "cached ReadAsyncCallContext was not properly disposed");
+                Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == null, "cached ReadAsyncCallContext was not properly disposed");
 
                 context.Set(this, source, registration);
                 context._hasMoreData = more;
@@ -4723,7 +4723,7 @@ public override Task<bool> ReadAsync(CancellationToken cancellationToken)
         private static Task<bool> ReadAsyncExecute(Task task, object state)
         {
             var context = (ReadAsyncCallContext)state;
-            SqlDataReader reader = context._reader;
+            SqlDataReader reader = context.Reader;
             ref bool hasMoreData = ref context._hasMoreData;
             ref bool hasReadRowToken = ref context._hasReadRowToken;
 
@@ -4882,7 +4882,7 @@ override public Task<bool> IsDBNullAsync(int i, CancellationToken cancellationTo
                     context = new IsDBNullAsyncCallContext();
                 }
 
-                Debug.Assert(context._reader == null && context._source == null && context._disposable == null, "cached ISDBNullAsync context not properly disposed");
+                Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == null, "cached ISDBNullAsync context not properly disposed");
 
                 context.Set(this, source, registration);
                 context._columnIndex = i;
@@ -4897,7 +4897,7 @@ override public Task<bool> IsDBNullAsync(int i, CancellationToken cancellationTo
         private static Task<bool> IsDBNullAsyncExecute(Task task, object state)
         {
             IsDBNullAsyncCallContext context = (IsDBNullAsyncCallContext)state;
-            SqlDataReader reader = context._reader;
+            SqlDataReader reader = context.Reader;
 
             if (task != null)
             {
@@ -4976,7 +4976,7 @@ override public Task<T> GetFieldValueAsync<T>(int i, CancellationToken cancellat
                     {
                         _stateObj._shouldHaveEnoughData = true;
 #endif
-                        return Task.FromResult(GetFieldValueInternal<T>(i));
+                    return Task.FromResult(GetFieldValueInternal<T>(i));
 #if DEBUG
                     }
                     finally
@@ -5016,19 +5016,22 @@ override public Task<T> GetFieldValueAsync<T>(int i, CancellationToken cancellat
             IDisposable registration = null;
             if (cancellationToken.CanBeCanceled)
             {
-                registration = cancellationToken.Register(s => ((SqlCommand)s).CancelIgnoreFailure(), _command);
+                registration = cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command);
             }
 
             // Setup async
             PrepareAsyncInvocation(useSnapshot: true);
 
-            return InvokeAsyncCall(new GetFieldValueAsyncCallContext<T>(this, source, registration, i));
+            GetFieldValueAsyncCallContext<T> context = new GetFieldValueAsyncCallContext<T>(this, source, registration);
+            context._columnIndex = i;
+
+            return InvokeAsyncCall(context);
         }
 
         private static Task<T> GetFieldValueAsyncExecute<T>(Task task, object state)
         {
             GetFieldValueAsyncCallContext<T> context = (GetFieldValueAsyncCallContext<T>)state;
-            SqlDataReader reader = context._reader;
+            SqlDataReader reader = context.Reader;
             int columnIndex = context._columnIndex;
             if (task != null)
             {
@@ -5067,71 +5070,48 @@ internal void CompletePendingReadWithFailure(int errorCode, bool resetForcePendi
 
 #endif
 
-        internal class Snapshot
+        internal abstract class SqlDataReaderAsyncCallContext<T> : AAsyncCallContext<SqlDataReader,T>
         {
-            public bool _dataReady;
-            public bool _haltRead;
-            public bool _metaDataConsumed;
-            public bool _browseModeInfoConsumed;
-            public bool _hasRows;
-            public ALTROWSTATUS _altRowStatus;
-            public int _nextColumnDataToRead;
-            public int _nextColumnHeaderToRead;
-            public long _columnDataBytesRead;
-            public long _columnDataBytesRemaining;
+            internal static readonly Action<Task<T>, object> s_completeCallback = CompleteAsyncCallCallback;
 
-            public _SqlMetaDataSet _metadata;
-            public _SqlMetaDataSetCollection _altMetaDataSetCollection;
-            public MultiPartTableName[] _tableNames;
+            internal static readonly Func<Task, object, Task<T>> s_executeCallback = ExecuteAsyncCallCallback;
 
-            public SqlSequentialStream _currentStream;
-            public SqlSequentialTextReader _currentTextReader;
-        }
-
-        internal abstract class AAsyncCallContext<T> : IDisposable
-        {
-            internal static readonly Action<Task<T>, object> s_completeCallback = SqlDataReader.CompleteAsyncCallCallback<T>;
-
-            internal static readonly Func<Task, object, Task<T>> s_executeCallback = SqlDataReader.ExecuteAsyncCallCallback<T>;
-
-            internal SqlDataReader _reader;
-            internal TaskCompletionSource<T> _source;
-            internal IDisposable _disposable;
-
-            protected AAsyncCallContext()
+            protected SqlDataReaderAsyncCallContext()
             {
             }
 
-            protected AAsyncCallContext(SqlDataReader reader, TaskCompletionSource<T> source, IDisposable disposable = null)
+            protected SqlDataReaderAsyncCallContext(SqlDataReader owner, TaskCompletionSource<T> source, IDisposable disposable = null)
             {
-                Set(reader, source, disposable);
+                Set(owner, source, disposable);
             }
 
-            internal void Set(SqlDataReader reader, TaskCompletionSource<T> source, IDisposable disposable = null)
+            internal abstract Func<Task, object, Task<T>> Execute { get; }
+
+            internal SqlDataReader Reader { get => _owner; set => _owner = value; }
+
+            public IDisposable Disposable { get => _disposable; set => _disposable = value; }
+
+            public TaskCompletionSource<T> Source { get => _source; set => _source = value; }
+
+            new public void Set(SqlDataReader reader, TaskCompletionSource<T> source, IDisposable disposable)
             {
-                this._reader = reader ?? throw new ArgumentNullException(nameof(reader));
-                this._source = source ?? throw new ArgumentNullException(nameof(source));
-                this._disposable = disposable;
+                base.Set(reader, source, disposable);
             }
 
-            internal void Clear()
+            private static Task<T> ExecuteAsyncCallCallback(Task task, object state)
             {
-                _source = null;
-                _reader = null;
-                IDisposable copyDisposable = _disposable;
-                _disposable = null;
-                copyDisposable?.Dispose();
+                SqlDataReaderAsyncCallContext<T> context = (SqlDataReaderAsyncCallContext<T>)state;
+                return context.Reader.ContinueAsyncCall(task, context);
             }
 
-            internal abstract Func<Task, object, Task<T>> Execute { get; }
-
-            public virtual void Dispose()
+            private static void CompleteAsyncCallCallback(Task<T> task, object state)
             {
-                Clear();
+                SqlDataReaderAsyncCallContext<T> context = (SqlDataReaderAsyncCallContext<T>)state;
+                context.Reader.CompleteAsyncCall(task, context);
             }
         }
 
-        internal sealed class ReadAsyncCallContext : AAsyncCallContext<bool>
+        internal sealed class ReadAsyncCallContext : SqlDataReaderAsyncCallContext<bool>
         {
             internal static readonly Func<Task, object, Task<bool>> s_execute = SqlDataReader.ReadAsyncExecute;
 
@@ -5144,15 +5124,13 @@ internal ReadAsyncCallContext()
 
             internal override Func<Task, object, Task<bool>> Execute => s_execute;
 
-            public override void Dispose()
+            protected override void AfterCleared(SqlDataReader owner)
             {
-                SqlDataReader reader = this._reader;
-                base.Dispose();
-                reader.SetCachedReadAsyncCallContext(this);
+                owner.SetCachedReadAsyncCallContext(this);
             }
         }
 
-        internal sealed class IsDBNullAsyncCallContext : AAsyncCallContext<bool>
+        internal sealed class IsDBNullAsyncCallContext : SqlDataReaderAsyncCallContext<bool>
         {
             internal static readonly Func<Task, object, Task<bool>> s_execute = SqlDataReader.IsDBNullAsyncExecute;
 
@@ -5162,15 +5140,13 @@ internal IsDBNullAsyncCallContext() { }
 
             internal override Func<Task, object, Task<bool>> Execute => s_execute;
 
-            public override void Dispose()
+            protected override void AfterCleared(SqlDataReader owner)
             {
-                SqlDataReader reader = this._reader;
-                base.Dispose();
-                reader.SetCachedIDBNullAsyncCallContext(this);
+                owner.SetCachedIDBNullAsyncCallContext(this);
             }
         }
 
-        private sealed class HasNextResultAsyncCallContext : AAsyncCallContext<bool>
+        private sealed class HasNextResultAsyncCallContext : SqlDataReaderAsyncCallContext<bool>
         {
             private static readonly Func<Task, object, Task<bool>> s_execute = SqlDataReader.NextResultAsyncExecute;
 
@@ -5182,7 +5158,7 @@ public HasNextResultAsyncCallContext(SqlDataReader reader, TaskCompletionSource<
             internal override Func<Task, object, Task<bool>> Execute => s_execute;
         }
 
-        private sealed class GetBytesAsyncCallContext : AAsyncCallContext<int>
+        private sealed class GetBytesAsyncCallContext : SqlDataReaderAsyncCallContext<int>
         {
             internal enum OperationMode
             {
@@ -5193,63 +5169,66 @@ internal enum OperationMode
             private static readonly Func<Task, object, Task<int>> s_executeSeek = SqlDataReader.GetBytesAsyncSeekExecute;
             private static readonly Func<Task, object, Task<int>> s_executeRead = SqlDataReader.GetBytesAsyncReadExecute;
 
-            internal int columnIndex;
-            internal byte[] buffer;
-            internal int index;
-            internal int length;
-            internal int timeout;
-            internal CancellationToken cancellationToken;
-            internal CancellationToken timeoutToken;
-            internal int totalBytesRead;
+            internal int _columnIndex;
+            internal byte[] _buffer;
+            internal int _index;
+            internal int _length;
+            internal int _timeout;
+            internal CancellationToken _cancellationToken;
+            internal CancellationToken _timeoutToken;
+            internal int _totalBytesRead;
 
-            internal OperationMode mode;
+            internal OperationMode _mode;
 
             internal GetBytesAsyncCallContext(SqlDataReader reader)
             {
-                this._reader = reader ?? throw new ArgumentNullException(nameof(reader));
+                Reader = reader ?? throw new ArgumentNullException(nameof(reader));
             }
 
-            internal override Func<Task, object, Task<int>> Execute => mode == OperationMode.Seek ? s_executeSeek : s_executeRead;
+            internal override Func<Task, object, Task<int>> Execute => _mode == OperationMode.Seek ? s_executeSeek : s_executeRead;
 
-            public override void Dispose()
+            protected override void Clear()
             {
-                buffer = null;
-                cancellationToken = default;
-                timeoutToken = default;
-                base.Dispose();
+                _buffer = null;
+                _cancellationToken = default;
+                _timeoutToken = default;
+                base.Clear();
             }
         }
 
-        private sealed class GetFieldValueAsyncCallContext<T> : AAsyncCallContext<T>
+        private sealed class GetFieldValueAsyncCallContext<T> : SqlDataReaderAsyncCallContext<T>
         {
             private static readonly Func<Task, object, Task<T>> s_execute = SqlDataReader.GetFieldValueAsyncExecute<T>;
 
-            internal readonly int _columnIndex;
+            internal int _columnIndex;
+
+            internal GetFieldValueAsyncCallContext() { }
 
-            internal GetFieldValueAsyncCallContext(SqlDataReader reader, TaskCompletionSource<T> source, IDisposable disposable, int columnIndex)
+            internal GetFieldValueAsyncCallContext(SqlDataReader reader, TaskCompletionSource<T> source, IDisposable disposable)
                 : base(reader, source, disposable)
             {
-                _columnIndex = columnIndex;
             }
 
-            internal override Func<Task, object, Task<T>> Execute => s_execute;
-        }
-
-        private static Task<T> ExecuteAsyncCallCallback<T>(Task task, object state)
-        {
-            AAsyncCallContext<T> context = (AAsyncCallContext<T>)state;
-            return context._reader.ExecuteAsyncCall(task, context);
-        }
+            protected override void Clear()
+            {
+                _columnIndex = -1;
+                base.Clear();
+            }
 
-        private static void CompleteAsyncCallCallback<T>(Task<T> task, object state)
-        {
-            AAsyncCallContext<T> context = (AAsyncCallContext<T>)state;
-            context._reader.CompleteAsyncCall<T>(task, context);
+            internal override Func<Task, object, Task<T>> Execute => s_execute;
         }
 
-        private Task<T> InvokeAsyncCall<T>(AAsyncCallContext<T> context)
+        /// <summary>
+        /// Starts the process of executing an async call using an SqlDataReaderAsyncCallContext derived context object.
+        /// After this call the context lifetime is handled by BeginAsyncCall ContinueAsyncCall and CompleteAsyncCall AsyncCall methods
+        ///
+        /// </summary>
+        /// <typeparam name="T"></typeparam>
+        /// <param name="context"></param>
+        /// <returns></returns>
+        private Task<T> InvokeAsyncCall<T>(SqlDataReaderAsyncCallContext<T> context)
         {
-            TaskCompletionSource<T> source = context._source;
+            TaskCompletionSource<T> source = context.Source;
             try
             {
                 Task<T> task;
@@ -5269,7 +5248,7 @@ private Task<T> InvokeAsyncCall<T>(AAsyncCallContext<T> context)
                 else
                 {
                     task.ContinueWith(
-                        continuationAction: AAsyncCallContext<T>.s_completeCallback,
+                        continuationAction: SqlDataReaderAsyncCallContext<T>.s_completeCallback,
                         state: context,
                         TaskScheduler.Default
                     );
@@ -5288,7 +5267,13 @@ private Task<T> InvokeAsyncCall<T>(AAsyncCallContext<T> context)
             return source.Task;
         }
 
-        private Task<T> ExecuteAsyncCall<T>(AAsyncCallContext<T> context)
+        /// <summary>
+        /// Begins an async call checking for cancellation and then setting up the callback for when data is available
+        /// </summary>
+        /// <typeparam name="T"></typeparam>
+        /// <param name="context"></param>
+        /// <returns></returns>
+        private Task<T> ExecuteAsyncCall<T>(SqlDataReaderAsyncCallContext<T> context)
         {
             // _networkPacketTaskSource could be null if the connection was closed
             // while an async invocation was outstanding.
@@ -5301,14 +5286,23 @@ private Task<T> ExecuteAsyncCall<T>(AAsyncCallContext<T> context)
             else
             {
                 return completionSource.Task.ContinueWith(
-                    continuationFunction: AAsyncCallContext<T>.s_executeCallback,
+                    continuationFunction: SqlDataReaderAsyncCallContext<T>.s_executeCallback,
                     state: context,
                     TaskScheduler.Default
                 ).Unwrap();
             }
         }
 
-        private Task<T> ExecuteAsyncCall<T>(Task task, AAsyncCallContext<T> context)
+        /// <summary>
+        /// When data has become available for an async call it is woken and this method is called.
+        /// It will call the async execution func and if a Task is returned indicating more data
+        /// is needed it will wait until it is called again when more is available
+        /// </summary>
+        /// <typeparam name="T"></typeparam>
+        /// <param name="task"></param>
+        /// <param name="context"></param>
+        /// <returns></returns>
+        private Task<T> ContinueAsyncCall<T>(Task task, SqlDataReaderAsyncCallContext<T> context)
         {
             // this function must be an instance function called from the static callback because otherwise a compiler error
             // is caused by accessing the _cancelAsyncOnCloseToken field of a MarshalByRefObject derived class
@@ -5361,9 +5355,16 @@ private Task<T> ExecuteAsyncCall<T>(Task task, AAsyncCallContext<T> context)
             return Task.FromException<T>(ADP.ExceptionWithStackTrace(ADP.ClosedConnectionError()));
         }
 
-        private void CompleteAsyncCall<T>(Task<T> task, AAsyncCallContext<T> context)
+        /// <summary>
+        /// When data has been successfully processed for an async call the async func will call this
+        /// function to set the result into the task and cleanup the async state ready for another call
+        /// </summary>
+        /// <typeparam name="T"></typeparam>
+        /// <param name="task"></param>
+        /// <param name="context"></param>
+        private void CompleteAsyncCall<T>(Task<T> task, SqlDataReaderAsyncCallContext<T> context)
         {
-            TaskCompletionSource<T> source = context._source;
+            TaskCompletionSource<T> source = context.Source;
             context.Dispose();
 
             // If something has forced us to switch to SyncOverAsync mode while in an async task then we need to guarantee that we do the cleanup
@@ -5390,6 +5391,28 @@ private void CompleteAsyncCall<T>(Task<T> task, AAsyncCallContext<T> context)
             }
         }
 
+
+        internal class Snapshot
+        {
+            public bool _dataReady;
+            public bool _haltRead;
+            public bool _metaDataConsumed;
+            public bool _browseModeInfoConsumed;
+            public bool _hasRows;
+            public ALTROWSTATUS _altRowStatus;
+            public int _nextColumnDataToRead;
+            public int _nextColumnHeaderToRead;
+            public long _columnDataBytesRead;
+            public long _columnDataBytesRemaining;
+
+            public _SqlMetaDataSet _metadata;
+            public _SqlMetaDataSetCollection _altMetaDataSetCollection;
+            public MultiPartTableName[] _tableNames;
+
+            public SqlSequentialStream _currentStream;
+            public SqlSequentialTextReader _currentTextReader;
+        }
+
         private void PrepareAsyncInvocation(bool useSnapshot)
         {
             // if there is already a snapshot, then the previous async command
@@ -5600,5 +5623,5 @@ private ReadOnlyCollection<DbColumn> BuildColumnSchema()
 
             return new ReadOnlyCollection<DbColumn>(columnSchema);
         }
-    }// SqlDataReader
-}// namespace
+    }
+}