diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs index eb5294b2a134..af1860f28f91 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs @@ -121,6 +121,7 @@ protected override void OnReset() { _keepAlive = true; _connectionAborted = false; + _userTrailers = null; // Reset Http2 Features _currentIHttpMinRequestBodyDataRateFeature = this; diff --git a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs index ea69197bc27e..31d2bb8811f4 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs @@ -726,6 +726,7 @@ protected override void OnReset() { _keepAlive = true; _connectionAborted = false; + _userTrailers = null; // Reset Http3 Features _currentIHttpMinRequestBodyDataRateFeature = this; diff --git a/src/Servers/Kestrel/shared/test/Http3/Http3InMemory.cs b/src/Servers/Kestrel/shared/test/Http3/Http3InMemory.cs index 2cbc890d3c4e..1e443ba2709b 100644 --- a/src/Servers/Kestrel/shared/test/Http3/Http3InMemory.cs +++ b/src/Servers/Kestrel/shared/test/Http3/Http3InMemory.cs @@ -687,6 +687,17 @@ internal async ValueTask> ExpectDataAsync() return http3WithPayload.Payload; } + internal async ValueTask> ExpectTrailersAsync() + { + var http3WithPayload = await ReceiveFrameAsync(false, true); + Http3InMemory.AssertFrameType(http3WithPayload.Type, Http3FrameType.Headers); + + _headerHandler.DecodedHeaders.Clear(); + _headerHandler.QpackDecoder.Decode(http3WithPayload.PayloadSequence, this); + _headerHandler.QpackDecoder.Reset(); + return _headerHandler.DecodedHeaders.ToDictionary(kvp => kvp.Key, kvp => kvp.Value, _headerHandler.DecodedHeaders.Comparer); + } + internal async Task ExpectReceiveEndOfStream() { var result = await ReadApplicationInputAsync(); diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2ConnectionTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2ConnectionTests.cs index 3016d57237f5..51d58036f5e4 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2ConnectionTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2ConnectionTests.cs @@ -252,23 +252,23 @@ public async Task ResponseTrailers_MultipleStreams_Reset() }; var requestCount = 0; + IHeaderDictionary trailersFirst = null; + IHeaderDictionary trailersLast = null; await InitializeConnectionAsync(context => { requestCount++; var trailersFeature = context.Features.Get(); - - IHeaderDictionary trailers; if (requestCount == 1) { - trailers = new ResponseTrailersWrapper(trailersFeature.Trailers); - trailersFeature.Trailers = trailers; + trailersFirst = new ResponseTrailersWrapper(trailersFeature.Trailers); + trailersFeature.Trailers = trailersFirst; } else { - trailers = trailersFeature.Trailers; + trailersLast = trailersFeature.Trailers; } - trailers["trailer-" + requestCount] = "true"; + trailersFeature.Trailers["trailer-" + requestCount] = "true"; return Task.CompletedTask; }); @@ -291,41 +291,55 @@ await ExpectAsync(Http2FrameType.HEADERS, _decodedHeaders.Clear(); - // Ping will trigger the stream to be returned to the pool so we can assert it - await SendPingAsync(Http2PingFrameFlags.NONE); - await ExpectAsync(Http2FrameType.PING, - withLength: 8, - withFlags: (byte)Http2PingFrameFlags.ACK, - withStreamId: 0); - await SendPingAsync(Http2PingFrameFlags.NONE); - await ExpectAsync(Http2FrameType.PING, - withLength: 8, - withFlags: (byte)Http2PingFrameFlags.ACK, - withStreamId: 0); + for (int i = 1; i < 3; i++) + { + int streamId = i * 2 + 1; + // Ping will trigger the stream to be returned to the pool so we can assert it + await PingAsync(); - // Stream has been returned to the pool - Assert.Equal(1, _connection.StreamPool.Count); + // Stream has been returned to the pool + Assert.Equal(1, _connection.StreamPool.Count); - await StartStreamAsync(3, requestHeaders, endStream: true); + await StartStreamAsync(streamId, requestHeaders, endStream: true); - await ExpectAsync(Http2FrameType.HEADERS, - withLength: 6, - withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS), - withStreamId: 3); + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 6, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS), + withStreamId: streamId); - trailersFrame = await ExpectAsync(Http2FrameType.HEADERS, - withLength: 16, - withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM), - withStreamId: 3); + trailersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 16, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM), + withStreamId: streamId); - _hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this); + _hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this); - Assert.Single(_decodedHeaders); - Assert.Equal("true", _decodedHeaders["trailer-2"]); + Assert.Single(_decodedHeaders); + Assert.Equal("true", _decodedHeaders[$"trailer-{i + 1}"]); - _decodedHeaders.Clear(); + _decodedHeaders.Clear(); - await StopConnectionAsync(expectedLastStreamId: 3, ignoreNonGoAwayFrames: false); + } + + Assert.NotNull(trailersFirst); + Assert.NotNull(trailersLast); + Assert.NotSame(trailersFirst, trailersLast); + + await StopConnectionAsync(expectedLastStreamId: 5, ignoreNonGoAwayFrames: false); + + async Task PingAsync() + { + await SendPingAsync(Http2PingFrameFlags.NONE); + await ExpectAsync(Http2FrameType.PING, + withLength: 8, + withFlags: (byte)Http2PingFrameFlags.ACK, + withStreamId: 0); + await SendPingAsync(Http2PingFrameFlags.NONE); + await ExpectAsync(Http2FrameType.PING, + withLength: 8, + withFlags: (byte)Http2PingFrameFlags.ACK, + withStreamId: 0); + } } [Fact] diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3ConnectionTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3ConnectionTests.cs index d75f0d65f343..27ee46bbaf02 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3ConnectionTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3ConnectionTests.cs @@ -3,6 +3,7 @@ using System; using System.Buffers; +using System.Collections; using System.Collections.Generic; using System.Globalization; using System.Net.Http; @@ -10,10 +11,12 @@ using System.Text.RegularExpressions; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3; using Microsoft.AspNetCore.Testing; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Primitives; using Microsoft.Net.Http.Headers; using Xunit; using Http3SettingType = Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3.Http3SettingType; @@ -393,6 +396,57 @@ public async Task VariableMultipleStreamsInSequence_Success(int count, bool send } } + [Fact] + public async Task ResponseTrailers_MultipleStreams_Reset() + { + var requestHeaders = new[] + { + new KeyValuePair(HeaderNames.Method, "GET"), + new KeyValuePair(HeaderNames.Path, "/hello"), + new KeyValuePair(HeaderNames.Scheme, "http"), + new KeyValuePair(HeaderNames.Authority, "localhost:80"), + new KeyValuePair(HeaderNames.ContentType, "application/json") + }; + + var requestCount = 0; + IHeaderDictionary trailersFirst = null; + IHeaderDictionary trailersLast = null; + await Http3Api.InitializeConnectionAsync(context => + { + var trailersFeature = context.Features.Get(); + if (requestCount == 0) + { + trailersFirst = new ResponseTrailersWrapper(trailersFeature.Trailers); + trailersFeature.Trailers = trailersFirst; + } + else + { + trailersLast = trailersFeature.Trailers; + } + trailersFeature.Trailers[$"trailer-{requestCount++}"] = "true"; + return Task.CompletedTask; + }); + + + for (int i = 0; i < 3; i++) + { + var requestStream = await Http3Api.CreateRequestStream(); + await requestStream.SendHeadersAsync(requestHeaders, endStream: true); + var responseHeaders = await requestStream.ExpectHeadersAsync(); + + var data = await requestStream.ExpectTrailersAsync(); + Assert.Single(data); + Assert.True(data.TryGetValue($"trailer-{i}", out var trailerValue) && bool.Parse(trailerValue)); + + await requestStream.ExpectReceiveEndOfStream(); + await requestStream.OnDisposedTask.DefaultTimeout(); + } + + Assert.NotNull(trailersFirst); + Assert.NotNull(trailersLast); + Assert.NotSame(trailersFirst, trailersLast); + } + private async Task MakeRequestAsync(int index, KeyValuePair[] headers, bool sendData, bool waitForServerDispose) { var requestStream = await Http3Api.CreateRequestStream(); @@ -431,5 +485,33 @@ private async Task MakeRequestAsync(int index, KeyValuePair _innerHeaders[key]; set => _innerHeaders[key] = value; } + public long? ContentLength { get => _innerHeaders.ContentLength; set => _innerHeaders.ContentLength = value; } + public ICollection Keys => _innerHeaders.Keys; + public ICollection Values => _innerHeaders.Values; + public int Count => _innerHeaders.Count; + public bool IsReadOnly => _innerHeaders.IsReadOnly; + public void Add(string key, StringValues value) => _innerHeaders.Add(key, value); + public void Add(KeyValuePair item) => _innerHeaders.Add(item); + public void Clear() => _innerHeaders.Clear(); + public bool Contains(KeyValuePair item) => _innerHeaders.Contains(item); + public bool ContainsKey(string key) => _innerHeaders.ContainsKey(key); + public void CopyTo(KeyValuePair[] array, int arrayIndex) => _innerHeaders.CopyTo(array, arrayIndex); + public IEnumerator> GetEnumerator() => _innerHeaders.GetEnumerator(); + public bool Remove(string key) => _innerHeaders.Remove(key); + public bool Remove(KeyValuePair item) => _innerHeaders.Remove(item); + public bool TryGetValue(string key, out StringValues value) => _innerHeaders.TryGetValue(key, out value); + IEnumerator IEnumerable.GetEnumerator() => _innerHeaders.GetEnumerator(); + } } }