Skip to content

Properly handle malformed request body and content type #320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
using System.Threading.Tasks;
using Google.Protobuf;
using Google.Protobuf.Reflection;
using Google.Protobuf.WellKnownTypes;
using Grpc.Core;
using Grpc.Gateway.Runtime;
using Grpc.Shared.HttpApi;
Expand Down Expand Up @@ -65,11 +64,17 @@ public UnaryServerCallHandler(

public async Task HandleCallAsync(HttpContext httpContext)
{
var requestMessage = await CreateMessage(httpContext.Request);
var selectedEncoding = ResponseEncoding.SelectCharacterEncoding(httpContext.Request);

var serverCallContext = new HttpApiServerCallContext(httpContext, _unaryMethodInvoker.Method.FullName);
var (requestMessage, requestStatusCode, errorMessage) = await CreateMessage(httpContext.Request);

var selectedEncoding = ResponseEncoding.SelectCharacterEncoding(httpContext.Request);
if (requestMessage == null || requestStatusCode != StatusCode.OK)
{
await SendErrorResponse(httpContext.Response, selectedEncoding, errorMessage ?? string.Empty, requestStatusCode);
return;
}

var serverCallContext = new HttpApiServerCallContext(httpContext, _unaryMethodInvoker.Method.FullName);

TResponse responseMessage;
try
Expand Down Expand Up @@ -106,7 +111,7 @@ public async Task HandleCallAsync(HttpContext httpContext)
await SendResponse(httpContext.Response, selectedEncoding, responseMessage);
}

private async Task<IMessage> CreateMessage(HttpRequest request)
private async Task<(IMessage? requestMessage, StatusCode statusCode, string? errorMessage)> CreateMessage(HttpRequest request)
{
IMessage? requestMessage;

Expand All @@ -115,7 +120,7 @@ private async Task<IMessage> CreateMessage(HttpRequest request)
if (request.ContentType == null ||
!request.ContentType.StartsWith("application/json", StringComparison.OrdinalIgnoreCase))
{
throw new InvalidOperationException("Request content-type of application/json is required.");
return (null, StatusCode.InvalidArgument, "Request content-type of application/json is required.");
}

if (!request.Body.CanSeek)
Expand Down Expand Up @@ -150,7 +155,20 @@ private async Task<IMessage> CreateMessage(HttpRequest request)
}
else
{
var bodyContent = JsonParser.Default.Parse(requestReader, _bodyDescriptor);
IMessage bodyContent;

try
{
bodyContent = JsonParser.Default.Parse(requestReader, _bodyDescriptor);
}
catch (InvalidJsonException)
{
return (null, StatusCode.InvalidArgument, "Request JSON payload is not correctly formatted.");
}
catch (InvalidProtocolBufferException exception)
{
return (null, StatusCode.InvalidArgument, exception.Message);
}

if (_bodyFieldDescriptors != null)
{
Expand Down Expand Up @@ -192,7 +210,7 @@ private async Task<IMessage> CreateMessage(HttpRequest request)
}
}

return requestMessage;
return (requestMessage, StatusCode.OK, null);
}

private List<FieldDescriptor>? GetPathDescriptors(IMessage requestMessage, string path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,68 @@ public async Task HandleCallAsync_MatchingQueryStringValues_SetOnRequestMessage(
Assert.Equal("TestSubfield!", request!.Sub.Subfield);
}

[Theory]
[InlineData("{malformed_json}", "Request JSON payload is not correctly formatted.")]
[InlineData("{\"name\": 1234}", "Unsupported conversion from JSON number for field type String")]
[InlineData("{\"abcd\": 1234}", "Unknown field: abcd")]
public async Task HandleCallAsync_MalformedRequestBody_BadRequestReturned(string json, string expectedError)
{
// Arrange
UnaryServerMethod<HttpApiGreeterService, HelloRequest, HelloReply> invoker = (s, r, c) =>
{
return Task.FromResult(new HelloReply());
};

var unaryServerCallHandler = CreateCallHandler(
invoker,
bodyDescriptor: HelloRequest.Descriptor);
var httpContext = CreateHttpContext();
httpContext.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes(json));
httpContext.Request.ContentType = "application/json";
// Act
await unaryServerCallHandler.HandleCallAsync(httpContext);

// Assert
Assert.Equal(400, httpContext.Response.StatusCode);

httpContext.Response.Body.Seek(0, SeekOrigin.Begin);
using var responseJson = JsonDocument.Parse(httpContext.Response.Body);
Assert.Equal(expectedError, responseJson.RootElement.GetProperty("message").GetString());
Assert.Equal(expectedError, responseJson.RootElement.GetProperty("error").GetString());
Assert.Equal((int)StatusCode.InvalidArgument, responseJson.RootElement.GetProperty("code").GetInt32());
}

[Theory]
[InlineData(null)]
[InlineData("text/html")]
public async Task HandleCallAsync_BadContentType_BadRequestReturned(string contentType)
{
// Arrange
UnaryServerMethod<HttpApiGreeterService, HelloRequest, HelloReply> invoker = (s, r, c) =>
{
return Task.FromResult(new HelloReply());
};

var unaryServerCallHandler = CreateCallHandler(
invoker,
bodyDescriptor: HelloRequest.Descriptor);
var httpContext = CreateHttpContext();
httpContext.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes("{}"));
httpContext.Request.ContentType = contentType;
// Act
await unaryServerCallHandler.HandleCallAsync(httpContext);

// Assert
Assert.Equal(400, httpContext.Response.StatusCode);

var expectedError = "Request content-type of application/json is required.";
httpContext.Response.Body.Seek(0, SeekOrigin.Begin);
using var responseJson = JsonDocument.Parse(httpContext.Response.Body);
Assert.Equal(expectedError, responseJson.RootElement.GetProperty("message").GetString());
Assert.Equal(expectedError, responseJson.RootElement.GetProperty("error").GetString());
Assert.Equal((int)StatusCode.InvalidArgument, responseJson.RootElement.GetProperty("code").GetInt32());
}

[Fact]
public async Task HandleCallAsync_RpcExceptionReturned_StatusReturned()
{
Expand Down