Skip to content
This repository was archived by the owner on Dec 18, 2018. It is now read-only.

Commit f1a3775

Browse files
authored
Copy HttpContext properties for long polling transport (#1684)
- The long polling transport simulates a persistent connection over multiple http requests. In order to expose common http request properties, we need to copy them to a fake http context on the first poll and set that as the HttpContext exposed via the IHttpContextFeature.
1 parent b5c46f3 commit f1a3775

File tree

6 files changed

+229
-10
lines changed

6 files changed

+229
-10
lines changed

src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionContextExtensions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ public static class DefaultConnectionContextExtensions
1111
{
1212
public static HttpContext GetHttpContext(this ConnectionContext connection)
1313
{
14-
return connection.Features.Get<IHttpContextFeature>().HttpContext;
14+
return connection.Features.Get<IHttpContextFeature>()?.HttpContext;
1515
}
1616

1717
public static void SetHttpContext(this ConnectionContext connection, HttpContext httpContext)

src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs

Lines changed: 91 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
33

44
using System;
5+
using System.Collections.Generic;
56
using System.Diagnostics;
67
using System.IO;
78
using System.IO.Pipelines;
@@ -10,6 +11,7 @@
1011
using System.Threading.Tasks;
1112
using Microsoft.AspNetCore.Http;
1213
using Microsoft.AspNetCore.Http.Features;
14+
using Microsoft.AspNetCore.Http.Internal;
1315
using Microsoft.AspNetCore.Protocols;
1416
using Microsoft.AspNetCore.Protocols.Features;
1517
using Microsoft.AspNetCore.Sockets.Internal;
@@ -276,8 +278,6 @@ private async Task ExecuteEndpointAsync(HttpContext context, ConnectionDelegate
276278

277279
connection.Status = DefaultConnectionContext.ConnectionStatus.Inactive;
278280

279-
connection.SetHttpContext(null);
280-
281281
// Dispose the cancellation token
282282
connection.Cancellation.Dispose();
283283

@@ -500,15 +500,35 @@ private async Task<bool> EnsureConnectionStateAsync(DefaultConnectionContext con
500500
return false;
501501
}
502502

503+
// Setup the connection state from the http context
504+
connection.User = context.User;
505+
503506
// Configure transport-specific features.
504507
if (transportType == TransportType.LongPolling)
505508
{
506509
connection.Features.Set<IConnectionInherentKeepAliveFeature>(new ConnectionInherentKeepAliveFeature(options.LongPolling.PollTimeout));
507-
}
508510

509-
// Setup the connection state from the http context
510-
connection.User = context.User;
511-
connection.SetHttpContext(context);
511+
// For long polling, the requests come and go but the connection is still alive.
512+
// To make the IHttpContextFeature work well, we make a copy of the relevant properties
513+
// to a new HttpContext. This means that it's impossible to affect the context
514+
// with subsequent requests.
515+
var existing = connection.GetHttpContext();
516+
if (existing == null)
517+
{
518+
var httpContext = CloneHttpContext(context);
519+
connection.SetHttpContext(httpContext);
520+
}
521+
else
522+
{
523+
// Set the request trace identifier to the current http request handling the poll
524+
existing.TraceIdentifier = context.TraceIdentifier;
525+
existing.User = context.User;
526+
}
527+
}
528+
else
529+
{
530+
connection.SetHttpContext(context);
531+
}
512532

513533
// Set the Connection ID on the logging scope so that logs from now on will have the
514534
// Connection ID metadata set.
@@ -517,6 +537,65 @@ private async Task<bool> EnsureConnectionStateAsync(DefaultConnectionContext con
517537
return true;
518538
}
519539

540+
private static HttpContext CloneHttpContext(HttpContext context)
541+
{
542+
// The reason we're copying the base features instead of the HttpContext properties is
543+
// so that we can get all of the logic built into DefaultHttpContext to extract higher level
544+
// structure from the low level properties
545+
var existingRequestFeature = context.Features.Get<IHttpRequestFeature>();
546+
547+
var requestFeature = new HttpRequestFeature();
548+
requestFeature.Protocol = existingRequestFeature.Protocol;
549+
requestFeature.Method = existingRequestFeature.Method;
550+
requestFeature.Scheme = existingRequestFeature.Scheme;
551+
requestFeature.Path = existingRequestFeature.Path;
552+
requestFeature.PathBase = existingRequestFeature.PathBase;
553+
requestFeature.QueryString = existingRequestFeature.QueryString;
554+
requestFeature.RawTarget = existingRequestFeature.RawTarget;
555+
var requestHeaders = new Dictionary<string, StringValues>(existingRequestFeature.Headers.Count);
556+
foreach (var header in existingRequestFeature.Headers)
557+
{
558+
requestHeaders[header.Key] = header.Value;
559+
}
560+
requestFeature.Headers = new HeaderDictionary(requestHeaders);
561+
562+
var existingConnectionFeature = context.Features.Get<IHttpConnectionFeature>();
563+
var connectionFeature = new HttpConnectionFeature();
564+
565+
if (existingConnectionFeature != null)
566+
{
567+
connectionFeature.ConnectionId = existingConnectionFeature.ConnectionId;
568+
connectionFeature.LocalIpAddress = existingConnectionFeature.LocalIpAddress;
569+
connectionFeature.LocalPort = existingConnectionFeature.LocalPort;
570+
connectionFeature.RemoteIpAddress = existingConnectionFeature.RemoteIpAddress;
571+
connectionFeature.RemotePort = existingConnectionFeature.RemotePort;
572+
}
573+
574+
// The response is a dud, you can't do anything with it anyways
575+
var responseFeature = new HttpResponseFeature();
576+
577+
var features = new FeatureCollection();
578+
features.Set<IHttpRequestFeature>(requestFeature);
579+
features.Set<IHttpResponseFeature>(responseFeature);
580+
features.Set<IHttpConnectionFeature>(connectionFeature);
581+
582+
// REVIEW: We could strategically look at adding other features but it might be better
583+
// if we expose a callback that would allow the user to preserve HttpContext properties.
584+
585+
var newHttpContext = new DefaultHttpContext(features);
586+
newHttpContext.TraceIdentifier = context.TraceIdentifier;
587+
newHttpContext.User = context.User;
588+
589+
// Making request services function property could be tricky and expensive as it would require
590+
// DI scope per connection. It would also mean that services resolved in middleware leading up to here
591+
// wouldn't be the same instance (but maybe that's fine). For now, we just return an empty service provider
592+
newHttpContext.RequestServices = EmptyServiceProvider.Instance;
593+
594+
// REVIEW: This extends the lifetime of anything that got put into HttpContext.Items
595+
newHttpContext.Items = new Dictionary<object, object>(context.Items);
596+
return newHttpContext;
597+
}
598+
520599
private async Task<DefaultConnectionContext> GetConnectionAsync(HttpContext context, HttpSocketOptions options)
521600
{
522601
var connectionId = GetConnectionId(context);
@@ -580,5 +659,11 @@ private async Task<DefaultConnectionContext> GetOrCreateConnectionAsync(HttpCont
580659

581660
return connection;
582661
}
662+
663+
private class EmptyServiceProvider : IServiceProvider
664+
{
665+
public static EmptyServiceProvider Instance { get; } = new EmptyServiceProvider();
666+
public object GetService(Type serviceType) => null;
667+
}
583668
}
584669
}

src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
<ItemGroup>
2020
<PackageReference Include="Microsoft.AspNetCore.Authorization.Policy" Version="$(MicrosoftAspNetCoreAuthorizationPolicyPackageVersion)" />
2121
<PackageReference Include="Microsoft.AspNetCore.Hosting.Abstractions" Version="$(MicrosoftAspNetCoreHostingAbstractionsPackageVersion)" />
22+
<PackageReference Include="Microsoft.AspNetCore.Http" Version="$(MicrosoftAspNetCoreHttpPackageVersion)" />
2223
<PackageReference Include="Microsoft.AspNetCore.Routing" Version="$(MicrosoftAspNetCoreRoutingPackageVersion)" />
2324
<PackageReference Include="Microsoft.AspNetCore.WebSockets" Version="$(MicrosoftAspNetCoreWebSocketsPackageVersion)" />
2425
<PackageReference Include="Microsoft.Extensions.SecurityHelper.Sources" PrivateAssets="All" Version="$(MicrosoftExtensionsSecurityHelperSourcesPackageVersion)" />

test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,7 @@ public async Task ClientCanUseJwtBearerTokenForAuthentication(TransportType tran
692692
}
693693
}
694694

695-
[Theory(Skip = "HttpContext + Long Polling fails. Issue logged - https://github.com/aspnet/SignalR/issues/1644")]
695+
[Theory]
696696
[MemberData(nameof(TransportTypes))]
697697
public async Task ClientCanSendHeaders(TransportType transportType)
698698
{

test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,15 @@ public async Task HubMethodDoesNotSendResultWhenInvocationIsNonBlocking()
540540
client.Dispose();
541541

542542
// Ensure the client channel is empty
543-
Assert.Null(client.TryRead());
543+
var message = client.TryRead();
544+
switch (message)
545+
{
546+
case CloseMessage close:
547+
break;
548+
default:
549+
Assert.Null(message);
550+
break;
551+
}
544552

545553
await endPointTask.OrTimeout();
546554
}

test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
using System.Collections.Generic;
77
using System.IO;
88
using System.IO.Pipelines;
9+
using System.Linq;
10+
using System.Net;
911
using System.Net.WebSockets;
1012
using System.Security.Claims;
1113
using System.Text;
@@ -337,6 +339,96 @@ public async Task PostSendsToConnection(TransportType transportType)
337339
}
338340
}
339341

342+
[Fact]
343+
public async Task HttpContextFeatureForLongpollingWorksBetweenPolls()
344+
{
345+
using (StartLog(out var loggerFactory, LogLevel.Debug))
346+
{
347+
var manager = CreateConnectionManager(loggerFactory);
348+
var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory);
349+
var connection = manager.CreateConnection();
350+
351+
using (var requestBody = new MemoryStream())
352+
using (var responseBody = new MemoryStream())
353+
{
354+
var context = new DefaultHttpContext();
355+
context.Request.Body = requestBody;
356+
context.Response.Body = responseBody;
357+
358+
var services = new ServiceCollection();
359+
services.AddSingleton<HttpContextEndPoint>();
360+
services.AddOptions();
361+
362+
// Setup state on the HttpContext
363+
context.Request.Path = "/foo";
364+
context.Request.Method = "GET";
365+
var values = new Dictionary<string, StringValues>();
366+
values["id"] = connection.ConnectionId;
367+
values["another"] = "value";
368+
var qs = new QueryCollection(values);
369+
context.Request.Query = qs;
370+
context.Request.Headers["header1"] = "h1";
371+
context.Request.Headers["header2"] = "h2";
372+
context.Request.Headers["header3"] = "h3";
373+
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim("claim1", "claimValue") }));
374+
context.TraceIdentifier = "requestid";
375+
context.Connection.Id = "connectionid";
376+
context.Connection.LocalIpAddress = IPAddress.Loopback;
377+
context.Connection.LocalPort = 4563;
378+
context.Connection.RemoteIpAddress = IPAddress.IPv6Any;
379+
context.Connection.RemotePort = 43456;
380+
381+
var builder = new ConnectionBuilder(services.BuildServiceProvider());
382+
builder.UseEndPoint<HttpContextEndPoint>();
383+
var app = builder.Build();
384+
385+
// Start a poll
386+
var task = dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
387+
388+
// Send to the application
389+
var buffer = Encoding.UTF8.GetBytes("Hello World");
390+
await connection.Application.Output.WriteAsync(buffer);
391+
392+
// The poll request should end
393+
await task;
394+
395+
// Make sure the actual response isn't affected
396+
Assert.Equal("application/octet-stream", context.Response.ContentType);
397+
398+
// Now do a new send again without the poll (that request should have ended)
399+
await connection.Application.Output.WriteAsync(buffer);
400+
401+
connection.Application.Output.Complete();
402+
403+
// Wait for the endpoint to end
404+
await connection.ApplicationTask;
405+
406+
var connectionHttpContext = connection.GetHttpContext();
407+
Assert.NotNull(connectionHttpContext);
408+
409+
Assert.Equal(2, connectionHttpContext.Request.Query.Count);
410+
Assert.Equal(connection.ConnectionId, connectionHttpContext.Request.Query["id"]);
411+
Assert.Equal("value", connectionHttpContext.Request.Query["another"]);
412+
413+
Assert.Equal(3, connectionHttpContext.Request.Headers.Count);
414+
Assert.Equal("h1", connectionHttpContext.Request.Headers["header1"]);
415+
Assert.Equal("h2", connectionHttpContext.Request.Headers["header2"]);
416+
Assert.Equal("h3", connectionHttpContext.Request.Headers["header3"]);
417+
Assert.Equal("requestid", connectionHttpContext.TraceIdentifier);
418+
Assert.Equal("claimValue", connectionHttpContext.User.Claims.FirstOrDefault().Value);
419+
Assert.Equal("connectionid", connectionHttpContext.Connection.Id);
420+
Assert.Equal(IPAddress.Loopback, connectionHttpContext.Connection.LocalIpAddress);
421+
Assert.Equal(4563, connectionHttpContext.Connection.LocalPort);
422+
Assert.Equal(IPAddress.IPv6Any, connectionHttpContext.Connection.RemoteIpAddress);
423+
Assert.Equal(43456, connectionHttpContext.Connection.RemotePort);
424+
Assert.NotNull(connectionHttpContext.RequestServices);
425+
Assert.Equal(Stream.Null, connectionHttpContext.Response.Body);
426+
Assert.NotNull(connectionHttpContext.Response.Headers);
427+
Assert.Equal("application/xml", connectionHttpContext.Response.ContentType);
428+
}
429+
}
430+
}
431+
340432
[Theory]
341433
[InlineData(TransportType.ServerSentEvents)]
342434
[InlineData(TransportType.LongPolling)]
@@ -713,7 +805,7 @@ public async Task ConnectionStateSetToInactiveAfterPoll()
713805
await task;
714806

715807
Assert.Equal(DefaultConnectionContext.ConnectionStatus.Inactive, connection.Status);
716-
Assert.Null(connection.GetHttpContext());
808+
Assert.NotNull(connection.GetHttpContext());
717809

718810
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
719811
}
@@ -1418,6 +1510,39 @@ public override Task OnConnectedAsync(ConnectionContext connection)
14181510
}
14191511
}
14201512

1513+
public class HttpContextEndPoint : EndPoint
1514+
{
1515+
public override async Task OnConnectedAsync(ConnectionContext connection)
1516+
{
1517+
while (true)
1518+
{
1519+
var result = await connection.Transport.Input.ReadAsync();
1520+
1521+
try
1522+
{
1523+
if (result.IsCompleted)
1524+
{
1525+
break;
1526+
}
1527+
1528+
// Make sure we have an http context
1529+
var context = connection.GetHttpContext();
1530+
Assert.NotNull(context);
1531+
1532+
// Setting the response headers should have no effect
1533+
context.Response.ContentType = "application/xml";
1534+
1535+
// Echo the results
1536+
await connection.Transport.Output.WriteAsync(result.Buffer.ToArray());
1537+
}
1538+
finally
1539+
{
1540+
connection.Transport.Input.AdvanceTo(result.Buffer.End);
1541+
}
1542+
}
1543+
}
1544+
}
1545+
14211546
public class TestEndPoint : EndPoint
14221547
{
14231548
public override async Task OnConnectedAsync(ConnectionContext connection)

0 commit comments

Comments
 (0)