diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs index fcfbfa8e6df786..c49cf1d195d438 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs @@ -566,14 +566,9 @@ private async ValueTask DrainContentLength0Frames(CancellationToken cancellation switch (frameType) { case Http3FrameType.Headers: - // Pick up any trailing headers. - _trailingHeaders = new List<(HeaderDescriptor name, string value)>(); - await ReadHeadersAsync(payloadLength, cancellationToken).ConfigureAwait(false); - - // Stop looping after a trailing header. - // There may be extra frames after this one, but they would all be unknown extension - // frames that can be safely ignored. Just stop reading here. - // Note: this does leave us open to a bad server sending us an out of order DATA frame. + // Pick up any trailing headers and stop processing. + await ProcessTrailersAsync(payloadLength, cancellationToken).ConfigureAwait(false); + goto case null; case null: // Done receiving: copy over trailing headers. @@ -601,6 +596,25 @@ private async ValueTask DrainContentLength0Frames(CancellationToken cancellation } } + private async ValueTask ProcessTrailersAsync(long payloadLength, CancellationToken cancellationToken) + { + _trailingHeaders = new List<(HeaderDescriptor name, string value)>(); + await ReadHeadersAsync(payloadLength, cancellationToken).ConfigureAwait(false); + + // In typical cases, there should be no more frames. Make sure to read the EOS. + _recvBuffer.EnsureAvailableSpace(1); + int bytesRead = await _stream.ReadAsync(_recvBuffer.AvailableMemory, cancellationToken).ConfigureAwait(false); + if (bytesRead > 0) + { + // The server may send us frames of unknown types after the trailer. Ideally we should drain the response by eating and ignoring them + // but this is a rare case so we just stop reading and let Dispose() send an ABORT_RECEIVE. + // Note: if a server sends additional HEADERS or DATA frames at this point, it + // would be a connection error -- not draining the stream also means we won't catch this. + _recvBuffer.Commit(bytesRead); + _recvBuffer.Discard(bytesRead); + } + } + private void CopyTrailersToResponseMessage(HttpResponseMessage responseMessage) { if (_trailingHeaders?.Count > 0) @@ -1367,15 +1381,9 @@ private async ValueTask ReadNextDataFrameAsync(HttpResponseMessage respons _responseDataPayloadRemaining = payloadLength; return true; case Http3FrameType.Headers: - // Read any trailing headers. - _trailingHeaders = new List<(HeaderDescriptor name, string value)>(); - await ReadHeadersAsync(payloadLength, cancellationToken).ConfigureAwait(false); - - // There may be more frames after this one, but they would all be unknown extension - // frames that we are allowed to skip. Just close the stream early. + // Pick up any trailing headers and stop processing. + await ProcessTrailersAsync(payloadLength, cancellationToken).ConfigureAwait(false); - // Note: if a server sends additional HEADERS or DATA frames at this point, it - // would be a connection error -- not draining the stream means we won't catch this. goto case null; case null: // End of stream. diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs index 478c9ce0c88838..2577094cb52b4e 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs @@ -7,11 +7,9 @@ using System.IO.Pipes; using System.Linq; using System.Net.Http.Headers; -using System.Net.Quic; using System.Net.Security; using System.Net.Sockets; using System.Net.Test.Common; -using System.Numerics; using System.Reflection; using System.Runtime.CompilerServices; using System.Security.Authentication; @@ -1285,6 +1283,69 @@ protected override async Task AcceptConnectionAndSendResponseAsync( await stream.SendResponseHeadersAsync(statusCode: null, headers: trailers); stream.Stream.CompleteWrites(); } + + [Theory] + [InlineData(false, HttpCompletionOption.ResponseContentRead)] + [InlineData(false, HttpCompletionOption.ResponseHeadersRead)] + [InlineData(true, HttpCompletionOption.ResponseContentRead)] + [InlineData(true, HttpCompletionOption.ResponseHeadersRead)] + public async Task GetAsync_TrailersWithoutServerStreamClosure_Success(bool emptyResponse, HttpCompletionOption httpCompletionOption) + { + SemaphoreSlim serverCompleted = new SemaphoreSlim(0); + + await LoopbackServerFactory.CreateClientAndServerAsync(async uri => + { + HttpClientHandler handler = CreateHttpClientHandler(); + + // Avoid drain timeout if CI is slow. + GetUnderlyingSocketsHttpHandler(handler).ResponseDrainTimeout = TimeSpan.FromSeconds(10); + using HttpClient client = CreateHttpClient(handler); + + using (HttpResponseMessage response = await client.GetAsync(uri, httpCompletionOption)) + { + if (httpCompletionOption == HttpCompletionOption.ResponseHeadersRead) + { + using Stream stream = await response.Content.ReadAsStreamAsync(); + byte[] buffer = new byte[512]; + // Consume the stream + while ((await stream.ReadAsync(buffer)) > 0) ; + } + + Assert.Equal(TrailingHeaders.Count, response.TrailingHeaders.Count()); + } + + await serverCompleted.WaitAsync(); + }, + async server => + { + try + { + await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); + await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync(); + _ = await stream.ReadRequestDataAsync(); + + HttpHeaderData[] headers = emptyResponse ? [new HttpHeaderData("Content-Length", "0")] : null; + + await stream.SendResponseHeadersAsync(statusCode: HttpStatusCode.OK, headers); + if (!emptyResponse) + { + await stream.SendResponseBodyAsync(new byte[16384], isFinal: false); + } + + await stream.SendResponseHeadersAsync(statusCode: null, headers: TrailingHeaders); + + // Small delay to make sure we do test if the client is waiting for EOS. + await Task.Delay(15); + + await stream.DisposeAsync(); + await stream.Stream.WritesClosed; + } + finally + { + serverCompleted.Release(); + } + }).WaitAsync(TimeSpan.FromSeconds(30)); + } } public sealed class SocketsHttpHandler_HttpClientHandlerTest : HttpClientHandlerTest