diff --git a/src/GrpcHttpApi/src/Microsoft.AspNetCore.Grpc.HttpApi/UnaryServerCallHandler.cs b/src/GrpcHttpApi/src/Microsoft.AspNetCore.Grpc.HttpApi/UnaryServerCallHandler.cs index 2542fc02f..079fe5606 100644 --- a/src/GrpcHttpApi/src/Microsoft.AspNetCore.Grpc.HttpApi/UnaryServerCallHandler.cs +++ b/src/GrpcHttpApi/src/Microsoft.AspNetCore.Grpc.HttpApi/UnaryServerCallHandler.cs @@ -12,7 +12,6 @@ using System.Threading.Tasks; using Google.Protobuf; using Google.Protobuf.Reflection; -using Google.Protobuf.WellKnownTypes; using Grpc.Core; using Grpc.Gateway.Runtime; using Grpc.Shared.HttpApi; @@ -65,11 +64,17 @@ public UnaryServerCallHandler( public async Task HandleCallAsync(HttpContext httpContext) { - var requestMessage = await CreateMessage(httpContext.Request); + var selectedEncoding = ResponseEncoding.SelectCharacterEncoding(httpContext.Request); - var serverCallContext = new HttpApiServerCallContext(httpContext, _unaryMethodInvoker.Method.FullName); + var (requestMessage, requestStatusCode, errorMessage) = await CreateMessage(httpContext.Request); - var selectedEncoding = ResponseEncoding.SelectCharacterEncoding(httpContext.Request); + if (requestMessage == null || requestStatusCode != StatusCode.OK) + { + await SendErrorResponse(httpContext.Response, selectedEncoding, errorMessage ?? string.Empty, requestStatusCode); + return; + } + + var serverCallContext = new HttpApiServerCallContext(httpContext, _unaryMethodInvoker.Method.FullName); TResponse responseMessage; try @@ -106,7 +111,7 @@ public async Task HandleCallAsync(HttpContext httpContext) await SendResponse(httpContext.Response, selectedEncoding, responseMessage); } - private async Task CreateMessage(HttpRequest request) + private async Task<(IMessage? requestMessage, StatusCode statusCode, string? errorMessage)> CreateMessage(HttpRequest request) { IMessage? requestMessage; @@ -115,7 +120,7 @@ private async Task CreateMessage(HttpRequest request) if (request.ContentType == null || !request.ContentType.StartsWith("application/json", StringComparison.OrdinalIgnoreCase)) { - throw new InvalidOperationException("Request content-type of application/json is required."); + return (null, StatusCode.InvalidArgument, "Request content-type of application/json is required."); } if (!request.Body.CanSeek) @@ -150,7 +155,20 @@ private async Task CreateMessage(HttpRequest request) } else { - var bodyContent = JsonParser.Default.Parse(requestReader, _bodyDescriptor); + IMessage bodyContent; + + try + { + bodyContent = JsonParser.Default.Parse(requestReader, _bodyDescriptor); + } + catch (InvalidJsonException) + { + return (null, StatusCode.InvalidArgument, "Request JSON payload is not correctly formatted."); + } + catch (InvalidProtocolBufferException exception) + { + return (null, StatusCode.InvalidArgument, exception.Message); + } if (_bodyFieldDescriptors != null) { @@ -192,7 +210,7 @@ private async Task CreateMessage(HttpRequest request) } } - return requestMessage; + return (requestMessage, StatusCode.OK, null); } private List? GetPathDescriptors(IMessage requestMessage, string path) diff --git a/src/GrpcHttpApi/test/Microsoft.AspNetCore.Grpc.HttpApi.Tests/UnaryServerCallHandlerTests.cs b/src/GrpcHttpApi/test/Microsoft.AspNetCore.Grpc.HttpApi.Tests/UnaryServerCallHandlerTests.cs index 5204b18fc..c99abc4f0 100644 --- a/src/GrpcHttpApi/test/Microsoft.AspNetCore.Grpc.HttpApi.Tests/UnaryServerCallHandlerTests.cs +++ b/src/GrpcHttpApi/test/Microsoft.AspNetCore.Grpc.HttpApi.Tests/UnaryServerCallHandlerTests.cs @@ -336,6 +336,68 @@ public async Task HandleCallAsync_MatchingQueryStringValues_SetOnRequestMessage( Assert.Equal("TestSubfield!", request!.Sub.Subfield); } + [Theory] + [InlineData("{malformed_json}", "Request JSON payload is not correctly formatted.")] + [InlineData("{\"name\": 1234}", "Unsupported conversion from JSON number for field type String")] + [InlineData("{\"abcd\": 1234}", "Unknown field: abcd")] + public async Task HandleCallAsync_MalformedRequestBody_BadRequestReturned(string json, string expectedError) + { + // Arrange + UnaryServerMethod invoker = (s, r, c) => + { + return Task.FromResult(new HelloReply()); + }; + + var unaryServerCallHandler = CreateCallHandler( + invoker, + bodyDescriptor: HelloRequest.Descriptor); + var httpContext = CreateHttpContext(); + httpContext.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes(json)); + httpContext.Request.ContentType = "application/json"; + // Act + await unaryServerCallHandler.HandleCallAsync(httpContext); + + // Assert + Assert.Equal(400, httpContext.Response.StatusCode); + + httpContext.Response.Body.Seek(0, SeekOrigin.Begin); + using var responseJson = JsonDocument.Parse(httpContext.Response.Body); + Assert.Equal(expectedError, responseJson.RootElement.GetProperty("message").GetString()); + Assert.Equal(expectedError, responseJson.RootElement.GetProperty("error").GetString()); + Assert.Equal((int)StatusCode.InvalidArgument, responseJson.RootElement.GetProperty("code").GetInt32()); + } + + [Theory] + [InlineData(null)] + [InlineData("text/html")] + public async Task HandleCallAsync_BadContentType_BadRequestReturned(string contentType) + { + // Arrange + UnaryServerMethod invoker = (s, r, c) => + { + return Task.FromResult(new HelloReply()); + }; + + var unaryServerCallHandler = CreateCallHandler( + invoker, + bodyDescriptor: HelloRequest.Descriptor); + var httpContext = CreateHttpContext(); + httpContext.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes("{}")); + httpContext.Request.ContentType = contentType; + // Act + await unaryServerCallHandler.HandleCallAsync(httpContext); + + // Assert + Assert.Equal(400, httpContext.Response.StatusCode); + + var expectedError = "Request content-type of application/json is required."; + httpContext.Response.Body.Seek(0, SeekOrigin.Begin); + using var responseJson = JsonDocument.Parse(httpContext.Response.Body); + Assert.Equal(expectedError, responseJson.RootElement.GetProperty("message").GetString()); + Assert.Equal(expectedError, responseJson.RootElement.GetProperty("error").GetString()); + Assert.Equal((int)StatusCode.InvalidArgument, responseJson.RootElement.GetProperty("code").GetInt32()); + } + [Fact] public async Task HandleCallAsync_RpcExceptionReturned_StatusReturned() {