diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs index f85940605..8312e88be 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.Globalization; using System.IO; using System.Linq; using System.Net; @@ -75,7 +76,7 @@ public abstract partial class Frame : IFrameControl protected readonly long _keepAliveMilliseconds; private readonly long _requestHeadersTimeoutMilliseconds; - private int _responseBytesWritten; + protected long _responseBytesWritten; public Frame(ConnectionContext context) { @@ -516,8 +517,8 @@ public async Task FlushAsync(CancellationToken cancellationToken) public void Write(ArraySegment data) { + VerifyAndUpdateWrite(data.Count); ProduceStartAndFireOnStarting().GetAwaiter().GetResult(); - _responseBytesWritten += data.Count; if (_canHaveBody) { @@ -547,7 +548,7 @@ public Task WriteAsync(ArraySegment data, CancellationToken cancellationTo return WriteAsyncAwaited(data, cancellationToken); } - _responseBytesWritten += data.Count; + VerifyAndUpdateWrite(data.Count); if (_canHaveBody) { @@ -573,8 +574,9 @@ public Task WriteAsync(ArraySegment data, CancellationToken cancellationTo public async Task WriteAsyncAwaited(ArraySegment data, CancellationToken cancellationToken) { + VerifyAndUpdateWrite(data.Count); + await ProduceStartAndFireOnStarting(); - _responseBytesWritten += data.Count; if (_canHaveBody) { @@ -598,6 +600,23 @@ public async Task WriteAsyncAwaited(ArraySegment data, CancellationToken c } } + private void VerifyAndUpdateWrite(int count) + { + var responseHeaders = FrameResponseHeaders; + + if (responseHeaders != null && + !responseHeaders.HasTransferEncoding && + responseHeaders.HasContentLength && + _responseBytesWritten + count > responseHeaders.HeaderContentLengthValue.Value) + { + _keepAlive = false; + throw new InvalidOperationException( + $"Response Content-Length mismatch: too many bytes written ({_responseBytesWritten + count} of {responseHeaders.HeaderContentLengthValue.Value})."); + } + + _responseBytesWritten += count; + } + private void WriteChunked(ArraySegment data) { SocketOutput.Write(data, chunk: true); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameHeaders.Generated.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameHeaders.Generated.cs index d37c1c8cb..88fa8d6ac 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameHeaders.Generated.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameHeaders.Generated.cs @@ -3697,6 +3697,7 @@ protected override void ClearFast() { _bits = 0; _headers = default(HeaderReferences); + MaybeUnknown?.Clear(); } @@ -5670,6 +5671,7 @@ public StringValues HeaderContentLength } set { + _contentLength = ParseContentLength(value); _bits |= 2048L; _headers._ContentLength = value; _headers._rawContentLength = null; @@ -7384,6 +7386,7 @@ protected override void SetValueFast(string key, StringValues value) { if ("Content-Length".Equals(key, StringComparison.OrdinalIgnoreCase)) { + _contentLength = ParseContentLength(value); _bits |= 2048L; _headers._ContentLength = value; _headers._rawContentLength = null; @@ -7809,6 +7812,7 @@ protected override void AddValueFast(string key, StringValues value) { ThrowDuplicateKeyException(); } + _contentLength = ParseContentLength(value); _bits |= 2048L; _headers._ContentLength = value; _headers._rawContentLength = null; @@ -8350,6 +8354,7 @@ protected override bool RemoveFast(string key) { if (((_bits & 2048L) != 0)) { + _contentLength = null; _bits &= ~2048L; _headers._ContentLength = StringValues.Empty; _headers._rawContentLength = null; @@ -8601,6 +8606,7 @@ protected override void ClearFast() { _bits = 0; _headers = default(HeaderReferences); + _contentLength = null; MaybeUnknown?.Clear(); } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameHeaders.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameHeaders.cs index ac283a006..c6b9d0c59 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameHeaders.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameHeaders.cs @@ -4,6 +4,7 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Globalization; using System.Linq; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Primitives; @@ -232,6 +233,18 @@ public static void ValidateHeaderCharacters(string headerCharacters) } } + public static long ParseContentLength(StringValues value) + { + try + { + return long.Parse(value, NumberStyles.AllowLeadingWhite | NumberStyles.AllowTrailingWhite, CultureInfo.InvariantCulture); + } + catch (FormatException ex) + { + throw new InvalidOperationException("Content-Length value must be an integral number.", ex); + } + } + private static void ThrowInvalidHeaderCharacter(char ch) { throw new InvalidOperationException(string.Format("Invalid non-ASCII or control character in header: 0x{0:X4}", (ushort)ch)); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameOfT.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameOfT.cs index eeba9695d..5fc2a56d2 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameOfT.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameOfT.cs @@ -92,6 +92,16 @@ public override async Task RequestProcessingAsync() try { await _application.ProcessRequestAsync(context).ConfigureAwait(false); + + var responseHeaders = FrameResponseHeaders; + if (!responseHeaders.HasTransferEncoding && + responseHeaders.HasContentLength && + _responseBytesWritten < responseHeaders.HeaderContentLengthValue.Value) + { + _keepAlive = false; + ReportApplicationError(new InvalidOperationException( + $"Response Content-Length mismatch: too few bytes written ({_responseBytesWritten} of {responseHeaders.HeaderContentLengthValue.Value}).")); + } } catch (Exception ex) { diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameResponseHeaders.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameResponseHeaders.cs index 9ea8056f2..ccb5e0d46 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameResponseHeaders.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameResponseHeaders.cs @@ -13,6 +13,8 @@ public partial class FrameResponseHeaders : FrameHeaders private static readonly byte[] _CrLf = new[] { (byte)'\r', (byte)'\n' }; private static readonly byte[] _colonSpace = new[] { (byte)':', (byte)' ' }; + private long? _contentLength; + public bool HasConnection => HeaderConnection.Count != 0; public bool HasTransferEncoding => HeaderTransferEncoding.Count != 0; @@ -23,6 +25,8 @@ public partial class FrameResponseHeaders : FrameHeaders public bool HasDate => HeaderDate.Count != 0; + public long? HeaderContentLengthValue => _contentLength; + public Enumerator GetEnumerator() { return new Enumerator(this); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/IKestrelTrace.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/IKestrelTrace.cs index 900adc876..2be2acbfc 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/IKestrelTrace.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/IKestrelTrace.cs @@ -33,7 +33,7 @@ public interface IKestrelTrace : ILogger void ConnectionDisconnectedWrite(string connectionId, int count, Exception ex); - void ConnectionHeadResponseBodyWrite(string connectionId, int count); + void ConnectionHeadResponseBodyWrite(string connectionId, long count); void ConnectionBadRequest(string connectionId, BadHttpRequestException ex); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/KestrelTrace.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/KestrelTrace.cs index 5d3ecff7b..fcdc3f920 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/KestrelTrace.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/KestrelTrace.cs @@ -24,7 +24,7 @@ public class KestrelTrace : IKestrelTrace private static readonly Action _applicationError; private static readonly Action _connectionError; private static readonly Action _connectionDisconnectedWrite; - private static readonly Action _connectionHeadResponseBodyWrite; + private static readonly Action _connectionHeadResponseBodyWrite; private static readonly Action _notAllConnectionsClosedGracefully; private static readonly Action _connectionBadRequest; @@ -49,7 +49,7 @@ static KestrelTrace() _connectionDisconnectedWrite = LoggerMessage.Define(LogLevel.Debug, 15, @"Connection id ""{ConnectionId}"" write of ""{count}"" bytes to disconnected client."); _notAllConnectionsClosedGracefully = LoggerMessage.Define(LogLevel.Debug, 16, "Some connections failed to close gracefully during server shutdown."); _connectionBadRequest = LoggerMessage.Define(LogLevel.Information, 17, @"Connection id ""{ConnectionId}"" bad request data: ""{message}"""); - _connectionHeadResponseBodyWrite = LoggerMessage.Define(LogLevel.Debug, 18, @"Connection id ""{ConnectionId}"" write of ""{count}"" body bytes to non-body HEAD response."); + _connectionHeadResponseBodyWrite = LoggerMessage.Define(LogLevel.Debug, 18, @"Connection id ""{ConnectionId}"" write of ""{count}"" body bytes to non-body HEAD response."); } public KestrelTrace(ILogger logger) @@ -135,7 +135,7 @@ public virtual void ConnectionDisconnectedWrite(string connectionId, int count, _connectionDisconnectedWrite(_logger, connectionId, count, ex); } - public virtual void ConnectionHeadResponseBodyWrite(string connectionId, int count) + public virtual void ConnectionHeadResponseBodyWrite(string connectionId, long count) { _connectionHeadResponseBodyWrite(_logger, connectionId, count, null); } diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs index bcd2e7ecf..57e753d07 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO; using System.Linq; using System.Net; using System.Net.Http; @@ -14,6 +15,7 @@ using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure; using Microsoft.AspNetCore.Testing; using Microsoft.Extensions.Internal; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.Primitives; using Moq; using Xunit; @@ -85,7 +87,7 @@ public async Task IgnoreNullHeaderValues(string headerName, StringValues headerV app.Run(async context => { context.Response.Headers.Add(headerName, headerValue); - + await context.Response.WriteAsync(""); }); }); @@ -299,7 +301,7 @@ await connection.Receive( } [Fact] - public async Task ResponseBodyNotWrittenOnHeadResponse() + public async Task ResponseBodyNotWrittenOnHeadResponseAndLoggedOnlyOnce() { var mockKestrelTrace = new Mock(); @@ -324,7 +326,285 @@ await connection.Receive( } mockKestrelTrace.Verify(kestrelTrace => - kestrelTrace.ConnectionHeadResponseBodyWrite(It.IsAny(), "hello, world".Length)); + kestrelTrace.ConnectionHeadResponseBodyWrite(It.IsAny(), "hello, world".Length), Times.Once); + } + + [Fact] + public async Task WhenAppWritesMoreThanContentLengthWriteThrowsAndConnectionCloses() + { + var testLogger = new TestApplicationErrorLogger(); + var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; + + using (var server = new TestServer(httpContext => + { + httpContext.Response.ContentLength = 11; + httpContext.Response.Body.Write(Encoding.ASCII.GetBytes("hello,"), 0, 6); + httpContext.Response.Body.Write(Encoding.ASCII.GetBytes(" world"), 0, 6); + return TaskCache.CompletedTask; + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.ReceiveEnd( + $"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 11", + "", + "hello,"); + } + } + + var logMessage = Assert.Single(testLogger.Messages, message => message.LogLevel == LogLevel.Error); + Assert.Equal( + $"Response Content-Length mismatch: too many bytes written (12 of 11).", + logMessage.Exception.Message); + } + + [Fact] + public async Task WhenAppWritesMoreThanContentLengthWriteAsyncThrowsAndConnectionCloses() + { + var testLogger = new TestApplicationErrorLogger(); + var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; + + using (var server = new TestServer(async httpContext => + { + httpContext.Response.ContentLength = 11; + await httpContext.Response.WriteAsync("hello,"); + await httpContext.Response.WriteAsync(" world"); + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.ReceiveEnd( + $"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 11", + "", + "hello,"); + } + } + + var logMessage = Assert.Single(testLogger.Messages, message => message.LogLevel == LogLevel.Error); + Assert.Equal( + $"Response Content-Length mismatch: too many bytes written (12 of 11).", + logMessage.Exception.Message); + } + + [Fact] + public async Task WhenAppWritesMoreThanContentLengthAndResponseNotStarted500ResponseSentAndConnectionCloses() + { + var testLogger = new TestApplicationErrorLogger(); + var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; + + using (var server = new TestServer(async httpContext => + { + httpContext.Response.ContentLength = 5; + await httpContext.Response.WriteAsync("hello, world"); + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.ReceiveEnd( + $"HTTP/1.1 500 Internal Server Error", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + var logMessage = Assert.Single(testLogger.Messages, message => message.LogLevel == LogLevel.Error); + Assert.Equal( + $"Response Content-Length mismatch: too many bytes written (12 of 5).", + logMessage.Exception.Message); + } + + [Fact] + public async Task WhenAppWritesLessThanContentLengthErrorLogged() + { + var testLogger = new TestApplicationErrorLogger(); + var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; + + using (var server = new TestServer(async httpContext => + { + httpContext.Response.ContentLength = 13; + await httpContext.Response.WriteAsync("hello, world"); + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.ReceiveEnd( + $"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 13", + "", + "hello, world"); + } + } + + var errorMessage = Assert.Single(testLogger.Messages, message => message.LogLevel == LogLevel.Error); + Assert.Equal( + $"Response Content-Length mismatch: too few bytes written (12 of 13).", + errorMessage.Exception.Message); + } + + [Fact] + public async Task WhenAppSetsContentLengthButDoesNotWriteBody500ResponseSentAndConnectionCloses() + { + var testLogger = new TestApplicationErrorLogger(); + var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; + + using (var server = new TestServer(httpContext => + { + httpContext.Response.ContentLength = 5; + return TaskCache.CompletedTask; + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.ReceiveEnd( + $"HTTP/1.1 500 Internal Server Error", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + var errorMessage = Assert.Single(testLogger.Messages, message => message.LogLevel == LogLevel.Error); + Assert.Equal( + $"Response Content-Length mismatch: too few bytes written (0 of 5).", + errorMessage.Exception.Message); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task WhenAppSetsContentLengthToZeroAndDoesNotWriteNoErrorIsThrown(bool flushResponse) + { + var testLogger = new TestApplicationErrorLogger(); + var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; + + using (var server = new TestServer(async httpContext => + { + httpContext.Response.ContentLength = 0; + + if (flushResponse) + { + await httpContext.Response.Body.FlushAsync(); + } + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.Receive( + $"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + Assert.Equal(0, testLogger.ApplicationErrorsLogged); + } + + // https://tools.ietf.org/html/rfc7230#section-3.3.3 + // If a message is received with both a Transfer-Encoding and a + // Content-Length header field, the Transfer-Encoding overrides the + // Content-Length. + [Fact] + public async Task WhenAppSetsTransferEncodingAndContentLengthWritingLessIsNotAnError() + { + var testLogger = new TestApplicationErrorLogger(); + var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; + + using (var server = new TestServer(async httpContext => + { + httpContext.Response.Headers["Transfer-Encoding"] = "chunked"; + httpContext.Response.ContentLength = 13; + await httpContext.Response.WriteAsync("hello, world"); + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.Receive( + $"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Transfer-Encoding: chunked", + "Content-Length: 13", + "", + "hello, world"); + } + } + + Assert.Equal(0, testLogger.ApplicationErrorsLogged); + } + + // https://tools.ietf.org/html/rfc7230#section-3.3.3 + // If a message is received with both a Transfer-Encoding and a + // Content-Length header field, the Transfer-Encoding overrides the + // Content-Length. + [Fact] + public async Task WhenAppSetsTransferEncodingAndContentLengthWritingMoreIsNotAnError() + { + var testLogger = new TestApplicationErrorLogger(); + var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; + + using (var server = new TestServer(async httpContext => + { + httpContext.Response.Headers["Transfer-Encoding"] = "chunked"; + httpContext.Response.ContentLength = 11; + await httpContext.Response.WriteAsync("hello, world"); + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.Receive( + $"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Transfer-Encoding: chunked", + "Content-Length: 11", + "", + "hello, world"); + } + } + + Assert.Equal(0, testLogger.ApplicationErrorsLogged); } public static TheoryData NullHeaderData diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs index a23a063d1..f9a8b8677 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs @@ -620,35 +620,6 @@ await connection.ReceiveEnd( } } - [Theory] - [MemberData(nameof(ConnectionFilterData))] - public async Task WriteOnHeadResponseLoggedOnlyOnce(TestServiceContext testContext) - { - using (var server = new TestServer(async httpContext => - { - await httpContext.Response.WriteAsync("hello, "); - await httpContext.Response.WriteAsync("world"); - await httpContext.Response.WriteAsync("!"); - }, testContext)) - { - using (var connection = server.CreateConnection()) - { - await connection.SendEnd( - "HEAD / HTTP/1.1", - "", - ""); - await connection.ReceiveEnd( - "HTTP/1.1 200 OK", - $"Date: {testContext.DateHeaderValue}", - "", - ""); - } - - Assert.Equal(1, ((TestKestrelTrace)testContext.Log).HeadResponseWrites); - Assert.Equal(13, ((TestKestrelTrace)testContext.Log).HeadResponseWriteByteCount); - } - } - [Theory] [MemberData(nameof(ConnectionFilterData))] public async Task ThrowingResultsIn500Response(TestServiceContext testContext) @@ -697,11 +668,11 @@ await connection.ReceiveEnd( "Content-Length: 0", "", ""); - - Assert.False(onStartingCalled); - Assert.Equal(2, testLogger.ApplicationErrorsLogged); } } + + Assert.False(onStartingCalled); + Assert.Equal(2, testLogger.ApplicationErrorsLogged); } [Theory] @@ -739,11 +710,11 @@ await connection.ReceiveForcedEnd( "Content-Length: 11", "", "Hello World"); - - Assert.True(onStartingCalled); - Assert.Equal(1, testLogger.ApplicationErrorsLogged); } } + + Assert.True(onStartingCalled); + Assert.Equal(1, testLogger.ApplicationErrorsLogged); } [Theory] @@ -781,11 +752,11 @@ await connection.ReceiveForcedEnd( "Content-Length: 11", "", "Hello"); - - Assert.True(onStartingCalled); - Assert.Equal(1, testLogger.ApplicationErrorsLogged); } } + + Assert.True(onStartingCalled); + Assert.Equal(1, testLogger.ApplicationErrorsLogged); } [Theory] @@ -925,16 +896,14 @@ await connection.ReceiveEnd( "Content-Length: 0", "", ""); - - Assert.Equal(2, onStartingCallCount2); - - // The first registered OnStarting callback should not be called, - // since they are called LIFO and the other one failed. - Assert.Equal(0, onStartingCallCount1); - - Assert.Equal(2, testLogger.ApplicationErrorsLogged); } } + + // The first registered OnStarting callback should not be called, + // since they are called LIFO and the other one failed. + Assert.Equal(0, onStartingCallCount1); + Assert.Equal(2, onStartingCallCount2); + Assert.Equal(2, testLogger.ApplicationErrorsLogged); } [Theory] @@ -979,12 +948,12 @@ await connection.ReceiveForcedEnd( "", "Hello World"); } - - // All OnCompleted callbacks should be called even if they throw. - Assert.Equal(2, testLogger.ApplicationErrorsLogged); - Assert.True(onCompletedCalled1); - Assert.True(onCompletedCalled2); } + + // All OnCompleted callbacks should be called even if they throw. + Assert.Equal(2, testLogger.ApplicationErrorsLogged); + Assert.True(onCompletedCalled1); + Assert.True(onCompletedCalled2); } [Theory] diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/FrameResponseHeadersTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/FrameResponseHeadersTests.cs index c57ddff25..ff27b37cf 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/FrameResponseHeadersTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/FrameResponseHeadersTests.cs @@ -78,24 +78,29 @@ public void AddingControlOrNonAsciiCharactersToHeadersThrows(string key, string { var responseHeaders = new FrameResponseHeaders(); - Assert.Throws(() => { + Assert.Throws(() => + { ((IHeaderDictionary)responseHeaders)[key] = value; }); - Assert.Throws(() => { + Assert.Throws(() => + { ((IHeaderDictionary)responseHeaders)[key] = new StringValues(new[] { "valid", value }); }); - Assert.Throws(() => { + Assert.Throws(() => + { ((IDictionary)responseHeaders)[key] = value; }); - Assert.Throws(() => { + Assert.Throws(() => + { var kvp = new KeyValuePair(key, value); ((ICollection>)responseHeaders).Add(kvp); }); - Assert.Throws(() => { + Assert.Throws(() => + { var kvp = new KeyValuePair(key, value); ((IDictionary)responseHeaders).Add(key, value); }); @@ -142,5 +147,83 @@ public void ThrowsWhenClearingHeadersAfterReadOnlyIsSet() Assert.Throws(() => dictionary.Clear()); } + + [Fact] + public void ThrowsWhenAddingContentLengthWithNonNumericValue() + { + var headers = new FrameResponseHeaders(); + var dictionary = (IDictionary)headers; + + Assert.Throws(() => dictionary.Add("Content-Length", new[] { "bad" })); + } + + [Fact] + public void ThrowsWhenSettingContentLengthToNonNumericValue() + { + var headers = new FrameResponseHeaders(); + var dictionary = (IDictionary)headers; + + Assert.Throws(() => ((IHeaderDictionary)headers)["Content-Length"] = "bad"); + } + + [Fact] + public void ThrowsWhenAssigningHeaderContentLengthToNonNumericValue() + { + var headers = new FrameResponseHeaders(); + Assert.Throws(() => headers.HeaderContentLength = "bad"); + } + + [Fact] + public void ContentLengthValueCanBeReadAsLongAfterAddingHeader() + { + var headers = new FrameResponseHeaders(); + var dictionary = (IDictionary)headers; + dictionary.Add("Content-Length", "42"); + + Assert.Equal(42, headers.HeaderContentLengthValue); + } + + [Fact] + public void ContentLengthValueCanBeReadAsLongAfterSettingHeader() + { + var headers = new FrameResponseHeaders(); + var dictionary = (IDictionary)headers; + dictionary["Content-Length"] = "42"; + + Assert.Equal(42, headers.HeaderContentLengthValue); + } + + [Fact] + public void ContentLengthValueCanBeReadAsLongAfterAssigningHeader() + { + var headers = new FrameResponseHeaders(); + headers.HeaderContentLength = "42"; + + Assert.Equal(42, headers.HeaderContentLengthValue); + } + + [Fact] + public void ContentLengthValueClearedWhenHeaderIsRemoved() + { + var headers = new FrameResponseHeaders(); + headers.HeaderContentLength = "42"; + var dictionary = (IDictionary)headers; + + dictionary.Remove("Content-Length"); + + Assert.Equal(null, headers.HeaderContentLengthValue); + } + + [Fact] + public void ContentLengthValueClearedWhenHeadersCleared() + { + var headers = new FrameResponseHeaders(); + headers.HeaderContentLength = "42"; + var dictionary = (IDictionary)headers; + + dictionary.Clear(); + + Assert.Equal(null, headers.HeaderContentLengthValue); + } } } diff --git a/test/shared/TestApplicationErrorLogger.cs b/test/shared/TestApplicationErrorLogger.cs index d2d3731a9..5036f1cec 100644 --- a/test/shared/TestApplicationErrorLogger.cs +++ b/test/shared/TestApplicationErrorLogger.cs @@ -2,6 +2,8 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; +using System.Linq; using Microsoft.AspNetCore.Server.Kestrel.Internal; using Microsoft.Extensions.Logging; @@ -12,11 +14,13 @@ public class TestApplicationErrorLogger : ILogger // Application errors are logged using 13 as the eventId. private const int ApplicationErrorEventId = 13; - public int TotalErrorsLogged { get; set; } + public List Messages { get; } = new List(); - public int CriticalErrorsLogged { get; set; } + public int TotalErrorsLogged => Messages.Count(message => message.LogLevel == LogLevel.Error); - public int ApplicationErrorsLogged { get; set; } + public int CriticalErrorsLogged => Messages.Count(message => message.LogLevel == LogLevel.Critical); + + public int ApplicationErrorsLogged => Messages.Count(message => message.EventId.Id == ApplicationErrorEventId); public IDisposable BeginScope(TState state) { @@ -34,20 +38,14 @@ public void Log(LogLevel logLevel, EventId eventId, TState state, Except Console.WriteLine($"Log {logLevel}[{eventId}]: {formatter(state, exception)} {exception?.Message}"); #endif - if (eventId.Id == ApplicationErrorEventId) - { - ApplicationErrorsLogged++; - } - - if (logLevel == LogLevel.Error) - { - TotalErrorsLogged++; - } + Messages.Add(new LogMessage { LogLevel = logLevel, EventId = eventId, Exception = exception }); + } - if (logLevel == LogLevel.Critical) - { - CriticalErrorsLogged++; - } + public class LogMessage + { + public LogLevel LogLevel { get; set; } + public EventId EventId { get; set; } + public Exception Exception { get; set; } } } } diff --git a/test/shared/TestKestrelTrace.cs b/test/shared/TestKestrelTrace.cs index 63dbfc0f7..814005d4d 100644 --- a/test/shared/TestKestrelTrace.cs +++ b/test/shared/TestKestrelTrace.cs @@ -13,10 +13,6 @@ public TestKestrelTrace(ILogger testLogger) : base(testLogger) { } - public int HeadResponseWrites { get; set; } - - public int HeadResponseWriteByteCount { get; set; } - public override void ConnectionRead(string connectionId, int count) { //_logger.LogDebug(1, @"Connection id ""{ConnectionId}"" recv {count} bytes.", connectionId, count); @@ -31,11 +27,5 @@ public override void ConnectionWriteCallback(string connectionId, int status) { //_logger.LogDebug(1, @"Connection id ""{ConnectionId}"" send finished with status {status}.", connectionId, status); } - - public override void ConnectionHeadResponseBodyWrite(string connectionId, int count) - { - HeadResponseWrites++; - HeadResponseWriteByteCount = count; - } } } \ No newline at end of file diff --git a/tools/Microsoft.AspNetCore.Server.Kestrel.GeneratedCode/KnownHeaders.cs b/tools/Microsoft.AspNetCore.Server.Kestrel.GeneratedCode/KnownHeaders.cs index bb4f4485a..7d8b57256 100644 --- a/tools/Microsoft.AspNetCore.Server.Kestrel.GeneratedCode/KnownHeaders.cs +++ b/tools/Microsoft.AspNetCore.Server.Kestrel.GeneratedCode/KnownHeaders.cs @@ -14,6 +14,11 @@ static string Each(IEnumerable values, Func formatter) return values.Any() ? values.Select(formatter).Aggregate((a, b) => a + b) : ""; } + static string If(bool condition, Func formatter) + { + return condition ? formatter() : ""; + } + class KnownHeader { public string Name { get; set; } @@ -228,7 +233,8 @@ public StringValues Header{header.Identifier} return StringValues.Empty; }} set - {{ + {{{If(loop.ClassName == "FrameResponseHeaders" && header.Identifier == "ContentLength", () => @" + _contentLength = ParseContentLength(value);")} {header.SetBit()}; _headers._{header.Identifier} = value; {(header.EnhancedSetter == false ? "" : $@" _headers._raw{header.Identifier} = null;")} @@ -304,7 +310,8 @@ protected override void SetValueFast(string key, StringValues value) case {byLength.Key}: {{{Each(byLength, header => $@" if (""{header.Name}"".Equals(key, StringComparison.OrdinalIgnoreCase)) - {{ + {{{If(loop.ClassName == "FrameResponseHeaders" && header.Identifier == "ContentLength", () => @" + _contentLength = ParseContentLength(value);")} {header.SetBit()}; _headers._{header.Identifier} = value;{(header.EnhancedSetter == false ? "" : $@" _headers._raw{header.Identifier} = null;")} @@ -328,7 +335,9 @@ protected override void AddValueFast(string key, StringValues value) if ({header.TestBit()}) {{ ThrowDuplicateKeyException(); - }} + }}{ + If(loop.ClassName == "FrameResponseHeaders" && header.Identifier == "ContentLength", () => @" + _contentLength = ParseContentLength(value);")} {header.SetBit()}; _headers._{header.Identifier} = value;{(header.EnhancedSetter == false ? "" : $@" _headers._raw{header.Identifier} = null;")} @@ -349,7 +358,8 @@ protected override bool RemoveFast(string key) if (""{header.Name}"".Equals(key, StringComparison.OrdinalIgnoreCase)) {{ if ({header.TestBit()}) - {{ + {{{If(loop.ClassName == "FrameResponseHeaders" && header.Identifier == "ContentLength", () => @" + _contentLength = null;")} {header.ClearBit()}; _headers._{header.Identifier} = StringValues.Empty;{(header.EnhancedSetter == false ? "" : $@" _headers._raw{header.Identifier} = null;")} @@ -369,6 +379,7 @@ protected override void ClearFast() {{ _bits = 0; _headers = default(HeaderReferences); + {(loop.ClassName == "FrameResponseHeaders" ? "_contentLength = null;" : "")} MaybeUnknown?.Clear(); }} @@ -435,7 +446,8 @@ public unsafe void Append(byte[] keyBytes, int keyOffset, int keyLength, string _headers._{header.Identifier} = AppendValue(_headers._{header.Identifier}, value); }} else - {{ + {{{If(loop.ClassName == "FrameResponseHeaders" && header.Identifier == "ContentLength", () => @" + _contentLength = ParseContentLength(value);")} {header.SetBit()}; _headers._{header.Identifier} = new StringValues(value);{(header.EnhancedSetter == false ? "" : $@" _headers._raw{header.Identifier} = null;")}