Skip to content
This repository was archived by the owner on Dec 18, 2018. It is now read-only.

Copy HttpContext properties for long polling transport #1684

Merged
merged 8 commits into from
Mar 22, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public static class DefaultConnectionContextExtensions
{
public static HttpContext GetHttpContext(this ConnectionContext connection)
{
return connection.Features.Get<IHttpContextFeature>().HttpContext;
return connection.Features.Get<IHttpContextFeature>()?.HttpContext;
}

public static void SetHttpContext(this ConnectionContext connection, HttpContext httpContext)
Expand Down
97 changes: 91 additions & 6 deletions src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.IO.Pipelines;
Expand All @@ -10,6 +11,7 @@
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Http.Internal;
using Microsoft.AspNetCore.Protocols;
using Microsoft.AspNetCore.Protocols.Features;
using Microsoft.AspNetCore.Sockets.Internal;
Expand Down Expand Up @@ -276,8 +278,6 @@ private async Task ExecuteEndpointAsync(HttpContext context, ConnectionDelegate

connection.Status = DefaultConnectionContext.ConnectionStatus.Inactive;

connection.SetHttpContext(null);

// Dispose the cancellation token
connection.Cancellation.Dispose();

Expand Down Expand Up @@ -488,15 +488,35 @@ private async Task<bool> EnsureConnectionStateAsync(DefaultConnectionContext con
return false;
}

// Setup the connection state from the http context
connection.User = context.User;

// Configure transport-specific features.
if (transportType == TransportType.LongPolling)
{
connection.Features.Set<IConnectionInherentKeepAliveFeature>(new ConnectionInherentKeepAliveFeature(options.LongPolling.PollTimeout));
}

// Setup the connection state from the http context
connection.User = context.User;
connection.SetHttpContext(context);
// For long polling, the requests come and go but the connection is still alive.
// To make the IHttpContextFeature work well, we make a copy of the relevant properties
// to a new HttpContext. This means that it's impossible to affect the context
// with subsequent requests.
var existing = connection.GetHttpContext();
if (existing == null)
{
var httpContext = CloneHttpContext(context);
connection.SetHttpContext(httpContext);
}
else
{
// Set the request trace identifier to the current http request handling the poll
existing.TraceIdentifier = context.TraceIdentifier;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Any other properties we want to update? Connection info (IP/Port/etc; In case a client changes networks, etc.)?

Of course, this is run in parallel with existing invocations so the data could "shear" (but as long as each set is atomic, it won't break anything).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extremely unlikely this would matter in real life. The reason I changed the trace identifier is because we log it to tell the user what the previous connection id is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we similarly set the existing.User value? Although we use connection.User for the Hub method auth

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

existing.User = context.User;
}
}
else
{
connection.SetHttpContext(context);
}

// Set the Connection ID on the logging scope so that logs from now on will have the
// Connection ID metadata set.
Expand All @@ -505,6 +525,65 @@ private async Task<bool> EnsureConnectionStateAsync(DefaultConnectionContext con
return true;
}

private static HttpContext CloneHttpContext(HttpContext context)
{
// The reason we're copying the base features instead of the HttpContext properties is
// so that we can get all of the logic built into DefaultHttpContext to extract higher level
// structure from the low level properties
var existingRequestFeature = context.Features.Get<IHttpRequestFeature>();

var requestFeature = new HttpRequestFeature();
requestFeature.Protocol = existingRequestFeature.Protocol;
requestFeature.Method = existingRequestFeature.Method;
requestFeature.Scheme = existingRequestFeature.Scheme;
requestFeature.Path = existingRequestFeature.Path;
requestFeature.PathBase = existingRequestFeature.PathBase;
requestFeature.QueryString = existingRequestFeature.QueryString;
requestFeature.RawTarget = existingRequestFeature.RawTarget;
var requestHeaders = new Dictionary<string, StringValues>(existingRequestFeature.Headers.Count);
foreach (var header in existingRequestFeature.Headers)
{
requestHeaders[header.Key] = header.Value;
}
requestFeature.Headers = new HeaderDictionary(requestHeaders);

var existingConnectionFeature = context.Features.Get<IHttpConnectionFeature>();
var connectionFeature = new HttpConnectionFeature();

if (existingConnectionFeature != null)
{
connectionFeature.ConnectionId = existingConnectionFeature.ConnectionId;
connectionFeature.LocalIpAddress = existingConnectionFeature.LocalIpAddress;
connectionFeature.LocalPort = existingConnectionFeature.LocalPort;
connectionFeature.RemoteIpAddress = existingConnectionFeature.RemoteIpAddress;
connectionFeature.RemotePort = existingConnectionFeature.RemotePort;
}

// The response is a dud, you can't do anything with it anyways
var responseFeature = new HttpResponseFeature();

var features = new FeatureCollection();
features.Set<IHttpRequestFeature>(requestFeature);
features.Set<IHttpResponseFeature>(responseFeature);
features.Set<IHttpConnectionFeature>(connectionFeature);

// REVIEW: We could strategically look at adding other features but it might be better
// if we expose a callback that would allow the user to preserve HttpContext properties.

var newHttpContext = new DefaultHttpContext(features);
newHttpContext.TraceIdentifier = context.TraceIdentifier;
newHttpContext.User = context.User;

// Making request services function property could be tricky and expensive as it would require
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. Not sure I understand.
Is this something that should be investigated further later on?

Copy link
Member Author

@davidfowl davidfowl Mar 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yah, I can remove the comment but I wanted to make it visible in the review.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is fine. Maybe log an issue and link to it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do DI scopes work with signalr? Why would it be tricky/expensive?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SignalR makes a scope per hub invocation which works great. The RequestServices on the other hand would exist for the lifetime of the connection. We'd then need to dispose the scope when the connection is disposed, that in itself isn't too bad. Where it gets wonky is when somebody tries to get the same instance of a service in middleware and then in a Hub. It'll be different (which is by design and most likely fine). The strange thing about it is that RequestServices would have been replaced.

// DI scope per connection. It would also mean that services resolved in middleware leading up to here
// wouldn't be the same instance (but maybe that's fine). For now, we just return an empty service provider
newHttpContext.RequestServices = EmptyServiceProvider.Instance;

// REVIEW: This extends the lifetime of anything that got put into HttpContext.Items
newHttpContext.Items = new Dictionary<object, object>(context.Items);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This worries me a bit... people use this to cache per-request stuff and we don't want to make those items live longer than the actual request.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This worries me a bit... people use this to cache per-request stuff and we don't want to make those items live longer than the actual request.

Yea, thats why I commented.

return newHttpContext;
}

private async Task<DefaultConnectionContext> GetConnectionAsync(HttpContext context, HttpSocketOptions options)
{
var connectionId = GetConnectionId(context);
Expand Down Expand Up @@ -568,5 +647,11 @@ private async Task<DefaultConnectionContext> GetOrCreateConnectionAsync(HttpCont

return connection;
}

private class EmptyServiceProvider : IServiceProvider
{
public static EmptyServiceProvider Instance { get; } = new EmptyServiceProvider();
public object GetService(Type serviceType) => null;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this should throw an error explaining why RequestServices doesn't work as expected.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will at least give a proper exception message when somebody is asking for a required service. That's why I didn't leave it as null. Push come to shove, we can implement something where we make a scoped service provider. I just wanted to see what people think of this first.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the error message? I was thinking that calling GetService would throw an error explaining why it isn't supported rather than return null and lead to a hard to debug NRE.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The usage pattern is that the caller does GetService in order to get a null and we have a GetRequiredService that throws if GetService returns null. If anyone is calling GetService today and not expecting null then they already have an NRE

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.Authorization.Policy" Version="$(MicrosoftAspNetCoreAuthorizationPolicyPackageVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Hosting.Abstractions" Version="$(MicrosoftAspNetCoreHostingAbstractionsPackageVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Http" Version="$(MicrosoftAspNetCoreHttpPackageVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Routing" Version="$(MicrosoftAspNetCoreRoutingPackageVersion)" />
<PackageReference Include="Microsoft.AspNetCore.WebSockets" Version="$(MicrosoftAspNetCoreWebSocketsPackageVersion)" />
<PackageReference Include="Microsoft.Extensions.SecurityHelper.Sources" PrivateAssets="All" Version="$(MicrosoftExtensionsSecurityHelperSourcesPackageVersion)" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ public async Task ClientCanUseJwtBearerTokenForAuthentication(TransportType tran
}
}

[Theory(Skip = "HttpContext + Long Polling fails. Issue logged - https://github.com/aspnet/SignalR/issues/1644")]
[Theory]
[MemberData(nameof(TransportTypes))]
public async Task ClientCanSendHeaders(TransportType transportType)
{
Expand Down
10 changes: 9 additions & 1 deletion test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,15 @@ public async Task HubMethodDoesNotSendResultWhenInvocationIsNonBlocking()
client.Dispose();

// Ensure the client channel is empty
Assert.Null(client.TryRead());
var message = client.TryRead();
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to "fix" this test since it kept failing with the close message thing.

/cc @JamesNK

switch (message)
{
case CloseMessage close:
break;
default:
Assert.Null(message);
break;
}

await endPointTask.OrTimeout();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Linq;
using System.Net;
using System.Net.WebSockets;
using System.Security.Claims;
using System.Text;
Expand Down Expand Up @@ -336,6 +338,96 @@ public async Task PostSendsToConnection(TransportType transportType)
}
}

[Fact]
public async Task HttpContextFeatureForLongpollingWorksBetweenPolls()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sure @moozzyk would love how long this test is 😄

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:D

{
using (StartLog(out var loggerFactory, LogLevel.Debug))
{
var manager = CreateConnectionManager(loggerFactory);
var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory);
var connection = manager.CreateConnection();

using (var requestBody = new MemoryStream())
using (var responseBody = new MemoryStream())
{
var context = new DefaultHttpContext();
context.Request.Body = requestBody;
context.Response.Body = responseBody;

var services = new ServiceCollection();
services.AddSingleton<HttpContextEndPoint>();
services.AddOptions();

// Setup state on the HttpContext
context.Request.Path = "/foo";
context.Request.Method = "GET";
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
values["another"] = "value";
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Request.Headers["header1"] = "h1";
context.Request.Headers["header2"] = "h2";
context.Request.Headers["header3"] = "h3";
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim("claim1", "claimValue") }));
context.TraceIdentifier = "requestid";
context.Connection.Id = "connectionid";
context.Connection.LocalIpAddress = IPAddress.Loopback;
context.Connection.LocalPort = 4563;
context.Connection.RemoteIpAddress = IPAddress.IPv6Any;
context.Connection.RemotePort = 43456;

var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseEndPoint<HttpContextEndPoint>();
var app = builder.Build();

// Start a poll
var task = dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);

// Send to the application
var buffer = Encoding.UTF8.GetBytes("Hello World");
await connection.Application.Output.WriteAsync(buffer);

// The poll request should end
await task;

// Make sure the actual response isn't affected
Assert.Equal("application/octet-stream", context.Response.ContentType);

// Now do a new send again without the poll (that request should have ended)
await connection.Application.Output.WriteAsync(buffer);

connection.Application.Output.Complete();

// Wait for the endpoint to end
await connection.ApplicationTask;

var connectionHttpContext = connection.GetHttpContext();
Assert.NotNull(connectionHttpContext);

Assert.Equal(2, connectionHttpContext.Request.Query.Count);
Assert.Equal(connection.ConnectionId, connectionHttpContext.Request.Query["id"]);
Assert.Equal("value", connectionHttpContext.Request.Query["another"]);

Assert.Equal(3, connectionHttpContext.Request.Headers.Count);
Assert.Equal("h1", connectionHttpContext.Request.Headers["header1"]);
Assert.Equal("h2", connectionHttpContext.Request.Headers["header2"]);
Assert.Equal("h3", connectionHttpContext.Request.Headers["header3"]);
Assert.Equal("requestid", connectionHttpContext.TraceIdentifier);
Assert.Equal("claimValue", connectionHttpContext.User.Claims.FirstOrDefault().Value);
Assert.Equal("connectionid", connectionHttpContext.Connection.Id);
Assert.Equal(IPAddress.Loopback, connectionHttpContext.Connection.LocalIpAddress);
Assert.Equal(4563, connectionHttpContext.Connection.LocalPort);
Assert.Equal(IPAddress.IPv6Any, connectionHttpContext.Connection.RemoteIpAddress);
Assert.Equal(43456, connectionHttpContext.Connection.RemotePort);
Assert.NotNull(connectionHttpContext.RequestServices);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also check that doing stuff with the response has no effect?

Copy link
Member Author

@davidfowl davidfowl Mar 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll assert something.

Assert.Equal(Stream.Null, connectionHttpContext.Response.Body);
Assert.NotNull(connectionHttpContext.Response.Headers);
Assert.Equal("application/xml", connectionHttpContext.Response.ContentType);
}
}
}

[Theory]
[InlineData(TransportType.ServerSentEvents)]
[InlineData(TransportType.LongPolling)]
Expand Down Expand Up @@ -712,7 +804,7 @@ public async Task ConnectionStateSetToInactiveAfterPoll()
await task;

Assert.Equal(DefaultConnectionContext.ConnectionStatus.Inactive, connection.Status);
Assert.Null(connection.GetHttpContext());
Assert.NotNull(connection.GetHttpContext());

Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
}
Expand Down Expand Up @@ -1392,6 +1484,39 @@ public override Task OnConnectedAsync(ConnectionContext connection)
}
}

public class HttpContextEndPoint : EndPoint
{
public override async Task OnConnectedAsync(ConnectionContext connection)
{
while (true)
{
var result = await connection.Transport.Input.ReadAsync();

try
{
if (result.IsCompleted)
{
break;
}

// Make sure we have an http context
var context = connection.GetHttpContext();
Assert.NotNull(context);

// Setting the response headers should have no effect
context.Response.ContentType = "application/xml";

// Echo the results
await connection.Transport.Output.WriteAsync(result.Buffer.ToArray());
}
finally
{
connection.Transport.Input.AdvanceTo(result.Buffer.End);
}
}
}
}

public class TestEndPoint : EndPoint
{
public override async Task OnConnectedAsync(ConnectionContext connection)
Expand Down