Skip to content
This repository was archived by the owner on Dec 18, 2018. It is now read-only.

Handle response content length mismatches (#175) #1155

Merged
merged 1 commit into from
Oct 11, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Net;
Expand Down Expand Up @@ -75,7 +76,7 @@ public abstract partial class Frame : IFrameControl
protected readonly long _keepAliveMilliseconds;
private readonly long _requestHeadersTimeoutMilliseconds;

private int _responseBytesWritten;
protected long _responseBytesWritten;

public Frame(ConnectionContext context)
{
Expand Down Expand Up @@ -516,8 +517,8 @@ public async Task FlushAsync(CancellationToken cancellationToken)

public void Write(ArraySegment<byte> data)
{
VerifyAndUpdateWrite(data.Count);
ProduceStartAndFireOnStarting().GetAwaiter().GetResult();
_responseBytesWritten += data.Count;

if (_canHaveBody)
{
Expand Down Expand Up @@ -547,7 +548,7 @@ public Task WriteAsync(ArraySegment<byte> data, CancellationToken cancellationTo
return WriteAsyncAwaited(data, cancellationToken);
}

_responseBytesWritten += data.Count;
VerifyAndUpdateWrite(data.Count);

if (_canHaveBody)
{
Expand All @@ -573,8 +574,9 @@ public Task WriteAsync(ArraySegment<byte> data, CancellationToken cancellationTo

public async Task WriteAsyncAwaited(ArraySegment<byte> data, CancellationToken cancellationToken)
{
VerifyAndUpdateWrite(data.Count);

await ProduceStartAndFireOnStarting();
_responseBytesWritten += data.Count;

if (_canHaveBody)
{
Expand All @@ -598,6 +600,23 @@ public async Task WriteAsyncAwaited(ArraySegment<byte> data, CancellationToken c
}
}

private void VerifyAndUpdateWrite(int count)
{
var responseHeaders = FrameResponseHeaders;

if (responseHeaders != null &&
!responseHeaders.HasTransferEncoding &&
responseHeaders.HasContentLength &&
_responseBytesWritten + count > responseHeaders.HeaderContentLengthValue.Value)
{
_keepAlive = false;
throw new InvalidOperationException(
$"Response Content-Length mismatch: too many bytes written ({_responseBytesWritten + count} of {responseHeaders.HeaderContentLengthValue.Value}).");
}

_responseBytesWritten += count;
}

private void WriteChunked(ArraySegment<byte> data)
{
SocketOutput.Write(data, chunk: true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3697,6 +3697,7 @@ protected override void ClearFast()
{
_bits = 0;
_headers = default(HeaderReferences);

MaybeUnknown?.Clear();
}

Expand Down Expand Up @@ -5670,6 +5671,7 @@ public StringValues HeaderContentLength
}
set
{
_contentLength = ParseContentLength(value);
_bits |= 2048L;
_headers._ContentLength = value;
_headers._rawContentLength = null;
Expand Down Expand Up @@ -7384,6 +7386,7 @@ protected override void SetValueFast(string key, StringValues value)
{
if ("Content-Length".Equals(key, StringComparison.OrdinalIgnoreCase))
{
_contentLength = ParseContentLength(value);
_bits |= 2048L;
_headers._ContentLength = value;
_headers._rawContentLength = null;
Expand Down Expand Up @@ -7809,6 +7812,7 @@ protected override void AddValueFast(string key, StringValues value)
{
ThrowDuplicateKeyException();
}
_contentLength = ParseContentLength(value);
_bits |= 2048L;
_headers._ContentLength = value;
_headers._rawContentLength = null;
Expand Down Expand Up @@ -8350,6 +8354,7 @@ protected override bool RemoveFast(string key)
{
if (((_bits & 2048L) != 0))
{
_contentLength = null;
_bits &= ~2048L;
_headers._ContentLength = StringValues.Empty;
_headers._rawContentLength = null;
Expand Down Expand Up @@ -8601,6 +8606,7 @@ protected override void ClearFast()
{
_bits = 0;
_headers = default(HeaderReferences);
_contentLength = null;
MaybeUnknown?.Clear();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Primitives;
Expand Down Expand Up @@ -232,6 +233,18 @@ public static void ValidateHeaderCharacters(string headerCharacters)
}
}

public static long ParseContentLength(StringValues value)
{
try
{
return long.Parse(value, NumberStyles.AllowLeadingWhite | NumberStyles.AllowTrailingWhite, CultureInfo.InvariantCulture);
}
catch (FormatException ex)
{
throw new InvalidOperationException("Content-Length value must be an integral number.", ex);
}
}

private static void ThrowInvalidHeaderCharacter(char ch)
{
throw new InvalidOperationException(string.Format("Invalid non-ASCII or control character in header: 0x{0:X4}", (ushort)ch));
Expand Down
10 changes: 10 additions & 0 deletions src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameOfT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,16 @@ public override async Task RequestProcessingAsync()
try
{
await _application.ProcessRequestAsync(context).ConfigureAwait(false);

var responseHeaders = FrameResponseHeaders;
if (!responseHeaders.HasTransferEncoding &&
responseHeaders.HasContentLength &&
_responseBytesWritten < responseHeaders.HeaderContentLengthValue.Value)
{
_keepAlive = false;
ReportApplicationError(new InvalidOperationException(
$"Response Content-Length mismatch: too few bytes written ({_responseBytesWritten} of {responseHeaders.HeaderContentLengthValue.Value})."));
}
}
catch (Exception ex)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ public partial class FrameResponseHeaders : FrameHeaders
private static readonly byte[] _CrLf = new[] { (byte)'\r', (byte)'\n' };
private static readonly byte[] _colonSpace = new[] { (byte)':', (byte)' ' };

private long? _contentLength;

public bool HasConnection => HeaderConnection.Count != 0;

public bool HasTransferEncoding => HeaderTransferEncoding.Count != 0;
Expand All @@ -23,6 +25,8 @@ public partial class FrameResponseHeaders : FrameHeaders

public bool HasDate => HeaderDate.Count != 0;

public long? HeaderContentLengthValue => _contentLength;

public Enumerator GetEnumerator()
{
return new Enumerator(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public interface IKestrelTrace : ILogger

void ConnectionDisconnectedWrite(string connectionId, int count, Exception ex);

void ConnectionHeadResponseBodyWrite(string connectionId, int count);
void ConnectionHeadResponseBodyWrite(string connectionId, long count);

void ConnectionBadRequest(string connectionId, BadHttpRequestException ex);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public class KestrelTrace : IKestrelTrace
private static readonly Action<ILogger, string, Exception> _applicationError;
private static readonly Action<ILogger, string, Exception> _connectionError;
private static readonly Action<ILogger, string, int, Exception> _connectionDisconnectedWrite;
private static readonly Action<ILogger, string, int, Exception> _connectionHeadResponseBodyWrite;
private static readonly Action<ILogger, string, long, Exception> _connectionHeadResponseBodyWrite;
private static readonly Action<ILogger, Exception> _notAllConnectionsClosedGracefully;
private static readonly Action<ILogger, string, string, Exception> _connectionBadRequest;

Expand All @@ -49,7 +49,7 @@ static KestrelTrace()
_connectionDisconnectedWrite = LoggerMessage.Define<string, int>(LogLevel.Debug, 15, @"Connection id ""{ConnectionId}"" write of ""{count}"" bytes to disconnected client.");
_notAllConnectionsClosedGracefully = LoggerMessage.Define(LogLevel.Debug, 16, "Some connections failed to close gracefully during server shutdown.");
_connectionBadRequest = LoggerMessage.Define<string, string>(LogLevel.Information, 17, @"Connection id ""{ConnectionId}"" bad request data: ""{message}""");
_connectionHeadResponseBodyWrite = LoggerMessage.Define<string, int>(LogLevel.Debug, 18, @"Connection id ""{ConnectionId}"" write of ""{count}"" body bytes to non-body HEAD response.");
_connectionHeadResponseBodyWrite = LoggerMessage.Define<string, long>(LogLevel.Debug, 18, @"Connection id ""{ConnectionId}"" write of ""{count}"" body bytes to non-body HEAD response.");
}

public KestrelTrace(ILogger logger)
Expand Down Expand Up @@ -135,7 +135,7 @@ public virtual void ConnectionDisconnectedWrite(string connectionId, int count,
_connectionDisconnectedWrite(_logger, connectionId, count, ex);
}

public virtual void ConnectionHeadResponseBodyWrite(string connectionId, int count)
public virtual void ConnectionHeadResponseBodyWrite(string connectionId, long count)
{
_connectionHeadResponseBodyWrite(_logger, connectionId, count, null);
}
Expand Down
Loading