diff --git a/src/Examples/GettingStarted/Startup.cs b/src/Examples/GettingStarted/Startup.cs index 6c46707d1d..b92ffc5fe3 100644 --- a/src/Examples/GettingStarted/Startup.cs +++ b/src/Examples/GettingStarted/Startup.cs @@ -22,7 +22,9 @@ public void ConfigureServices(IServiceCollection services) public void Configure(IApplicationBuilder app, SampleDbContext context) { - context.Database.EnsureDeleted(); // indices need to be reset + // indices need to be reset + context.Database.EnsureDeleted(); + context.Database.EnsureCreated(); app.UseJsonApi(); } } diff --git a/src/JsonApiDotNetCore/Middleware/CurrentRequestMiddleware.cs b/src/JsonApiDotNetCore/Middleware/CurrentRequestMiddleware.cs index 9b6f885b15..fd0895da36 100644 --- a/src/JsonApiDotNetCore/Middleware/CurrentRequestMiddleware.cs +++ b/src/JsonApiDotNetCore/Middleware/CurrentRequestMiddleware.cs @@ -22,12 +22,6 @@ namespace JsonApiDotNetCore.Middleware public sealed class CurrentRequestMiddleware { private readonly RequestDelegate _next; - private HttpContext _httpContext; - private ICurrentRequest _currentRequest; - private IResourceGraph _resourceGraph; - private IJsonApiOptions _options; - private RouteValueDictionary _routeValues; - private IControllerResourceMapping _controllerResourceMapping; public CurrentRequestMiddleware(RequestDelegate next) { @@ -35,73 +29,73 @@ public CurrentRequestMiddleware(RequestDelegate next) } public async Task Invoke(HttpContext httpContext, - IControllerResourceMapping controllerResourceMapping, - IJsonApiOptions options, - ICurrentRequest currentRequest, - IResourceGraph resourceGraph) + IControllerResourceMapping controllerResourceMapping, + IJsonApiOptions options, + ICurrentRequest currentRequest, + IResourceGraph resourceGraph) { - _httpContext = httpContext; - _currentRequest = currentRequest; - _controllerResourceMapping = controllerResourceMapping; - _resourceGraph = resourceGraph; - _options = options; - _routeValues = httpContext.GetRouteData().Values; - var requestResource = GetCurrentEntity(); + var routeValues = httpContext.GetRouteData().Values; + var requestContext = new RequestContext(httpContext, currentRequest, resourceGraph, options, routeValues, + controllerResourceMapping); + + var requestResource = GetCurrentEntity(requestContext); if (requestResource != null) { - _currentRequest.SetRequestResource(requestResource); - _currentRequest.IsRelationshipPath = PathIsRelationship(); - _currentRequest.BasePath = GetBasePath(requestResource.ResourceName); - _currentRequest.BaseId = GetBaseId(); - _currentRequest.RelationshipId = GetRelationshipId(); + requestContext.CurrentRequest.SetRequestResource(requestResource); + requestContext.CurrentRequest.IsRelationshipPath = PathIsRelationship(requestContext.RouteValues); + requestContext.CurrentRequest.BasePath = GetBasePath(requestContext, requestResource.ResourceName); + requestContext.CurrentRequest.BaseId = GetBaseId(requestContext.RouteValues); + requestContext.CurrentRequest.RelationshipId = GetRelationshipId(requestContext); } - if (await IsValidAsync()) + if (await IsValidAsync(requestContext)) { - await _next(httpContext); + await _next(requestContext.HttpContext); } } - private string GetBaseId() + private static string GetBaseId(RouteValueDictionary routeValues) { - if (_routeValues.TryGetValue("id", out object stringId)) + if (routeValues.TryGetValue("id", out object stringId)) { return (string)stringId; } return null; } - private string GetRelationshipId() + + private static string GetRelationshipId(RequestContext requestContext) { - if (!_currentRequest.IsRelationshipPath) + if (!requestContext.CurrentRequest.IsRelationshipPath) { return null; } - var components = SplitCurrentPath(); + var components = SplitCurrentPath(requestContext); var toReturn = components.ElementAtOrDefault(4); return toReturn; } - private string[] SplitCurrentPath() + + private static string[] SplitCurrentPath(RequestContext requestContext) { - var path = _httpContext.Request.Path.Value; - var ns = $"/{_options.Namespace}"; + var path = requestContext.HttpContext.Request.Path.Value; + var ns = $"/{requestContext.Options.Namespace}"; var nonNameSpaced = path.Replace(ns, ""); nonNameSpaced = nonNameSpaced.Trim('/'); var individualComponents = nonNameSpaced.Split('/'); return individualComponents; } - private string GetBasePath(string resourceName = null) + private static string GetBasePath(RequestContext requestContext, string resourceName = null) { - var r = _httpContext.Request; - if (_options.RelativeLinks) + var r = requestContext.HttpContext.Request; + if (requestContext.Options.RelativeLinks) { - return _options.Namespace; + return requestContext.Options.Namespace; } - var customRoute = GetCustomRoute(r.Path.Value, resourceName); - var toReturn = $"{r.Scheme}://{r.Host}/{_options.Namespace}"; + var customRoute = GetCustomRoute(requestContext.Options, r.Path.Value, resourceName); + var toReturn = $"{r.Scheme}://{r.Host}/{requestContext.Options.Namespace}"; if (customRoute != null) { toReturn += $"/{customRoute}"; @@ -109,13 +103,13 @@ private string GetBasePath(string resourceName = null) return toReturn; } - private object GetCustomRoute(string path, string resourceName) + private static object GetCustomRoute(IJsonApiOptions options, string path, string resourceName) { var trimmedComponents = path.Trim('/').Split('/').ToList(); var resourceNameIndex = trimmedComponents.FindIndex(c => c == resourceName); var newComponents = trimmedComponents.Take(resourceNameIndex).ToArray(); var customRoute = string.Join('/', newComponents); - if (customRoute == _options.Namespace) + if (customRoute == options.Namespace) { return null; } @@ -125,23 +119,23 @@ private object GetCustomRoute(string path, string resourceName) } } - private bool PathIsRelationship() + private static bool PathIsRelationship(RouteValueDictionary routeValues) { - var actionName = (string)_routeValues["action"]; + var actionName = (string)routeValues["action"]; return actionName.ToLowerInvariant().Contains("relationships"); } - private async Task IsValidAsync() + private static async Task IsValidAsync(RequestContext requestContext) { - return await IsValidContentTypeHeaderAsync(_httpContext) && await IsValidAcceptHeaderAsync(_httpContext); + return await IsValidContentTypeHeaderAsync(requestContext) && await IsValidAcceptHeaderAsync(requestContext); } - private async Task IsValidContentTypeHeaderAsync(HttpContext context) + private static async Task IsValidContentTypeHeaderAsync(RequestContext requestContext) { - var contentType = context.Request.ContentType; + var contentType = requestContext.HttpContext.Request.ContentType; if (contentType != null && ContainsMediaTypeParameters(contentType)) { - await FlushResponseAsync(context, new Error(HttpStatusCode.UnsupportedMediaType) + await FlushResponseAsync(requestContext, new Error(HttpStatusCode.UnsupportedMediaType) { Title = "The specified Content-Type header value is not supported.", Detail = $"Please specify '{HeaderConstants.ContentType}' for the Content-Type header value." @@ -152,9 +146,9 @@ private async Task IsValidContentTypeHeaderAsync(HttpContext context) return true; } - private async Task IsValidAcceptHeaderAsync(HttpContext context) + private static async Task IsValidAcceptHeaderAsync(RequestContext requestContext) { - if (context.Request.Headers.TryGetValue(HeaderConstants.AcceptHeader, out StringValues acceptHeaders) == false) + if (requestContext.HttpContext.Request.Headers.TryGetValue(HeaderConstants.AcceptHeader, out StringValues acceptHeaders) == false) return true; foreach (var acceptHeader in acceptHeaders) @@ -164,7 +158,7 @@ private async Task IsValidAcceptHeaderAsync(HttpContext context) continue; } - await FlushResponseAsync(context, new Error(HttpStatusCode.NotAcceptable) + await FlushResponseAsync(requestContext, new Error(HttpStatusCode.NotAcceptable) { Title = "The specified Accept header value is not supported.", Detail = $"Please specify '{HeaderConstants.ContentType}' for the Accept header value." @@ -195,11 +189,11 @@ private static bool ContainsMediaTypeParameters(string mediaType) ); } - private async Task FlushResponseAsync(HttpContext context, Error error) + private static async Task FlushResponseAsync(RequestContext requestContext, Error error) { - context.Response.StatusCode = (int) error.StatusCode; + requestContext.HttpContext.Response.StatusCode = (int) error.StatusCode; - JsonSerializer serializer = JsonSerializer.CreateDefault(_options.SerializerSettings); + JsonSerializer serializer = JsonSerializer.CreateDefault(requestContext.Options.SerializerSettings); serializer.ApplyErrorSettings(); // https://github.com/JamesNK/Newtonsoft.Json/issues/1193 @@ -212,34 +206,59 @@ private async Task FlushResponseAsync(HttpContext context, Error error) } stream.Seek(0, SeekOrigin.Begin); - await stream.CopyToAsync(context.Response.Body); + await stream.CopyToAsync(requestContext.HttpContext.Response.Body); } - context.Response.Body.Flush(); + requestContext.HttpContext.Response.Body.Flush(); } /// /// Gets the current entity that we need for serialization and deserialization. /// /// - private ResourceContext GetCurrentEntity() + private static ResourceContext GetCurrentEntity(RequestContext requestContext) { - var controllerName = (string)_routeValues["controller"]; + var controllerName = (string)requestContext.RouteValues["controller"]; if (controllerName == null) { return null; } - var resourceType = _controllerResourceMapping.GetAssociatedResource(controllerName); - var requestResource = _resourceGraph.GetResourceContext(resourceType); + var resourceType = requestContext.ControllerResourceMapping.GetAssociatedResource(controllerName); + var requestResource = requestContext.ResourceGraph.GetResourceContext(resourceType); if (requestResource == null) { return null; } - if (_routeValues.TryGetValue("relationshipName", out object relationshipName)) + if (requestContext.RouteValues.TryGetValue("relationshipName", out object relationshipName)) { - _currentRequest.RequestRelationship = requestResource.Relationships.SingleOrDefault(r => r.PublicRelationshipName == (string)relationshipName); + requestContext.CurrentRequest.RequestRelationship = requestResource.Relationships.SingleOrDefault(r => r.PublicRelationshipName == (string)relationshipName); } return requestResource; } + + private sealed class RequestContext + { + public HttpContext HttpContext { get; } + public ICurrentRequest CurrentRequest { get; } + public IResourceGraph ResourceGraph { get; } + public IJsonApiOptions Options { get; } + public RouteValueDictionary RouteValues { get; } + public IControllerResourceMapping ControllerResourceMapping { get; } + + public RequestContext(HttpContext httpContext, + ICurrentRequest currentRequest, + IResourceGraph resourceGraph, + IJsonApiOptions options, + RouteValueDictionary routeValues, + IControllerResourceMapping controllerResourceMapping) + { + HttpContext = httpContext; + CurrentRequest = currentRequest; + ResourceGraph = resourceGraph; + Options = options; + RouteValues = routeValues; + ControllerResourceMapping = controllerResourceMapping; + } + } } }