diff --git a/src/libraries/Common/src/Interop/Windows/WinHttp/Interop.winhttp_types.cs b/src/libraries/Common/src/Interop/Windows/WinHttp/Interop.winhttp_types.cs index 27213cc2a0ed39..99bcc38f0a7159 100644 --- a/src/libraries/Common/src/Interop/Windows/WinHttp/Interop.winhttp_types.cs +++ b/src/libraries/Common/src/Interop/Windows/WinHttp/Interop.winhttp_types.cs @@ -168,6 +168,7 @@ internal static partial class WinHttp public const uint WINHTTP_OPTION_TCP_KEEPALIVE = 152; public const uint WINHTTP_OPTION_STREAM_ERROR_CODE = 159; + public const uint WINHTTP_OPTION_REQUIRE_STREAM_END = 160; public enum WINHTTP_WEB_SOCKET_BUFFER_TYPE { diff --git a/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpHandler.cs b/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpHandler.cs index 3b4c247b8b7032..07133b2667c229 100644 --- a/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpHandler.cs +++ b/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpHandler.cs @@ -1120,6 +1120,7 @@ private void SetSessionHandleOptions(SafeWinHttpHandle sessionHandle) SetSessionHandleTimeoutOptions(sessionHandle); SetDisableHttp2StreamQueue(sessionHandle); SetTcpKeepalive(sessionHandle); + SetRequireStreamEnd(sessionHandle); } private unsafe void SetTcpKeepalive(SafeWinHttpHandle sessionHandle) @@ -1145,6 +1146,27 @@ private unsafe void SetTcpKeepalive(SafeWinHttpHandle sessionHandle) } } + private void SetRequireStreamEnd(SafeWinHttpHandle sessionHandle) + { + if (WinHttpTrailersHelper.OsSupportsTrailers) + { + // Setting WINHTTP_OPTION_REQUIRE_STREAM_END to TRUE is needed for WinHttp to read trailing headers + // in case the response has Content-Lenght defined. + // According to the WinHttp team, the feature-detection logic in WinHttpTrailersHelper.OsSupportsTrailers + // should also indicate the support of WINHTTP_OPTION_REQUIRE_STREAM_END. + // WINHTTP_OPTION_REQUIRE_STREAM_END doesn't have effect on HTTP 1.1 requests, therefore it's safe to set it on + // the session handle so it is inhereted by all request handles. + uint optionData = 1; + if (!Interop.WinHttp.WinHttpSetOption(sessionHandle, Interop.WinHttp.WINHTTP_OPTION_REQUIRE_STREAM_END, ref optionData)) + { + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Info(this, "Failed to enable WINHTTP_OPTION_REQUIRE_STREAM_END error code: " + Marshal.GetLastWin32Error()); + } + } + } + } + private void SetSessionHandleConnectionOptions(SafeWinHttpHandle sessionHandle) { uint optionData = (uint)_maxConnectionsPerServer; diff --git a/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/TrailingHeadersTest.cs b/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/TrailingHeadersTest.cs index cef52465524dcf..5ac753c266ee0a 100644 --- a/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/TrailingHeadersTest.cs +++ b/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/TrailingHeadersTest.cs @@ -67,8 +67,10 @@ public async Task Http2GetAsync_NoTrailingHeaders_EmptyCollection() } } - [ConditionalFact(nameof(TestsEnabled))] - public async Task Http2GetAsync_MissingTrailer_TrailingHeadersAccepted() + [InlineData(false)] + [InlineData(true)] + [ConditionalTheory(nameof(TestsEnabled))] + public async Task Http2GetAsync_MissingTrailer_TrailingHeadersAccepted(bool responseHasContentLength) { using (Http2LoopbackServer server = Http2LoopbackServer.CreateServer()) using (HttpClient client = CreateHttpClient()) @@ -80,7 +82,14 @@ public async Task Http2GetAsync_MissingTrailer_TrailingHeadersAccepted() int streamId = await connection.ReadRequestHeaderAsync(); // Response header. - await connection.SendDefaultResponseHeadersAsync(streamId); + if (responseHasContentLength) + { + await connection.SendResponseHeadersAsync(streamId, endStream: false, headers: new[] { new HttpHeaderData("Content-Length", DataBytes.Length.ToString()) }); + } + else + { + await connection.SendDefaultResponseHeadersAsync(streamId); + } // Response data, missing Trailers. await connection.WriteFrameAsync(MakeDataFrame(streamId, DataBytes)); @@ -98,8 +107,10 @@ public async Task Http2GetAsync_MissingTrailer_TrailingHeadersAccepted() } } - [ConditionalFact(nameof(TestsEnabled))] - public async Task Http2GetAsyncResponseHeadersReadOption_TrailingHeaders_Available() + [InlineData(false)] + [InlineData(true)] + [ConditionalTheory(nameof(TestsEnabled))] + public async Task Http2GetAsyncResponseHeadersReadOption_TrailingHeaders_Available(bool responseHasContentLength) { using (Http2LoopbackServer server = Http2LoopbackServer.CreateServer()) using (HttpClient client = CreateHttpClient()) @@ -111,7 +122,14 @@ public async Task Http2GetAsyncResponseHeadersReadOption_TrailingHeaders_Availab int streamId = await connection.ReadRequestHeaderAsync(); // Response header. - await connection.SendDefaultResponseHeadersAsync(streamId); + if (responseHasContentLength) + { + await connection.SendResponseHeadersAsync(streamId, endStream: false, headers: new[] { new HttpHeaderData("Content-Length", DataBytes.Length.ToString()) }); + } + else + { + await connection.SendDefaultResponseHeadersAsync(streamId); + } // Response data, missing Trailers. await connection.WriteFrameAsync(MakeDataFrame(streamId, DataBytes)); diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs index 14f444b1935a60..bfd1e83bfad502 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs @@ -837,8 +837,10 @@ public async Task Http2GetAsync_NoTrailingHeaders_EmptyCollection() } } - [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))] - public async Task Http2GetAsync_MissingTrailer_TrailingHeadersAccepted() + [InlineData(false)] + [InlineData(true)] + [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))] + public async Task Http2GetAsync_MissingTrailer_TrailingHeadersAccepted(bool responseHasContentLength) { using (Http2LoopbackServer server = Http2LoopbackServer.CreateServer()) using (HttpClient client = CreateHttpClient()) @@ -850,7 +852,14 @@ public async Task Http2GetAsync_MissingTrailer_TrailingHeadersAccepted() int streamId = await connection.ReadRequestHeaderAsync(); // Response header. - await connection.SendDefaultResponseHeadersAsync(streamId); + if (responseHasContentLength) + { + await connection.SendResponseHeadersAsync(streamId, endStream: false, headers: new[] { new HttpHeaderData("Content-Length", DataBytes.Length.ToString()) }); + } + else + { + await connection.SendDefaultResponseHeadersAsync(streamId); + } // Response data, missing Trailers. await connection.WriteFrameAsync(MakeDataFrame(streamId, DataBytes)); @@ -888,8 +897,10 @@ public async Task Http2GetAsync_TrailerHeaders_TrailingPseudoHeadersThrow() } } - [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))] - public async Task Http2GetAsyncResponseHeadersReadOption_TrailingHeaders_Available() + [InlineData(false)] + [InlineData(true)] + [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.SupportsAlpn))] + public async Task Http2GetAsyncResponseHeadersReadOption_TrailingHeaders_Available(bool responseHasContentLength) { using (Http2LoopbackServer server = Http2LoopbackServer.CreateServer()) using (HttpClient client = CreateHttpClient()) @@ -901,7 +912,14 @@ public async Task Http2GetAsyncResponseHeadersReadOption_TrailingHeaders_Availab int streamId = await connection.ReadRequestHeaderAsync(); // Response header. - await connection.SendDefaultResponseHeadersAsync(streamId); + if (responseHasContentLength) + { + await connection.SendResponseHeadersAsync(streamId, endStream: false, headers: new[] { new HttpHeaderData("Content-Length", DataBytes.Length.ToString()) }); + } + else + { + await connection.SendDefaultResponseHeadersAsync(streamId); + } // Response data, missing Trailers. await connection.WriteFrameAsync(MakeDataFrame(streamId, DataBytes));