Skip to content

Commit bae2f22

Browse files
authored
Make StartAsync not throw if we haven't started the response (#8199)
1 parent c0c2bb3 commit bae2f22

File tree

8 files changed

+905
-79
lines changed

8 files changed

+905
-79
lines changed

src/Servers/Kestrel/Core/ref/Microsoft.AspNetCore.Server.Kestrel.Core.netcoreapp3.0.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ public void Dispose() { }
484484
public System.Threading.Tasks.ValueTask<System.IO.Pipelines.FlushResult> FlushAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
485485
public System.Memory<byte> GetMemory(int sizeHint = 0) { throw null; }
486486
public System.Span<byte> GetSpan(int sizeHint = 0) { throw null; }
487+
public void Reset() { }
487488
public System.Threading.Tasks.ValueTask<System.IO.Pipelines.FlushResult> Write100ContinueAsync() { throw null; }
488489
public System.Threading.Tasks.ValueTask<System.IO.Pipelines.FlushResult> WriteChunkAsync(System.ReadOnlySpan<byte> buffer, System.Threading.CancellationToken cancellationToken) { throw null; }
489490
public System.Threading.Tasks.Task WriteDataAsync(System.ReadOnlySpan<byte> buffer, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
@@ -954,6 +955,7 @@ public partial interface IHttpOutputProducer
954955
System.Threading.Tasks.ValueTask<System.IO.Pipelines.FlushResult> FlushAsync(System.Threading.CancellationToken cancellationToken);
955956
System.Memory<byte> GetMemory(int sizeHint = 0);
956957
System.Span<byte> GetSpan(int sizeHint = 0);
958+
void Reset();
957959
System.Threading.Tasks.ValueTask<System.IO.Pipelines.FlushResult> Write100ContinueAsync();
958960
System.Threading.Tasks.ValueTask<System.IO.Pipelines.FlushResult> WriteChunkAsync(System.ReadOnlySpan<byte> data, System.Threading.CancellationToken cancellationToken);
959961
System.Threading.Tasks.Task WriteDataAsync(System.ReadOnlySpan<byte> data, System.Threading.CancellationToken cancellationToken);
@@ -1297,6 +1299,7 @@ public void Dispose() { }
12971299
public System.Span<byte> GetSpan(int sizeHint = 0) { throw null; }
12981300
void Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http.IHttpOutputAborter.Abort(Microsoft.AspNetCore.Connections.ConnectionAbortedException abortReason) { }
12991301
System.Threading.Tasks.ValueTask<System.IO.Pipelines.FlushResult> Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http.IHttpOutputProducer.WriteChunkAsync(System.ReadOnlySpan<byte> data, System.Threading.CancellationToken cancellationToken) { throw null; }
1302+
public void Reset() { }
13001303
public System.Threading.Tasks.ValueTask<System.IO.Pipelines.FlushResult> Write100ContinueAsync() { throw null; }
13011304
public System.Threading.Tasks.Task WriteChunkAsync(System.ReadOnlySpan<byte> span, System.Threading.CancellationToken cancellationToken) { throw null; }
13021305
public System.Threading.Tasks.Task WriteDataAsync(System.ReadOnlySpan<byte> data, System.Threading.CancellationToken cancellationToken) { throw null; }

src/Servers/Kestrel/Core/src/Internal/Http/Http1OutputProducer.cs

Lines changed: 199 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
using System;
55
using System.Buffers;
6+
using System.Collections.Generic;
67
using System.Diagnostics;
78
using System.IO.Pipelines;
89
using System.Threading;
@@ -26,6 +27,9 @@ public class Http1OutputProducer : IHttpOutputProducer, IHttpOutputAborter, IDis
2627
// "0\r\n\r\n"
2728
private static ReadOnlySpan<byte> EndChunkedResponseBytes => new byte[] { (byte)'0', (byte)'\r', (byte)'\n', (byte)'\r', (byte)'\n' };
2829

30+
private const int BeginChunkLengthMax = 5;
31+
private const int EndChunkLength = 2;
32+
2933
private readonly string _connectionId;
3034
private readonly ConnectionContext _connectionContext;
3135
private readonly IKestrelTrace _log;
@@ -40,21 +44,28 @@ public class Http1OutputProducer : IHttpOutputProducer, IHttpOutputAborter, IDis
4044
private bool _completed;
4145
private bool _aborted;
4246
private long _unflushedBytes;
43-
private bool _autoChunk;
47+
4448
private readonly PipeWriter _pipeWriter;
45-
private const int MemorySizeThreshold = 1024;
46-
private const int BeginChunkLengthMax = 5;
47-
private const int EndChunkLength = 2;
49+
private IMemoryOwner<byte> _fakeMemoryOwner;
4850

4951
// Chunked responses need to be treated uniquely when using GetMemory + Advance.
5052
// We need to know the size of the data written to the chunk before calling Advance on the
5153
// PipeWriter, meaning we internally track how far we have advanced through a current chunk (_advancedBytesForChunk).
5254
// Once write or flush is called, we modify the _currentChunkMemory to prepend the size of data written
5355
// and append the end terminator.
56+
57+
private bool _autoChunk;
5458
private int _advancedBytesForChunk;
5559
private Memory<byte> _currentChunkMemory;
5660
private bool _currentChunkMemoryUpdated;
57-
private IMemoryOwner<byte> _fakeMemoryOwner;
61+
62+
// Fields needed to store writes before calling either startAsync or Write/FlushAsync
63+
// These should be cleared by the end of the request
64+
private List<CompletedBuffer> _completedSegments;
65+
private Memory<byte> _currentSegment;
66+
private IMemoryOwner<byte> _currentSegmentOwner;
67+
private int _position;
68+
private bool _startCalled;
5869

5970
public Http1OutputProducer(
6071
PipeWriter pipeWriter,
@@ -158,6 +169,10 @@ public Memory<byte> GetMemory(int sizeHint = 0)
158169
{
159170
return GetFakeMemory(sizeHint);
160171
}
172+
else if (!_startCalled)
173+
{
174+
return LeasedMemory(sizeHint);
175+
}
161176
else if (_autoChunk)
162177
{
163178
return GetChunkedMemory(sizeHint);
@@ -177,6 +192,10 @@ public Span<byte> GetSpan(int sizeHint = 0)
177192
{
178193
return GetFakeMemory(sizeHint).Span;
179194
}
195+
else if (!_startCalled)
196+
{
197+
return LeasedMemory(sizeHint).Span;
198+
}
180199
else if (_autoChunk)
181200
{
182201
return GetChunkedMemory(sizeHint).Span;
@@ -197,16 +216,23 @@ public void Advance(int bytes)
197216
return;
198217
}
199218

200-
if (_autoChunk)
219+
if (!_startCalled)
201220
{
202-
if (bytes < 0)
221+
if (bytes >= 0)
203222
{
204-
throw new ArgumentOutOfRangeException(nameof(bytes));
205-
}
223+
if (_currentSegment.Length - bytes < _position)
224+
{
225+
throw new ArgumentOutOfRangeException("Can't advance past buffer size.");
226+
}
206227

207-
if (bytes + _advancedBytesForChunk > _currentChunkMemory.Length - BeginChunkLengthMax - EndChunkLength)
228+
_position += bytes;
229+
}
230+
}
231+
else if (_autoChunk)
232+
{
233+
if (_advancedBytesForChunk > _currentChunkMemory.Length - BeginChunkLengthMax - EndChunkLength - bytes)
208234
{
209-
throw new InvalidOperationException("Can't advance past buffer size.");
235+
throw new ArgumentOutOfRangeException("Can't advance past buffer size.");
210236
}
211237
_advancedBytesForChunk += bytes;
212238
}
@@ -238,6 +264,7 @@ public ValueTask<FlushResult> WriteChunkAsync(ReadOnlySpan<byte> buffer, Cancell
238264
{
239265
var writer = new BufferWriter<PipeWriter>(_pipeWriter);
240266
CommitChunkInternal(ref writer, buffer);
267+
_unflushedBytes += writer.BytesCommitted;
241268
}
242269
}
243270

@@ -260,7 +287,6 @@ private void CommitChunkInternal(ref BufferWriter<PipeWriter> writer, ReadOnlySp
260287
}
261288

262289
writer.Commit();
263-
_unflushedBytes += writer.BytesCommitted;
264290
}
265291

266292
public void WriteResponseHeaders(int statusCode, string reasonPhrase, HttpResponseHeaders responseHeaders, bool autoChunk)
@@ -288,8 +314,52 @@ private void WriteResponseHeadersInternal(ref BufferWriter<PipeWriter> writer, i
288314

289315
writer.Commit();
290316

291-
_unflushedBytes += writer.BytesCommitted;
292317
_autoChunk = autoChunk;
318+
WriteDataWrittenBeforeHeaders(ref writer);
319+
_unflushedBytes += writer.BytesCommitted;
320+
321+
_startCalled = true;
322+
}
323+
324+
private void WriteDataWrittenBeforeHeaders(ref BufferWriter<PipeWriter> writer)
325+
{
326+
if (_completedSegments != null)
327+
{
328+
foreach (var segment in _completedSegments)
329+
{
330+
if (_autoChunk)
331+
{
332+
CommitChunkInternal(ref writer, segment.Span);
333+
}
334+
else
335+
{
336+
writer.Write(segment.Span);
337+
writer.Commit();
338+
}
339+
segment.Return();
340+
}
341+
342+
_completedSegments.Clear();
343+
}
344+
345+
if (!_currentSegment.IsEmpty)
346+
{
347+
var segment = _currentSegment.Slice(0, _position);
348+
349+
if (_autoChunk)
350+
{
351+
CommitChunkInternal(ref writer, segment.Span);
352+
}
353+
else
354+
{
355+
writer.Write(segment.Span);
356+
writer.Commit();
357+
}
358+
359+
_position = 0;
360+
361+
DisposeCurrentSegment();
362+
}
293363
}
294364

295365
public void Dispose()
@@ -302,10 +372,28 @@ public void Dispose()
302372
_fakeMemoryOwner = null;
303373
}
304374

375+
// Call dispose on any memory that wasn't written.
376+
if (_completedSegments != null)
377+
{
378+
foreach (var segment in _completedSegments)
379+
{
380+
segment.Return();
381+
}
382+
}
383+
384+
DisposeCurrentSegment();
385+
305386
CompletePipe();
306387
}
307388
}
308389

390+
private void DisposeCurrentSegment()
391+
{
392+
_currentSegmentOwner?.Dispose();
393+
_currentSegmentOwner = null;
394+
_currentSegment = default;
395+
}
396+
309397
private void CompletePipe()
310398
{
311399
if (!_pipeWriterCompleted)
@@ -382,10 +470,21 @@ public ValueTask<FlushResult> FirstWriteChunkedAsync(int statusCode, string reas
382470

383471
CommitChunkInternal(ref writer, buffer);
384472

473+
_unflushedBytes += writer.BytesCommitted;
474+
385475
return FlushAsync(cancellationToken);
386476
}
387477
}
388478

479+
public void Reset()
480+
{
481+
Debug.Assert(_currentSegmentOwner == null);
482+
Debug.Assert(_completedSegments == null || _completedSegments.Count == 0);
483+
_autoChunk = false;
484+
_startCalled = false;
485+
_currentChunkMemoryUpdated = false;
486+
}
487+
389488
private ValueTask<FlushResult> WriteAsync(
390489
ReadOnlySpan<byte> buffer,
391490
CancellationToken cancellationToken = default)
@@ -454,7 +553,7 @@ private Memory<byte> GetChunkedMemory(int sizeHint)
454553
}
455554

456555
var memoryMaxLength = _currentChunkMemory.Length - BeginChunkLengthMax - EndChunkLength;
457-
if (_advancedBytesForChunk >= memoryMaxLength - Math.Min(MemorySizeThreshold, sizeHint))
556+
if (_advancedBytesForChunk >= memoryMaxLength - sizeHint && _advancedBytesForChunk > 0)
458557
{
459558
// Chunk is completely written, commit it to the pipe so GetMemory will return a new chunk of memory.
460559
var writer = new BufferWriter<PipeWriter>(_pipeWriter);
@@ -506,5 +605,91 @@ private Memory<byte> GetFakeMemory(int sizeHint)
506605
}
507606
return _fakeMemoryOwner.Memory;
508607
}
608+
609+
private Memory<byte> LeasedMemory(int sizeHint)
610+
{
611+
EnsureCapacity(sizeHint);
612+
return _currentSegment.Slice(_position);
613+
}
614+
615+
private void EnsureCapacity(int sizeHint)
616+
{
617+
// Only subtracts _position from the current segment length if it's non-null.
618+
// If _currentSegment is null, it returns 0.
619+
var remainingSize = _currentSegment.Length - _position;
620+
621+
// If the sizeHint is 0, any capacity will do
622+
// Otherwise, the buffer must have enough space for the entire size hint, or we need to add a segment.
623+
if ((sizeHint == 0 && remainingSize > 0) || (sizeHint > 0 && remainingSize >= sizeHint))
624+
{
625+
// We have capacity in the current segment
626+
return;
627+
}
628+
629+
AddSegment(sizeHint);
630+
}
631+
632+
private void AddSegment(int sizeHint = 0)
633+
{
634+
if (_currentSegment.Length != 0)
635+
{
636+
// We're adding a segment to the list
637+
if (_completedSegments == null)
638+
{
639+
_completedSegments = new List<CompletedBuffer>();
640+
}
641+
642+
// Position might be less than the segment length if there wasn't enough space to satisfy the sizeHint when
643+
// GetMemory was called. In that case we'll take the current segment and call it "completed", but need to
644+
// ignore any empty space in it.
645+
_completedSegments.Add(new CompletedBuffer(_currentSegmentOwner, _currentSegment, _position));
646+
}
647+
648+
if (sizeHint <= _memoryPool.MaxBufferSize)
649+
{
650+
// Get a new buffer using the minimum segment size, unless the size hint is larger than a single segment.
651+
// Also, the size cannot be larger than the MaxBufferSize of the MemoryPool
652+
var owner = _memoryPool.Rent(Math.Min(sizeHint, _memoryPool.MaxBufferSize));
653+
_currentSegment = owner.Memory;
654+
_currentSegmentOwner = owner;
655+
}
656+
else
657+
{
658+
_currentSegment = new byte[sizeHint];
659+
_currentSegmentOwner = null;
660+
}
661+
662+
_position = 0;
663+
}
664+
665+
666+
/// <summary>
667+
/// Holds a byte[] from the pool and a size value. Basically a Memory but guaranteed to be backed by an ArrayPool byte[], so that we know we can return it.
668+
/// </summary>
669+
private readonly struct CompletedBuffer
670+
{
671+
private readonly IMemoryOwner<byte> _memoryOwner;
672+
673+
public Memory<byte> Buffer { get; }
674+
public int Length { get; }
675+
676+
public ReadOnlySpan<byte> Span => Buffer.Span.Slice(0, Length);
677+
678+
public CompletedBuffer(IMemoryOwner<byte> owner, Memory<byte> buffer, int length)
679+
{
680+
_memoryOwner = owner;
681+
682+
Buffer = buffer;
683+
Length = length;
684+
}
685+
686+
public void Return()
687+
{
688+
if (_memoryOwner != null)
689+
{
690+
_memoryOwner.Dispose();
691+
}
692+
}
693+
}
509694
}
510695
}

0 commit comments

Comments
 (0)