diff --git a/src/Microsoft.AspNet.Server.Kestrel/Filter/SocketInputStream.cs b/src/Microsoft.AspNet.Server.Kestrel/Filter/SocketInputStream.cs index 26b513faa..d6cc5748f 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Filter/SocketInputStream.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Filter/SocketInputStream.cs @@ -74,11 +74,7 @@ public override void SetLength(long value) public override void Write(byte[] buffer, int offset, int count) { - var inputBuffer = _socketInput.IncomingStart(count); - - Buffer.BlockCopy(buffer, offset, inputBuffer.Data.Array, inputBuffer.Data.Offset, count); - - _socketInput.IncomingComplete(count, error: null); + _socketInput.IncomingData(buffer, offset, count); } public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken token) @@ -90,7 +86,7 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati protected override void Dispose(bool disposing) { // Close _socketInput with a fake zero-length write that will result in a zero-length read. - _socketInput.IncomingComplete(0, error: null); + _socketInput.IncomingData(null, 0, 0); base.Dispose(disposing); } } diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/Connection.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/Connection.cs index 89b9db55e..6f2ea40f2 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Http/Connection.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/Connection.cs @@ -140,11 +140,11 @@ private static Libuv.uv_buf_t AllocCallback(UvStreamHandle handle, int suggested private Libuv.uv_buf_t OnAlloc(UvStreamHandle handle, int suggestedSize) { - var result = _rawSocketInput.IncomingStart(2048); + var result = _rawSocketInput.IncomingStart(); return handle.Libuv.buf_init( - result.DataPtr, - result.Data.Count); + result.Pin() + result.End, + result.Data.Offset + result.Data.Count - result.End); } private static void ReadCallback(UvStreamHandle handle, int status, object state) diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/SocketInput.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/SocketInput.cs index 5a7a1715d..efe431f94 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Http/SocketInput.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/SocketInput.cs @@ -5,7 +5,6 @@ using System.IO; using System.Runtime.CompilerServices; using System.Threading; -using System.Threading.Tasks; using Microsoft.AspNet.Server.Kestrel.Infrastructure; namespace Microsoft.AspNet.Server.Kestrel.Http @@ -25,7 +24,6 @@ public class SocketInput : ICriticalNotifyCompletion private MemoryPoolBlock2 _head; private MemoryPoolBlock2 _tail; private MemoryPoolBlock2 _pinned; - private readonly object _sync = new Object(); public SocketInput(MemoryPool2 memory, IThreadPool threadPool) { @@ -34,99 +32,98 @@ public SocketInput(MemoryPool2 memory, IThreadPool threadPool) _awaitableState = _awaitableIsNotCompleted; } - public ArraySegment Buffer { get; set; } - public bool RemoteIntakeFin { get; set; } - public bool IsCompleted + public bool IsCompleted => (_awaitableState == _awaitableIsCompleted); + + public MemoryPoolBlock2 IncomingStart() { - get + const int minimumSize = 2048; + + if (_tail != null && minimumSize <= _tail.Data.Offset + _tail.Data.Count - _tail.End) { - return Equals(_awaitableState, _awaitableIsCompleted); + _pinned = _tail; + } + else + { + _pinned = _memory.Lease(); } - } - - public void Skip(int count) - { - Buffer = new ArraySegment(Buffer.Array, Buffer.Offset + count, Buffer.Count - count); - } - public ArraySegment Take(int count) - { - var taken = new ArraySegment(Buffer.Array, Buffer.Offset, count); - Skip(count); - return taken; + return _pinned; } - public IncomingBuffer IncomingStart(int minimumSize) + public void IncomingData(byte[] buffer, int offset, int count) { - lock (_sync) + if (count > 0) { - if (_tail != null && minimumSize <= _tail.Data.Offset + _tail.Data.Count - _tail.End) + if (_tail == null) { - _pinned = _tail; - var data = new ArraySegment(_pinned.Data.Array, _pinned.End, _pinned.Data.Offset + _pinned.Data.Count - _pinned.End); - var dataPtr = _pinned.Pin() + _pinned.End; - return new IncomingBuffer - { - Data = data, - DataPtr = dataPtr, - }; + _tail = _memory.Lease(); } - } - _pinned = _memory.Lease(minimumSize); - return new IncomingBuffer + var iterator = new MemoryPoolIterator2(_tail, _tail.End); + iterator.CopyFrom(buffer, offset, count); + + if (_head == null) + { + _head = _tail; + } + + _tail = iterator.Block; + } + else { - Data = _pinned.Data, - DataPtr = _pinned.Pin() + _pinned.End - }; + RemoteIntakeFin = true; + } + + Complete(); } public void IncomingComplete(int count, Exception error) { - Action awaitableState; - - lock (_sync) + // Unpin may called without an earlier Pin + if (_pinned != null) { - // Unpin may called without an earlier Pin - if (_pinned != null) + + _pinned.End += count; + + if (_head == null) { - _pinned.Unpin(); - - _pinned.End += count; - if (_head == null) - { - _head = _tail = _pinned; - } - else if (_tail == _pinned) - { - // NO-OP: this was a read into unoccupied tail-space - } - else - { - _tail.Next = _pinned; - _tail = _pinned; - } + _head = _tail = _pinned; } - _pinned = null; - - if (count == 0) + else if (_tail == _pinned) { - RemoteIntakeFin = true; + // NO-OP: this was a read into unoccupied tail-space } - if (error != null) + else { - _awaitableError = error; + _tail.Next = _pinned; + _tail = _pinned; } - awaitableState = Interlocked.Exchange( - ref _awaitableState, - _awaitableIsCompleted); + _pinned = null; + } - _manualResetEvent.Set(); + if (count == 0) + { + RemoteIntakeFin = true; + } + if (error != null) + { + _awaitableError = error; } + Complete(); + } + + private void Complete() + { + var awaitableState = Interlocked.Exchange( + ref _awaitableState, + _awaitableIsCompleted); + + _manualResetEvent.Set(); + if (awaitableState != _awaitableIsCompleted && awaitableState != _awaitableIsNotCompleted) { @@ -136,10 +133,7 @@ public void IncomingComplete(int count, Exception error) public MemoryPoolIterator2 ConsumingStart() { - lock (_sync) - { - return new MemoryPoolIterator2(_head); - } + return new MemoryPoolIterator2(_head); } public void ConsumingComplete( @@ -148,33 +142,31 @@ public void ConsumingComplete( { MemoryPoolBlock2 returnStart = null; MemoryPoolBlock2 returnEnd = null; - lock (_sync) + if (!consumed.IsDefault) { - if (!consumed.IsDefault) - { - returnStart = _head; - returnEnd = consumed.Block; - _head = consumed.Block; - _head.Start = consumed.Index; - } - if (!examined.IsDefault && - examined.IsEnd && - RemoteIntakeFin == false && - _awaitableError == null) - { - _manualResetEvent.Reset(); + returnStart = _head; + returnEnd = consumed.Block; + _head = consumed.Block; + _head.Start = consumed.Index; + } + if (!examined.IsDefault && + examined.IsEnd && + RemoteIntakeFin == false && + _awaitableError == null) + { + _manualResetEvent.Reset(); - var awaitableState = Interlocked.CompareExchange( - ref _awaitableState, - _awaitableIsNotCompleted, - _awaitableIsCompleted); - } + var awaitableState = Interlocked.CompareExchange( + ref _awaitableState, + _awaitableIsNotCompleted, + _awaitableIsCompleted); } + while (returnStart != returnEnd) { var returnBlock = returnStart; returnStart = returnStart.Next; - returnBlock.Pool?.Return(returnBlock); + returnBlock.Pool.Return(returnBlock); } } @@ -182,17 +174,7 @@ public void AbortAwaiting() { _awaitableError = new ObjectDisposedException(nameof(SocketInput), "The request was aborted"); - var awaitableState = Interlocked.Exchange( - ref _awaitableState, - _awaitableIsCompleted); - - _manualResetEvent.Set(); - - if (awaitableState != _awaitableIsCompleted && - awaitableState != _awaitableIsNotCompleted) - { - _threadPool.Run(awaitableState); - } + Complete(); } public SocketInput GetAwaiter() @@ -247,11 +229,5 @@ public void GetResult() throw new IOException(error.Message, error); } } - - public struct IncomingBuffer - { - public ArraySegment Data; - public IntPtr DataPtr; - } } } diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/SocketInputExtensions.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/SocketInputExtensions.cs index 39cb3202a..d514c3893 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Http/SocketInputExtensions.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/SocketInputExtensions.cs @@ -1,8 +1,8 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; using System.Threading.Tasks; +using Microsoft.AspNet.Server.Kestrel.Infrastructure; namespace Microsoft.AspNet.Server.Kestrel.Http { @@ -10,15 +10,9 @@ public static class SocketInputExtensions { public static ValueTask ReadAsync(this SocketInput input, byte[] buffer, int offset, int count) { - while (true) + while (input.IsCompleted) { - if (!input.IsCompleted) - { - return input.ReadAsyncAwaited(buffer, offset, count); - } - var begin = input.ConsumingStart(); - int actual; var end = begin.CopyTo(buffer, offset, count, out actual); input.ConsumingComplete(end, end); @@ -32,6 +26,8 @@ public static ValueTask ReadAsync(this SocketInput input, byte[] buffer, in return 0; } } + + return input.ReadAsyncAwaited(buffer, offset, count); } private static async Task ReadAsyncAwaited(this SocketInput input, byte[] buffer, int offset, int count) diff --git a/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/TaskUtilities.cs b/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/TaskUtilities.cs index e67ded5ac..6a56e7507 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/TaskUtilities.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/TaskUtilities.cs @@ -12,5 +12,6 @@ public static class TaskUtilities #else public static Task CompletedTask = Task.FromResult(null); #endif + public static Task ZeroTask = Task.FromResult(0); } } \ No newline at end of file diff --git a/test/Microsoft.AspNet.Server.KestrelTests/FrameTests.cs b/test/Microsoft.AspNet.Server.KestrelTests/FrameTests.cs index cc9bc2661..793084ef5 100644 --- a/test/Microsoft.AspNet.Server.KestrelTests/FrameTests.cs +++ b/test/Microsoft.AspNet.Server.KestrelTests/FrameTests.cs @@ -55,9 +55,7 @@ public void EmptyHeaderValuesCanBeParsed(string rawHeaders, int numHeaders) var headerCollection = new FrameRequestHeaders(); var headerArray = Encoding.ASCII.GetBytes(rawHeaders); - var inputBuffer = socketInput.IncomingStart(headerArray.Length); - Buffer.BlockCopy(headerArray, 0, inputBuffer.Data.Array, inputBuffer.Data.Offset, headerArray.Length); - socketInput.IncomingComplete(headerArray.Length, null); + socketInput.IncomingData(headerArray, 0, headerArray.Length); var success = Frame.TakeMessageHeaders(socketInput, headerCollection); diff --git a/test/Microsoft.AspNet.Server.KestrelTests/TestInput.cs b/test/Microsoft.AspNet.Server.KestrelTests/TestInput.cs index 827d47ba8..14c321390 100644 --- a/test/Microsoft.AspNet.Server.KestrelTests/TestInput.cs +++ b/test/Microsoft.AspNet.Server.KestrelTests/TestInput.cs @@ -29,11 +29,8 @@ public TestInput() public void Add(string text, bool fin = false) { - var encoding = System.Text.Encoding.ASCII; - var count = encoding.GetByteCount(text); - var buffer = FrameContext.SocketInput.IncomingStart(text.Length); - count = encoding.GetBytes(text, 0, text.Length, buffer.Data.Array, buffer.Data.Offset); - FrameContext.SocketInput.IncomingComplete(count, null); + var data = System.Text.Encoding.ASCII.GetBytes(text); + FrameContext.SocketInput.IncomingData(data, 0, data.Length); if (fin) { FrameContext.SocketInput.RemoteIntakeFin = true;