diff --git a/src/Analyzers/Analyzers/test/TestFiles/CompilationFeatureDetectorTest/StartupWithUseSignalR.cs b/src/Analyzers/Analyzers/test/TestFiles/CompilationFeatureDetectorTest/StartupWithUseSignalR.cs index 17159a411ca0..aa65f832580a 100644 --- a/src/Analyzers/Analyzers/test/TestFiles/CompilationFeatureDetectorTest/StartupWithUseSignalR.cs +++ b/src/Analyzers/Analyzers/test/TestFiles/CompilationFeatureDetectorTest/StartupWithUseSignalR.cs @@ -9,10 +9,12 @@ public class StartupWithUseSignalR { public void Configure(IApplicationBuilder app) { +#pragma warning disable CS0618 // Type or member is obsolete app.UseSignalR(routes => { }); +#pragma warning restore CS0618 // Type or member is obsolete } } } diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/VersionStartup.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/VersionStartup.cs index 18972ebb87d2..329ebe5211a3 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/VersionStartup.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/VersionStartup.cs @@ -1,18 +1,10 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; -using System.IdentityModel.Tokens.Jwt; -using System.Security.Claims; -using Microsoft.AspNetCore.Authentication.JwtBearer; using Microsoft.AspNetCore.Builder; -using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Http.Connections; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; -using Microsoft.IdentityModel.Tokens; -using Newtonsoft.Json; namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests { @@ -33,11 +25,12 @@ public void ConfigureServices(IServiceCollection services) public void Configure(IApplicationBuilder app) { + app.UseRouting(); app.UseAuthentication(); - app.UseSignalR(routes => + app.UseEndpoints(endpoints => { - routes.MapHub("/version"); + endpoints.MapHub("/version"); }); } } diff --git a/src/SignalR/common/Http.Connections/ref/Microsoft.AspNetCore.Http.Connections.netcoreapp3.0.cs b/src/SignalR/common/Http.Connections/ref/Microsoft.AspNetCore.Http.Connections.netcoreapp3.0.cs index 78ca329b4516..afe027ed6653 100644 --- a/src/SignalR/common/Http.Connections/ref/Microsoft.AspNetCore.Http.Connections.netcoreapp3.0.cs +++ b/src/SignalR/common/Http.Connections/ref/Microsoft.AspNetCore.Http.Connections.netcoreapp3.0.cs @@ -12,6 +12,7 @@ public static partial class ConnectionEndpointRouteBuilderExtensions } public static partial class ConnectionsAppBuilderExtensions { + [System.ObsoleteAttribute("This method is obsolete and will be removed in a future version. The recommended alternative is to use MapConnections or MapConnectionHandler inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")] public static Microsoft.AspNetCore.Builder.IApplicationBuilder UseConnections(this Microsoft.AspNetCore.Builder.IApplicationBuilder app, System.Action configure) { throw null; } } } @@ -28,6 +29,7 @@ public partial class ConnectionOptionsSetup : Microsoft.Extensions.Options.IConf public ConnectionOptionsSetup() { } public void Configure(Microsoft.AspNetCore.Http.Connections.ConnectionOptions options) { } } + [System.ObsoleteAttribute("This class is obsolete and will be removed in a future version. The recommended alternative is to use MapConnection and MapConnectionHandler inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")] public partial class ConnectionsRouteBuilder { internal ConnectionsRouteBuilder() { } diff --git a/src/SignalR/common/Http.Connections/src/ConnectionEndpointRouteBuilderExtensions.cs b/src/SignalR/common/Http.Connections/src/ConnectionEndpointRouteBuilderExtensions.cs index bc111a2ed485..7833a466008c 100644 --- a/src/SignalR/common/Http.Connections/src/ConnectionEndpointRouteBuilderExtensions.cs +++ b/src/SignalR/common/Http.Connections/src/ConnectionEndpointRouteBuilderExtensions.cs @@ -3,8 +3,6 @@ using System; using System.Collections.Generic; -using System.Linq; -using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http.Connections; using Microsoft.AspNetCore.Http.Connections.Internal; @@ -48,13 +46,6 @@ public static IEndpointConventionBuilder MapConnectionHandler(this IEndpointRouteBuilder endpoints, string pattern, Action configureOptions) where TConnectionHandler : ConnectionHandler { var options = new HttpConnectionDispatcherOptions(); - // REVIEW: WE should consider removing this and instead just relying on the - // AuthorizationMiddleware - var attributes = typeof(TConnectionHandler).GetCustomAttributes(inherit: true); - foreach (var attribute in attributes.OfType()) - { - options.AuthorizationData.Add(attribute); - } configureOptions?.Invoke(options); var conventionBuilder = endpoints.MapConnections(pattern, options, b => @@ -62,6 +53,7 @@ public static IEndpointConventionBuilder MapConnectionHandler(); }); + var attributes = typeof(TConnectionHandler).GetCustomAttributes(inherit: true); conventionBuilder.Add(e => { // Add all attributes on the ConnectionHandler has metadata (this will allow for things like) @@ -93,7 +85,7 @@ public static IEndpointConventionBuilder MapConnections(this IEndpointRouteBuild var connectionDelegate = connectionBuilder.Build(); // REVIEW: Consider expanding the internals of the dispatcher as endpoint routes instead of - // using if statemants we can let the matcher handle + // using if statements we can let the matcher handle var conventionBuilders = new List(); diff --git a/src/SignalR/common/Http.Connections/src/ConnectionsAppBuilderExtensions.cs b/src/SignalR/common/Http.Connections/src/ConnectionsAppBuilderExtensions.cs index a663b31d6b66..05227d6f3963 100644 --- a/src/SignalR/common/Http.Connections/src/ConnectionsAppBuilderExtensions.cs +++ b/src/SignalR/common/Http.Connections/src/ConnectionsAppBuilderExtensions.cs @@ -3,9 +3,6 @@ using System; using Microsoft.AspNetCore.Http.Connections; -using Microsoft.AspNetCore.Http.Connections.Internal; -using Microsoft.AspNetCore.Routing; -using Microsoft.Extensions.DependencyInjection; namespace Microsoft.AspNetCore.Builder { @@ -16,10 +13,15 @@ public static class ConnectionsAppBuilderExtensions { /// /// Adds support for ASP.NET Core Connection Handlers to the request execution pipeline. + /// + /// This method is obsolete and will be removed in a future version. + /// The recommended alternative is to use MapConnections or MapConnectionHandler<TConnectionHandler> inside Microsoft.AspNetCore.Builder.UseEndpoints(...). + /// /// /// The . /// A callback to configure connection routes. /// The same instance of the for chaining. + [Obsolete("This method is obsolete and will be removed in a future version. The recommended alternative is to use MapConnections or MapConnectionHandler inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")] public static IApplicationBuilder UseConnections(this IApplicationBuilder app, Action configure) { if (configure == null) @@ -27,14 +29,13 @@ public static IApplicationBuilder UseConnections(this IApplicationBuilder app, A throw new ArgumentNullException(nameof(configure)); } - var dispatcher = app.ApplicationServices.GetRequiredService(); - - var routes = new RouteBuilder(app); - - configure(new ConnectionsRouteBuilder(routes, dispatcher)); - app.UseWebSockets(); - app.UseRouter(routes.Build()); + app.UseRouting(); + app.UseAuthorization(); + app.UseEndpoints(endpoints => + { + configure(new ConnectionsRouteBuilder(endpoints)); + }); return app; } } diff --git a/src/SignalR/common/Http.Connections/src/ConnectionsRouteBuilder.cs b/src/SignalR/common/Http.Connections/src/ConnectionsRouteBuilder.cs index 5bda6e664840..b76d5688f065 100644 --- a/src/SignalR/common/Http.Connections/src/ConnectionsRouteBuilder.cs +++ b/src/SignalR/common/Http.Connections/src/ConnectionsRouteBuilder.cs @@ -2,26 +2,27 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Reflection; -using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Connections; -using Microsoft.AspNetCore.Http.Connections.Internal; using Microsoft.AspNetCore.Routing; namespace Microsoft.AspNetCore.Http.Connections { /// /// Maps routes to ASP.NET Core Connection Handlers. + /// + /// This class is obsolete and will be removed in a future version. + /// The recommended alternative is to use MapConnection and MapConnectionHandler<TConnectionHandler> inside Microsoft.AspNetCore.Builder.UseEndpoints(...). + /// /// + [Obsolete("This class is obsolete and will be removed in a future version. The recommended alternative is to use MapConnection and MapConnectionHandler inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")] public class ConnectionsRouteBuilder { - private readonly HttpConnectionDispatcher _dispatcher; - private readonly RouteBuilder _routes; + private readonly IEndpointRouteBuilder _endpoints; - internal ConnectionsRouteBuilder(RouteBuilder routes, HttpConnectionDispatcher dispatcher) + internal ConnectionsRouteBuilder(IEndpointRouteBuilder endpoints) { - _routes = routes; - _dispatcher = dispatcher; + _endpoints = endpoints; } /// @@ -38,24 +39,16 @@ public void MapConnections(PathString path, Action configure /// The request path. /// Options used to configure the connection. /// A callback to configure the connection. - public void MapConnections(PathString path, HttpConnectionDispatcherOptions options, Action configure) - { - var connectionBuilder = new ConnectionBuilder(_routes.ServiceProvider); - configure(connectionBuilder); - var socket = connectionBuilder.Build(); - _routes.MapRoute(path, c => _dispatcher.ExecuteAsync(c, options, socket)); - _routes.MapRoute(path + "/negotiate", c => _dispatcher.ExecuteNegotiateAsync(c, options)); - } + public void MapConnections(PathString path, HttpConnectionDispatcherOptions options, Action configure) => + _endpoints.MapConnections(path, options, configure); /// /// Maps incoming requests with the specified path to the provided connection pipeline. /// /// The type. /// The request path. - public void MapConnectionHandler(PathString path) where TConnectionHandler : ConnectionHandler - { + public void MapConnectionHandler(PathString path) where TConnectionHandler : ConnectionHandler => MapConnectionHandler(path, configureOptions: null); - } /// /// Maps incoming requests with the specified path to the provided connection pipeline. @@ -63,20 +56,7 @@ public void MapConnectionHandler(PathString path) where TCon /// The type. /// The request path. /// A callback to configure dispatcher options. - public void MapConnectionHandler(PathString path, Action configureOptions) where TConnectionHandler : ConnectionHandler - { - var authorizeAttributes = typeof(TConnectionHandler).GetCustomAttributes(inherit: true); - var options = new HttpConnectionDispatcherOptions(); - foreach (var attribute in authorizeAttributes) - { - options.AuthorizationData.Add(attribute); - } - configureOptions?.Invoke(options); - - MapConnections(path, options, builder => - { - builder.UseConnectionHandler(); - }); - } + public void MapConnectionHandler(PathString path, Action configureOptions) where TConnectionHandler : ConnectionHandler => + _endpoints.MapConnectionHandler(path, configureOptions); } } diff --git a/src/SignalR/common/Http.Connections/src/Internal/AuthorizeHelper.cs b/src/SignalR/common/Http.Connections/src/Internal/AuthorizeHelper.cs deleted file mode 100644 index 7b48314ef012..000000000000 --- a/src/SignalR/common/Http.Connections/src/Internal/AuthorizeHelper.cs +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System.Collections.Generic; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Authentication; -using Microsoft.AspNetCore.Authorization; -using Microsoft.AspNetCore.Authorization.Policy; -using Microsoft.Extensions.DependencyInjection; - -namespace Microsoft.AspNetCore.Http.Connections.Internal -{ - internal static class AuthorizeHelper - { - public static async Task AuthorizeAsync(HttpContext context, IList policies) - { - if (policies.Count == 0) - { - return true; - } - - var policyProvider = context.RequestServices.GetRequiredService(); - - var authorizePolicy = await AuthorizationPolicy.CombineAsync(policyProvider, policies); - - var policyEvaluator = context.RequestServices.GetRequiredService(); - - // This will set context.User if required - var authenticateResult = await policyEvaluator.AuthenticateAsync(authorizePolicy, context); - - var authorizeResult = await policyEvaluator.AuthorizeAsync(authorizePolicy, authenticateResult, context, resource: null); - if (authorizeResult.Succeeded) - { - return true; - } - else if (authorizeResult.Challenged) - { - if (authorizePolicy.AuthenticationSchemes.Count > 0) - { - foreach (var scheme in authorizePolicy.AuthenticationSchemes) - { - await context.ChallengeAsync(scheme); - } - } - else - { - await context.ChallengeAsync(); - } - return false; - } - else if (authorizeResult.Forbidden) - { - if (authorizePolicy.AuthenticationSchemes.Count > 0) - { - foreach (var scheme in authorizePolicy.AuthenticationSchemes) - { - await context.ForbidAsync(scheme); - } - } - else - { - await context.ForbidAsync(); - } - return false; - } - return false; - } - } -} diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs index 20bd0ccb2417..ea0bee5bb799 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs @@ -61,11 +61,6 @@ public async Task ExecuteAsync(HttpContext context, HttpConnectionDispatcherOpti var logScope = new ConnectionLogScope(GetConnectionId(context)); using (_logger.BeginScope(logScope)) { - if (!await AuthorizeHelper.AuthorizeAsync(context, options.AuthorizationData)) - { - return; - } - if (HttpMethods.IsPost(context.Request.Method)) { // POST /{path} @@ -95,11 +90,6 @@ public async Task ExecuteNegotiateAsync(HttpContext context, HttpConnectionDispa var logScope = new ConnectionLogScope(connectionId: string.Empty); using (_logger.BeginScope(logScope)) { - if (!await AuthorizeHelper.AuthorizeAsync(context, options.AuthorizationData)) - { - return; - } - if (HttpMethods.IsPost(context.Request.Method)) { // POST /{path}/negotiate diff --git a/src/SignalR/common/Http.Connections/src/Microsoft.AspNetCore.Http.Connections.csproj b/src/SignalR/common/Http.Connections/src/Microsoft.AspNetCore.Http.Connections.csproj index c7dfbeb9e120..175167df71a8 100644 --- a/src/SignalR/common/Http.Connections/src/Microsoft.AspNetCore.Http.Connections.csproj +++ b/src/SignalR/common/Http.Connections/src/Microsoft.AspNetCore.Http.Connections.csproj @@ -14,6 +14,7 @@ + diff --git a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs index 2fcf129790ea..d79bcd3d9215 100644 --- a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs +++ b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs @@ -1347,358 +1347,6 @@ public async Task TransferModeSet(HttpTransportType transportType, TransferForma } } - [Fact] - public async Task UnauthorizedConnectionFailsToStartEndPoint() - { - using (StartVerifiableLog()) - { - var manager = CreateConnectionManager(LoggerFactory); - var connection = manager.CreateConnection(); - connection.TransportType = HttpTransportType.LongPolling; - var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = new DefaultHttpContext(); - var services = new ServiceCollection(); - services.AddOptions(); - services.AddSingleton(); - services.AddAuthorization(o => - { - o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier)); - }); - services.AddAuthenticationCore(o => - { - o.DefaultScheme = "Default"; - o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler)); - }); - services.AddLogging(); - var sp = services.BuildServiceProvider(); - context.Request.Path = "/foo"; - context.Request.Method = "GET"; - context.RequestServices = sp; - var values = new Dictionary(); - values["id"] = connection.ConnectionId; - var qs = new QueryCollection(values); - context.Request.Query = qs; - - var builder = new ConnectionBuilder(sp); - builder.UseConnectionHandler(); - var app = builder.Build(); - var options = new HttpConnectionDispatcherOptions(); - options.AuthorizationData.Add(new AuthorizeAttribute("test")); - - // would get stuck if EndPoint was running - await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); - - Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode); - } - } - - [Fact] - public async Task AuthenticatedUserWithoutPermissionCausesForbidden() - { - using (StartVerifiableLog()) - { - var manager = CreateConnectionManager(LoggerFactory); - var connection = manager.CreateConnection(); - connection.TransportType = HttpTransportType.LongPolling; - var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = new DefaultHttpContext(); - var services = new ServiceCollection(); - services.AddOptions(); - services.AddSingleton(); - services.AddAuthorization(o => - { - o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier)); - }); - services.AddAuthenticationCore(o => - { - o.DefaultScheme = "Default"; - o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler)); - }); - services.AddLogging(); - var sp = services.BuildServiceProvider(); - context.Request.Path = "/foo"; - context.Request.Method = "GET"; - context.RequestServices = sp; - var values = new Dictionary(); - values["id"] = connection.ConnectionId; - var qs = new QueryCollection(values); - context.Request.Query = qs; - - var builder = new ConnectionBuilder(sp); - builder.UseConnectionHandler(); - var app = builder.Build(); - var options = new HttpConnectionDispatcherOptions(); - options.AuthorizationData.Add(new AuthorizeAttribute("test")); - - context.User = new ClaimsPrincipal(new ClaimsIdentity("authenticated")); - - // would get stuck if EndPoint was running - await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); - - Assert.Equal(StatusCodes.Status403Forbidden, context.Response.StatusCode); - } - } - - [Fact] - public async Task AuthorizedConnectionCanConnectToEndPoint() - { - using (StartVerifiableLog()) - { - var manager = CreateConnectionManager(LoggerFactory); - var connection = manager.CreateConnection(); - connection.TransportType = HttpTransportType.LongPolling; - var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = new DefaultHttpContext(); - context.Features.Set(new ResponseFeature()); - var services = new ServiceCollection(); - services.AddOptions(); - services.AddSingleton(); - services.AddAuthorization(o => - { - o.AddPolicy("test", policy => - { - policy.RequireClaim(ClaimTypes.NameIdentifier); - }); - }); - services.AddLogging(); - services.AddAuthenticationCore(o => - { - o.DefaultScheme = "Default"; - o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler)); - }); - var sp = services.BuildServiceProvider(); - context.Request.Path = "/foo"; - context.Request.Method = "GET"; - context.RequestServices = sp; - var values = new Dictionary(); - values["id"] = connection.ConnectionId; - var qs = new QueryCollection(values); - context.Request.Query = qs; - context.Response.Body = new MemoryStream(); - - var builder = new ConnectionBuilder(sp); - builder.UseConnectionHandler(); - var app = builder.Build(); - var options = new HttpConnectionDispatcherOptions(); - options.AuthorizationData.Add(new AuthorizeAttribute("test")); - - // "authorize" user - context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); - - var connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app); - await connectionHandlerTask.OrTimeout(); - Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); - - connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app); - await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).AsTask().OrTimeout(); - await connectionHandlerTask.OrTimeout(); - - Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); - Assert.Equal("Hello, World", GetContentAsString(context.Response.Body)); - } - } - - [Fact] - public async Task AllPoliciesRequiredForAuthorizedEndPoint() - { - using (StartVerifiableLog()) - { - var manager = CreateConnectionManager(LoggerFactory); - var connection = manager.CreateConnection(); - connection.TransportType = HttpTransportType.LongPolling; - var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = new DefaultHttpContext(); - context.Features.Set(new ResponseFeature()); - var services = new ServiceCollection(); - services.AddOptions(); - services.AddSingleton(); - services.AddAuthorization(o => - { - o.AddPolicy("test", policy => - { - policy.RequireClaim(ClaimTypes.NameIdentifier); - }); - o.AddPolicy("secondPolicy", policy => - { - policy.RequireClaim(ClaimTypes.StreetAddress); - }); - }); - services.AddLogging(); - services.AddAuthenticationCore(o => - { - o.DefaultScheme = "Default"; - o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler)); - }); - var sp = services.BuildServiceProvider(); - context.Request.Path = "/foo"; - context.Request.Method = "GET"; - context.RequestServices = sp; - var values = new Dictionary(); - values["id"] = connection.ConnectionId; - var qs = new QueryCollection(values); - context.Request.Query = qs; - context.Response.Body = new MemoryStream(); - - var builder = new ConnectionBuilder(sp); - builder.UseConnectionHandler(); - var app = builder.Build(); - var options = new HttpConnectionDispatcherOptions(); - options.AuthorizationData.Add(new AuthorizeAttribute("test")); - options.AuthorizationData.Add(new AuthorizeAttribute("secondPolicy")); - - // partially "authorize" user - context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); - - // would get stuck if EndPoint was running - await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); - - Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode); - - // reset HttpContext - context = new DefaultHttpContext(); - context.Features.Set(new ResponseFeature()); - context.Request.Path = "/foo"; - context.Request.Method = "GET"; - context.RequestServices = sp; - context.Request.Query = qs; - context.Response.Body = new MemoryStream(); - // fully "authorize" user - context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] - { - new Claim(ClaimTypes.NameIdentifier, "name"), - new Claim(ClaimTypes.StreetAddress, "12345 123rd St. NW") - })); - - // First poll - var connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app); - Assert.True(connectionHandlerTask.IsCompleted); - Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); - - connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app); - await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).AsTask().OrTimeout(); - - await connectionHandlerTask.OrTimeout(); - - Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); - Assert.Equal("Hello, World", GetContentAsString(context.Response.Body)); - } - } - - [Fact] - public async Task AuthorizedConnectionWithAcceptedSchemesCanConnectToEndPoint() - { - using (StartVerifiableLog()) - { - var manager = CreateConnectionManager(LoggerFactory); - var connection = manager.CreateConnection(); - connection.TransportType = HttpTransportType.LongPolling; - var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = new DefaultHttpContext(); - context.Features.Set(new ResponseFeature()); - var services = new ServiceCollection(); - services.AddOptions(); - services.AddSingleton(); - services.AddAuthorization(o => - { - o.AddPolicy("test", policy => - { - policy.RequireClaim(ClaimTypes.NameIdentifier); - policy.AddAuthenticationSchemes("Default"); - }); - }); - services.AddLogging(); - services.AddAuthenticationCore(o => - { - o.DefaultScheme = "Default"; - o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler)); - }); - var sp = services.BuildServiceProvider(); - context.Request.Path = "/foo"; - context.Request.Method = "GET"; - context.RequestServices = sp; - var values = new Dictionary(); - values["id"] = connection.ConnectionId; - var qs = new QueryCollection(values); - context.Request.Query = qs; - context.Response.Body = new MemoryStream(); - - var builder = new ConnectionBuilder(sp); - builder.UseConnectionHandler(); - var app = builder.Build(); - var options = new HttpConnectionDispatcherOptions(); - options.AuthorizationData.Add(new AuthorizeAttribute("test")); - - // "authorize" user - context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); - - // Initial poll - var connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app); - Assert.True(connectionHandlerTask.IsCompleted); - Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); - - connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app); - await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).AsTask().OrTimeout(); - - await connectionHandlerTask.OrTimeout(); - - Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); - Assert.Equal("Hello, World", GetContentAsString(context.Response.Body)); - } - } - - [Fact] - public async Task AuthorizedConnectionWithRejectedSchemesFailsToConnectToEndPoint() - { - using (StartVerifiableLog()) - { - var manager = CreateConnectionManager(LoggerFactory); - var connection = manager.CreateConnection(); - connection.TransportType = HttpTransportType.LongPolling; - var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = new DefaultHttpContext(); - var services = new ServiceCollection(); - services.AddOptions(); - services.AddSingleton(); - services.AddAuthorization(o => - { - o.AddPolicy("test", policy => - { - policy.RequireClaim(ClaimTypes.NameIdentifier); - policy.AddAuthenticationSchemes("Default"); - }); - }); - services.AddLogging(); - services.AddAuthenticationCore(o => - { - o.DefaultScheme = "Default"; - o.AddScheme("Default", a => a.HandlerType = typeof(RejectHandler)); - }); - var sp = services.BuildServiceProvider(); - context.Request.Path = "/foo"; - context.Request.Method = "GET"; - context.RequestServices = sp; - var values = new Dictionary(); - values["id"] = connection.ConnectionId; - var qs = new QueryCollection(values); - context.Request.Query = qs; - context.Response.Body = new MemoryStream(); - - var builder = new ConnectionBuilder(sp); - builder.UseConnectionHandler(); - var app = builder.Build(); - var options = new HttpConnectionDispatcherOptions(); - options.AuthorizationData.Add(new AuthorizeAttribute("test")); - - // "authorize" user - context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); - - // would block if EndPoint was executed - await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); - - Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode); - } - } - [Fact] public async Task SetsInherentKeepAliveFeatureOnFirstLongPollingRequest() { @@ -2119,50 +1767,6 @@ bool ExpectedErrors(WriteContext writeContext) } } - private class RejectHandler : TestAuthenticationHandler - { - protected override bool ShouldAccept => false; - } - - private class TestAuthenticationHandler : IAuthenticationHandler - { - private HttpContext HttpContext; - private AuthenticationScheme _scheme; - - protected virtual bool ShouldAccept => true; - - public Task AuthenticateAsync() - { - if (ShouldAccept) - { - return Task.FromResult(AuthenticateResult.Success(new AuthenticationTicket(HttpContext.User, _scheme.Name))); - } - else - { - return Task.FromResult(AuthenticateResult.NoResult()); - } - } - - public Task ChallengeAsync(AuthenticationProperties properties) - { - HttpContext.Response.StatusCode = StatusCodes.Status401Unauthorized; - return Task.CompletedTask; - } - - public Task ForbidAsync(AuthenticationProperties properties) - { - HttpContext.Response.StatusCode = StatusCodes.Status403Forbidden; - return Task.CompletedTask; - } - - public Task InitializeAsync(AuthenticationScheme scheme, HttpContext context) - { - HttpContext = context; - _scheme = scheme; - return Task.CompletedTask; - } - } - private static async Task CheckTransportSupported(HttpTransportType supportedTransports, HttpTransportType transportType, int status, ILoggerFactory loggerFactory) { var manager = CreateConnectionManager(loggerFactory); diff --git a/src/SignalR/common/Http.Connections/test/MapConnectionHandlerTests.cs b/src/SignalR/common/Http.Connections/test/MapConnectionHandlerTests.cs index 4541f5e3f1a6..1b5716c01c29 100644 --- a/src/SignalR/common/Http.Connections/test/MapConnectionHandlerTests.cs +++ b/src/SignalR/common/Http.Connections/test/MapConnectionHandlerTests.cs @@ -4,11 +4,14 @@ using System; using System.Linq; using System.Net.WebSockets; +using System.Reflection; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Cors; +using Microsoft.AspNetCore.Cors.Infrastructure; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Hosting.Server.Features; using Microsoft.AspNetCore.Routing; @@ -34,39 +37,113 @@ public MapConnectionHandlerTests(ITestOutputHelper output) public void MapConnectionHandlerFindsAuthAttributeOnEndPoint() { var authCount = 0; - using (var builder = BuildWebHost("/auth", + using (var host = BuildWebHost("/auth", options => authCount += options.AuthorizationData.Count)) { - builder.Start(); + host.Start(); + + var dataSource = host.Services.GetRequiredService(); + // We register 2 endpoints (/negotiate and /) + Assert.Collection(dataSource.Endpoints, + endpoint => + { + Assert.Equal("/auth/negotiate", endpoint.DisplayName); + Assert.Single(endpoint.Metadata.GetOrderedMetadata()); + }, + endpoint => + { + Assert.Equal("/auth", endpoint.DisplayName); + Assert.Single(endpoint.Metadata.GetOrderedMetadata()); + }); } - Assert.Equal(1, authCount); + Assert.Equal(0, authCount); } [Fact] public void MapConnectionHandlerFindsAuthAttributeOnInheritedEndPoint() { var authCount = 0; - using (var builder = BuildWebHost("/auth", + using (var host = BuildWebHost("/auth", options => authCount += options.AuthorizationData.Count)) { - builder.Start(); + host.Start(); + + var dataSource = host.Services.GetRequiredService(); + // We register 2 endpoints (/negotiate and /) + Assert.Collection(dataSource.Endpoints, + endpoint => + { + Assert.Equal("/auth/negotiate", endpoint.DisplayName); + Assert.Single(endpoint.Metadata.GetOrderedMetadata()); + }, + endpoint => + { + Assert.Equal("/auth", endpoint.DisplayName); + Assert.Single(endpoint.Metadata.GetOrderedMetadata()); + }); } - Assert.Equal(1, authCount); + Assert.Equal(0, authCount); } [Fact] + public void MapConnectionHandlerFindsAuthAttributesOnDoubleAuthEndPoint() { var authCount = 0; - using (var builder = BuildWebHost("/auth", + using (var host = BuildWebHost("/auth", options => authCount += options.AuthorizationData.Count)) { - builder.Start(); + host.Start(); + + var dataSource = host.Services.GetRequiredService(); + // We register 2 endpoints (/negotiate and /) + Assert.Collection(dataSource.Endpoints, + endpoint => + { + Assert.Equal("/auth/negotiate", endpoint.DisplayName); + Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata().Count); + }, + endpoint => + { + Assert.Equal("/auth", endpoint.DisplayName); + Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata().Count); + }); + } + + Assert.Equal(0, authCount); + } + + [Fact] + public void MapConnectionHandlerFindsAttributesFromEndPointAndOptions() + { + var authCount = 0; + using (var host = BuildWebHost("/auth", + options => + { + authCount += options.AuthorizationData.Count; + options.AuthorizationData.Add(new AuthorizeAttribute()); + })) + { + host.Start(); + + var dataSource = host.Services.GetRequiredService(); + // We register 2 endpoints (/negotiate and /) + Assert.Collection(dataSource.Endpoints, + endpoint => + { + Assert.Equal("/auth/negotiate", endpoint.DisplayName); + Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata().Count); + }, + endpoint => + { + Assert.Equal("/auth", endpoint.DisplayName); + Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata().Count); + }); } - Assert.Equal(2, authCount); + Assert.Equal(0, authCount); } [Fact] @@ -82,12 +159,50 @@ public void MapConnectionHandlerEndPointRoutingFindsAttributesOnHub() var dataSource = host.Services.GetRequiredService(); // We register 2 endpoints (/negotiate and /) - Assert.Equal(2, dataSource.Endpoints.Count); - Assert.NotNull(dataSource.Endpoints[0].Metadata.GetMetadata()); - Assert.NotNull(dataSource.Endpoints[1].Metadata.GetMetadata()); + Assert.Collection(dataSource.Endpoints, + endpoint => + { + Assert.Equal("/path/negotiate", endpoint.DisplayName); + Assert.Single(endpoint.Metadata.GetOrderedMetadata()); + }, + endpoint => + { + Assert.Equal("/path", endpoint.DisplayName); + Assert.Single(endpoint.Metadata.GetOrderedMetadata()); + }); } - Assert.Equal(1, authCount); + Assert.Equal(0, authCount); + } + + [Fact] + public void MapConnectionHandlerEndPointRoutingFindsAttributesFromOptions() + { + var authCount = 0; + using (var host = BuildWebHostWithEndPointRouting(routes => routes.MapConnectionHandler("/path", options => + { + authCount += options.AuthorizationData.Count; + options.AuthorizationData.Add(new AuthorizeAttribute()); + }))) + { + host.Start(); + + var dataSource = host.Services.GetRequiredService(); + // We register 2 endpoints (/negotiate and /) + Assert.Collection(dataSource.Endpoints, + endpoint => + { + Assert.Equal("/path/negotiate", endpoint.DisplayName); + Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata().Count); + }, + endpoint => + { + Assert.Equal("/path", endpoint.DisplayName); + Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata().Count); + }); + } + + Assert.Equal(0, authCount); } [Fact] @@ -106,9 +221,27 @@ void ConfigureRoutes(IEndpointRouteBuilder endpoints) var dataSource = host.Services.GetRequiredService(); // We register 2 endpoints (/negotiate and /) - Assert.Equal(2, dataSource.Endpoints.Count); - Assert.Equal("Foo", dataSource.Endpoints[0].Metadata.GetMetadata()?.Policy); - Assert.Equal("Foo", dataSource.Endpoints[1].Metadata.GetMetadata()?.Policy); + Assert.Collection(dataSource.Endpoints, + endpoint => + { + Assert.Equal("/path/negotiate", endpoint.DisplayName); + Assert.Collection(endpoint.Metadata.GetOrderedMetadata(), + auth => { }, + auth => + { + Assert.Equal("Foo", auth?.Policy); + }); + }, + endpoint => + { + Assert.Equal("/path", endpoint.DisplayName); + Assert.Collection(endpoint.Metadata.GetOrderedMetadata(), + auth => { }, + auth => + { + Assert.Equal("Foo", auth?.Policy); + }); + }); } } @@ -126,9 +259,45 @@ void ConfigureRoutes(IEndpointRouteBuilder endpoints) var dataSource = host.Services.GetRequiredService(); // We register 2 endpoints (/negotiate and /) - Assert.Equal(2, dataSource.Endpoints.Count); - Assert.NotNull(dataSource.Endpoints[0].Metadata.GetMetadata()); - Assert.Null(dataSource.Endpoints[1].Metadata.GetMetadata()); + Assert.Collection(dataSource.Endpoints, + endpoint => + { + Assert.Equal("/path/negotiate", endpoint.DisplayName); + Assert.NotNull(endpoint.Metadata.GetMetadata()); + }, + endpoint => + { + Assert.Equal("/path", endpoint.DisplayName); + Assert.Null(endpoint.Metadata.GetMetadata()); + }); + } + } + + [Fact] + public void MapConnectionHandlerEndPointRoutingAppliesCorsMetadata() + { + void ConfigureRoutes(IEndpointRouteBuilder endpoints) + { + endpoints.MapConnectionHandler("/path"); + } + + using (var host = BuildWebHostWithEndPointRouting(ConfigureRoutes)) + { + host.Start(); + + var dataSource = host.Services.GetRequiredService(); + // We register 2 endpoints (/negotiate and /) + Assert.Collection(dataSource.Endpoints, + endpoint => + { + Assert.Equal("/path/negotiate", endpoint.DisplayName); + Assert.NotNull(endpoint.Metadata.GetMetadata()); + }, + endpoint => + { + Assert.Equal("/path", endpoint.DisplayName); + Assert.NotNull(endpoint.Metadata.GetMetadata()); + }); } } @@ -177,6 +346,15 @@ public override async Task OnConnectedAsync(ConnectionContext connection) } } + [EnableCors] + private class CorsConnectionHandler : ConnectionHandler + { + public override Task OnConnectedAsync(ConnectionContext connection) + { + throw new NotImplementedException(); + } + } + private class InheritedAuthConnectionHandler : AuthConnectionHandler { public override Task OnConnectedAsync(ConnectionContext connection) @@ -227,10 +405,12 @@ private IWebHost BuildWebHost(string path, Action { +#pragma warning disable CS0618 // Type or member is obsolete app.UseConnections(routes => { routes.MapConnectionHandler(path, configureOptions); }); +#pragma warning restore CS0618 // Type or member is obsolete }) .ConfigureLogging(factory => { diff --git a/src/SignalR/common/Http.Connections/test/Microsoft.AspNetCore.Http.Connections.Tests.csproj b/src/SignalR/common/Http.Connections/test/Microsoft.AspNetCore.Http.Connections.Tests.csproj index 1f6b36a8bae9..5a8aca685fdb 100644 --- a/src/SignalR/common/Http.Connections/test/Microsoft.AspNetCore.Http.Connections.Tests.csproj +++ b/src/SignalR/common/Http.Connections/test/Microsoft.AspNetCore.Http.Connections.Tests.csproj @@ -15,6 +15,7 @@ + diff --git a/src/SignalR/server/Core/src/Internal/TaskCache.cs b/src/SignalR/common/Shared/TaskCache.cs similarity index 88% rename from src/SignalR/server/Core/src/Internal/TaskCache.cs rename to src/SignalR/common/Shared/TaskCache.cs index e11f53506e58..2df6d85ed979 100644 --- a/src/SignalR/server/Core/src/Internal/TaskCache.cs +++ b/src/SignalR/common/Shared/TaskCache.cs @@ -3,11 +3,11 @@ using System.Threading.Tasks; -namespace Microsoft.AspNetCore.SignalR.Internal +namespace Microsoft.AspNetCore.Internal { internal static class TaskCache { public static readonly Task True = Task.FromResult(true); public static readonly Task False = Task.FromResult(false); } -} \ No newline at end of file +} diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index ee39dcad9d1d..16f1b8da63e9 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -11,6 +11,7 @@ using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Internal; diff --git a/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj b/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj index 8d8194fbc7ae..5db9ea75b23f 100644 --- a/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj +++ b/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj @@ -12,6 +12,7 @@ + diff --git a/src/SignalR/server/SignalR/ref/Microsoft.AspNetCore.SignalR.netcoreapp3.0.cs b/src/SignalR/server/SignalR/ref/Microsoft.AspNetCore.SignalR.netcoreapp3.0.cs index 27663babccad..56b62a7317ec 100644 --- a/src/SignalR/server/SignalR/ref/Microsoft.AspNetCore.SignalR.netcoreapp3.0.cs +++ b/src/SignalR/server/SignalR/ref/Microsoft.AspNetCore.SignalR.netcoreapp3.0.cs @@ -10,6 +10,7 @@ public static partial class HubEndpointRouteBuilderExtensions } public static partial class SignalRAppBuilderExtensions { + [System.ObsoleteAttribute("This method is obsolete and will be removed in a future version. The recommended alternative is to use MapHub inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")] public static Microsoft.AspNetCore.Builder.IApplicationBuilder UseSignalR(this Microsoft.AspNetCore.Builder.IApplicationBuilder app, System.Action configure) { throw null; } } } @@ -25,6 +26,7 @@ public sealed partial class HubEndpointConventionBuilder : Microsoft.AspNetCore. internal HubEndpointConventionBuilder() { } public void Add(System.Action convention) { } } + [System.ObsoleteAttribute("This class is obsolete and will be removed in a future version. The recommended alternative is to use MapHub inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")] public partial class HubRouteBuilder { public HubRouteBuilder(Microsoft.AspNetCore.Http.Connections.ConnectionsRouteBuilder routes) { } diff --git a/src/SignalR/server/SignalR/src/HubEndpointRouteBuilderExtensions.cs b/src/SignalR/server/SignalR/src/HubEndpointRouteBuilderExtensions.cs index f53632d00f6d..553ea3f8ce12 100644 --- a/src/SignalR/server/SignalR/src/HubEndpointRouteBuilderExtensions.cs +++ b/src/SignalR/server/SignalR/src/HubEndpointRouteBuilderExtensions.cs @@ -2,8 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Linq; -using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Http.Connections; using Microsoft.AspNetCore.Routing; using Microsoft.AspNetCore.SignalR; @@ -44,14 +42,6 @@ public static HubEndpointConventionBuilder MapHub(this IEndpointRouteBuild } var options = new HttpConnectionDispatcherOptions(); - // REVIEW: WE should consider removing this and instead just relying on the - // AuthorizationMiddleware - var attributes = typeof(THub).GetCustomAttributes(inherit: true); - foreach (var attribute in attributes.OfType()) - { - options.AuthorizationData.Add(attribute); - } - configureOptions?.Invoke(options); var conventionBuilder = endpoints.MapConnections(pattern, options, b => @@ -59,9 +49,10 @@ public static HubEndpointConventionBuilder MapHub(this IEndpointRouteBuild b.UseHub(); }); + var attributes = typeof(THub).GetCustomAttributes(inherit: true); conventionBuilder.Add(e => { - // Add all attributes on the Hub has metadata (this will allow for things like) + // Add all attributes on the Hub as metadata (this will allow for things like) // auth attributes and cors attributes to work seamlessly foreach (var item in attributes) { diff --git a/src/SignalR/server/SignalR/src/HubRouteBuilder.cs b/src/SignalR/server/SignalR/src/HubRouteBuilder.cs index 06f945bbb0bc..cf9f694bb355 100644 --- a/src/SignalR/server/SignalR/src/HubRouteBuilder.cs +++ b/src/SignalR/server/SignalR/src/HubRouteBuilder.cs @@ -4,17 +4,25 @@ using System; using System.Reflection; using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Connections; +using Microsoft.AspNetCore.Routing; namespace Microsoft.AspNetCore.SignalR { /// /// Maps incoming requests to types. + /// + /// This class is obsolete and will be removed in a future version. + /// The recommended alternative is to use MapHub<THub> inside Microsoft.AspNetCore.Builder.UseEndpoints(...). + /// /// + [Obsolete("This class is obsolete and will be removed in a future version. The recommended alternative is to use MapHub inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")] public class HubRouteBuilder { private readonly ConnectionsRouteBuilder _routes; + private readonly IEndpointRouteBuilder _endpoints; /// /// Initializes a new instance of the class. @@ -25,6 +33,11 @@ public HubRouteBuilder(ConnectionsRouteBuilder routes) _routes = routes; } + internal HubRouteBuilder(IEndpointRouteBuilder endpoints) + { + _endpoints = endpoints; + } + /// /// Maps incoming requests with the specified path to the specified type. /// @@ -43,6 +56,14 @@ public void MapHub(PathString path) where THub : Hub /// A callback to configure dispatcher options. public void MapHub(PathString path, Action configureOptions) where THub : Hub { + // This will be null if someone is manually using the HubRouteBuilder(ConnectionsRouteBuilder routes) constructor + // SignalR itself will only use the IEndpointRouteBuilder overload + if (_endpoints != null) + { + _endpoints.MapHub(path, configureOptions); + return; + } + // find auth attributes var authorizeAttributes = typeof(THub).GetCustomAttributes(inherit: true); var options = new HttpConnectionDispatcherOptions(); diff --git a/src/SignalR/server/SignalR/src/SignalRAppBuilderExtensions.cs b/src/SignalR/server/SignalR/src/SignalRAppBuilderExtensions.cs index 8a85de1a7096..e01870193a52 100644 --- a/src/SignalR/server/SignalR/src/SignalRAppBuilderExtensions.cs +++ b/src/SignalR/server/SignalR/src/SignalRAppBuilderExtensions.cs @@ -14,10 +14,15 @@ public static class SignalRAppBuilderExtensions { /// /// Adds SignalR to the request execution pipeline. + /// + /// This method is obsolete and will be removed in a future version. + /// The recommended alternative is to use MapHub<THub> inside Microsoft.AspNetCore.Builder.UseEndpoints(...). + /// /// /// The . /// A callback to configure hub routes. /// The same instance of the for chaining. + [Obsolete("This method is obsolete and will be removed in a future version. The recommended alternative is to use MapHub inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")] public static IApplicationBuilder UseSignalR(this IApplicationBuilder app, Action configure) { var marker = app.ApplicationServices.GetService(); @@ -27,9 +32,12 @@ public static IApplicationBuilder UseSignalR(this IApplicationBuilder app, Actio "'IServiceCollection.AddSignalR' inside the call to 'ConfigureServices(...)' in the application startup code."); } - app.UseConnections(routes => + app.UseWebSockets(); + app.UseRouting(); + app.UseAuthorization(); + app.UseEndpoints(endpoints => { - configure(new HubRouteBuilder(routes)); + configure(new HubRouteBuilder(endpoints)); }); return app; diff --git a/src/SignalR/server/SignalR/test/AuthHub.cs b/src/SignalR/server/SignalR/test/AuthHub.cs new file mode 100644 index 000000000000..6adf678870c8 --- /dev/null +++ b/src/SignalR/server/SignalR/test/AuthHub.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Authorization; + +namespace Microsoft.AspNetCore.SignalR.Tests +{ + [Authorize] + class AuthHub : Hub + { + } +} diff --git a/src/SignalR/server/SignalR/test/EndToEndTests.cs b/src/SignalR/server/SignalR/test/EndToEndTests.cs index ae59219a945a..b3aabeee1c8c 100644 --- a/src/SignalR/server/SignalR/test/EndToEndTests.cs +++ b/src/SignalR/server/SignalR/test/EndToEndTests.cs @@ -446,6 +446,52 @@ bool ExpectedErrors(WriteContext writeContext) } } + [Fact] + [LogLevel(LogLevel.Trace)] + public async Task AuthorizedConnectionCanConnect() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HttpConnection).FullName && + writeContext.EventId.Name == "ErrorWithNegotiation"; + } + + using (StartServer(out var server, ExpectedErrors)) + { + var logger = LoggerFactory.CreateLogger(); + + string token; + using (var client = new HttpClient()) + { + client.BaseAddress = new Uri(server.Url); + + var response = await client.GetAsync("generatetoken?user=bob"); + token = await response.Content.ReadAsStringAsync(); + } + + var url = server.Url + "/auth"; + var connection = new HttpConnection(new HttpConnectionOptions() + { + AccessTokenProvider = () => Task.FromResult(token), + Url = new Uri(url), + Transports = HttpTransportType.ServerSentEvents + }, LoggerFactory); + + try + { + logger.LogInformation("Starting connection to {url}", url); + await connection.StartAsync(TransferFormat.Text).OrTimeout(); + logger.LogInformation("Connected to {url}", url); + } + finally + { + logger.LogInformation("Disposing Connection"); + await connection.DisposeAsync().OrTimeout(); + logger.LogInformation("Disposed Connection"); + } + } + } + [ConditionalFact] [WebSocketsSupportedCondition] public async Task ServerClosesConnectionWithErrorIfHubCannotBeCreated_WebSocket() @@ -532,6 +578,178 @@ private async Task ServerClosesConnectionWithErrorIfHubCannotBeCreated(HttpTrans } } + [Fact] + [LogLevel(LogLevel.Trace)] + public async Task UnauthorizedHubConnectionDoesNotConnectWithEndpoints() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HttpConnection).FullName && + writeContext.EventId.Name == "ErrorWithNegotiation"; + } + + using (StartServer(out var server, ExpectedErrors)) + { + var logger = LoggerFactory.CreateLogger(); + + var url = server.Url + "/authHubEndpoints"; + var connection = new HubConnectionBuilder() + .WithLoggerFactory(LoggerFactory) + .WithUrl(url, HttpTransportType.LongPolling) + .Build(); + + try + { + logger.LogInformation("Starting connection to {url}", url); + await connection.StartAsync().OrTimeout(); + Assert.True(false); + } + catch (Exception ex) + { + Assert.Equal("Response status code does not indicate success: 401 (Unauthorized).", ex.Message); + } + finally + { + logger.LogInformation("Disposing Connection"); + await connection.DisposeAsync().OrTimeout(); + logger.LogInformation("Disposed Connection"); + } + } + } + + [Fact] + [LogLevel(LogLevel.Trace)] + public async Task UnauthorizedHubConnectionDoesNotConnect() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HttpConnection).FullName && + writeContext.EventId.Name == "ErrorWithNegotiation"; + } + + using (StartServer(out var server, ExpectedErrors)) + { + var logger = LoggerFactory.CreateLogger(); + + var url = server.Url + "/authHub"; + var connection = new HubConnectionBuilder() + .WithLoggerFactory(LoggerFactory) + .WithUrl(url, HttpTransportType.LongPolling) + .Build(); + + try + { + logger.LogInformation("Starting connection to {url}", url); + await connection.StartAsync().OrTimeout(); + Assert.True(false); + } + catch (Exception ex) + { + Assert.Equal("Response status code does not indicate success: 401 (Unauthorized).", ex.Message); + } + finally + { + logger.LogInformation("Disposing Connection"); + await connection.DisposeAsync().OrTimeout(); + logger.LogInformation("Disposed Connection"); + } + } + } + + [Fact] + [LogLevel(LogLevel.Trace)] + public async Task AuthorizedHubConnectionCanConnectWithEndpoints() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HttpConnection).FullName && + writeContext.EventId.Name == "ErrorWithNegotiation"; + } + + using (StartServer(out var server, ExpectedErrors)) + { + var logger = LoggerFactory.CreateLogger(); + + string token; + using (var client = new HttpClient()) + { + client.BaseAddress = new Uri(server.Url); + + var response = await client.GetAsync("generatetoken?user=bob"); + token = await response.Content.ReadAsStringAsync(); + } + + var url = server.Url + "/authHubEndpoints"; + var connection = new HubConnectionBuilder() + .WithLoggerFactory(LoggerFactory) + .WithUrl(url, HttpTransportType.LongPolling, o => + { + o.AccessTokenProvider = () => Task.FromResult(token); + }) + .Build(); + + try + { + logger.LogInformation("Starting connection to {url}", url); + await connection.StartAsync().OrTimeout(); + logger.LogInformation("Connected to {url}", url); + } + finally + { + logger.LogInformation("Disposing Connection"); + await connection.DisposeAsync().OrTimeout(); + logger.LogInformation("Disposed Connection"); + } + } + } + + [Fact] + [LogLevel(LogLevel.Trace)] + public async Task AuthorizedHubConnectionCanConnect() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HttpConnection).FullName && + writeContext.EventId.Name == "ErrorWithNegotiation"; + } + + using (StartServer(out var server, ExpectedErrors)) + { + var logger = LoggerFactory.CreateLogger(); + + string token; + using (var client = new HttpClient()) + { + client.BaseAddress = new Uri(server.Url); + + var response = await client.GetAsync("generatetoken?user=bob"); + token = await response.Content.ReadAsStringAsync(); + } + + var url = server.Url + "/authHub"; + var connection = new HubConnectionBuilder() + .WithLoggerFactory(LoggerFactory) + .WithUrl(url, HttpTransportType.LongPolling, o => + { + o.AccessTokenProvider = () => Task.FromResult(token); + }) + .Build(); + + try + { + logger.LogInformation("Starting connection to {url}", url); + await connection.StartAsync().OrTimeout(); + logger.LogInformation("Connected to {url}", url); + } + finally + { + logger.LogInformation("Disposing Connection"); + await connection.DisposeAsync().OrTimeout(); + logger.LogInformation("Disposed Connection"); + } + } + } + // Serves a fake transport that lets us verify fallback behavior private class TestTransportFactory : ITransportFactory { diff --git a/src/SignalR/server/SignalR/test/MapSignalRTests.cs b/src/SignalR/server/SignalR/test/MapSignalRTests.cs index 40ea60382fe6..01bc4d2e3370 100644 --- a/src/SignalR/server/SignalR/test/MapSignalRTests.cs +++ b/src/SignalR/server/SignalR/test/MapSignalRTests.cs @@ -1,7 +1,4 @@ using System; -using System.Collections; -using System.Collections.Generic; -using System.Threading.Tasks; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; @@ -42,10 +39,12 @@ public void NotAddingSignalRServiceThrows() var ex = Assert.Throws(() => { +#pragma warning disable CS0618 // Type or member is obsolete app.UseSignalR(routes => { routes.MapHub("/overloads"); }); +#pragma warning restore CS0618 // Type or member is obsolete }); Assert.Equal("Unable to find the required services. Please add all the required services by calling " + @@ -109,9 +108,23 @@ public void MapHubFindsAuthAttributeOnHub() }))) { host.Start(); + + var dataSource = host.Services.GetRequiredService(); + // We register 2 endpoints (/negotiate and /) + Assert.Collection(dataSource.Endpoints, + endpoint => + { + Assert.Equal("/path/negotiate", endpoint.DisplayName); + Assert.Equal(1, endpoint.Metadata.GetOrderedMetadata().Count); + }, + endpoint => + { + Assert.Equal("/path", endpoint.DisplayName); + Assert.Equal(1, endpoint.Metadata.GetOrderedMetadata().Count); + }); } - Assert.Equal(1, authCount); + Assert.Equal(0, authCount); } [Fact] @@ -124,9 +137,23 @@ public void MapHubFindsAuthAttributeOnInheritedHub() }))) { host.Start(); + + var dataSource = host.Services.GetRequiredService(); + // We register 2 endpoints (/negotiate and /) + Assert.Collection(dataSource.Endpoints, + endpoint => + { + Assert.Equal("/path/negotiate", endpoint.DisplayName); + Assert.Equal(1, endpoint.Metadata.GetOrderedMetadata().Count); + }, + endpoint => + { + Assert.Equal("/path", endpoint.DisplayName); + Assert.Equal(1, endpoint.Metadata.GetOrderedMetadata().Count); + }); } - Assert.Equal(1, authCount); + Assert.Equal(0, authCount); } [Fact] @@ -139,9 +166,23 @@ public void MapHubFindsMultipleAuthAttributesOnDoubleAuthHub() }))) { host.Start(); + + var dataSource = host.Services.GetRequiredService(); + // We register 2 endpoints (/negotiate and /) + Assert.Collection(dataSource.Endpoints, + endpoint => + { + Assert.Equal("/path/negotiate", endpoint.DisplayName); + Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata().Count); + }, + endpoint => + { + Assert.Equal("/path", endpoint.DisplayName); + Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata().Count); + }); } - Assert.Equal(2, authCount); + Assert.Equal(0, authCount); } [Fact] @@ -157,12 +198,52 @@ public void MapHubEndPointRoutingFindsAttributesOnHub() var dataSource = host.Services.GetRequiredService(); // We register 2 endpoints (/negotiate and /) - Assert.Equal(2, dataSource.Endpoints.Count); - Assert.NotNull(dataSource.Endpoints[0].Metadata.GetMetadata()); - Assert.NotNull(dataSource.Endpoints[1].Metadata.GetMetadata()); + Assert.Collection(dataSource.Endpoints, + endpoint => + { + Assert.Equal("/path/negotiate", endpoint.DisplayName); + Assert.Equal(1, endpoint.Metadata.GetOrderedMetadata().Count); + }, + endpoint => + { + Assert.Equal("/path", endpoint.DisplayName); + Assert.Equal(1, endpoint.Metadata.GetOrderedMetadata().Count); + }); } - Assert.Equal(1, authCount); + Assert.Equal(0, authCount); + } + + [Fact] + public void MapHubEndPointRoutingFindsAttributesOnHubAndFromOptions() + { + var authCount = 0; + HttpConnectionDispatcherOptions configuredOptions = null; + using (var host = BuildWebHostWithEndPointRouting(routes => routes.MapHub("/path", options => + { + authCount += options.AuthorizationData.Count; + options.AuthorizationData.Add(new AuthorizeAttribute()); + configuredOptions = options; + }))) + { + host.Start(); + + var dataSource = host.Services.GetRequiredService(); + // We register 2 endpoints (/negotiate and /) + Assert.Collection(dataSource.Endpoints, + endpoint => + { + Assert.Equal("/path/negotiate", endpoint.DisplayName); + Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata().Count); + }, + endpoint => + { + Assert.Equal("/path", endpoint.DisplayName); + Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata().Count); + }); + } + + Assert.Equal(0, authCount); } [Fact] @@ -181,9 +262,27 @@ void ConfigureRoutes(IEndpointRouteBuilder endpoints) var dataSource = host.Services.GetRequiredService(); // We register 2 endpoints (/negotiate and /) - Assert.Equal(2, dataSource.Endpoints.Count); - Assert.Equal("Foo", dataSource.Endpoints[0].Metadata.GetMetadata()?.Policy); - Assert.Equal("Foo", dataSource.Endpoints[1].Metadata.GetMetadata()?.Policy); + Assert.Collection(dataSource.Endpoints, + endpoint => + { + Assert.Equal("/path/negotiate", endpoint.DisplayName); + Assert.Collection(endpoint.Metadata.GetOrderedMetadata(), + auth => { }, + auth => + { + Assert.Equal("Foo", auth?.Policy); + }); + }, + endpoint => + { + Assert.Equal("/path", endpoint.DisplayName); + Assert.Collection(endpoint.Metadata.GetOrderedMetadata(), + auth => { }, + auth => + { + Assert.Equal("Foo", auth?.Policy); + }); + }); } } @@ -202,11 +301,52 @@ void ConfigureRoutes(IEndpointRouteBuilder endpoints) var dataSource = host.Services.GetRequiredService(); // We register 2 endpoints (/negotiate and /) - Assert.Equal(2, dataSource.Endpoints.Count); - Assert.Equal(typeof(AuthHub), dataSource.Endpoints[0].Metadata.GetMetadata()?.HubType); - Assert.Equal(typeof(AuthHub), dataSource.Endpoints[1].Metadata.GetMetadata()?.HubType); - Assert.NotNull(dataSource.Endpoints[0].Metadata.GetMetadata()); - Assert.Null(dataSource.Endpoints[1].Metadata.GetMetadata()); + Assert.Collection(dataSource.Endpoints, + endpoint => + { + Assert.Equal("/path/negotiate", endpoint.DisplayName); + Assert.Equal(typeof(AuthHub), endpoint.Metadata.GetMetadata()?.HubType); + Assert.NotNull(endpoint.Metadata.GetMetadata()); + }, + endpoint => + { + Assert.Equal("/path", endpoint.DisplayName); + Assert.Equal(typeof(AuthHub), endpoint.Metadata.GetMetadata()?.HubType); + Assert.Null(endpoint.Metadata.GetMetadata()); + }); + } + } + + [Fact] + public void MapHubAppliesHubMetadata() + { +#pragma warning disable CS0618 // Type or member is obsolete + void ConfigureRoutes(HubRouteBuilder routes) +#pragma warning restore CS0618 // Type or member is obsolete + { + // This "Foo" policy should override the default auth attribute + routes.MapHub("/path"); + } + + using (var host = BuildWebHost(ConfigureRoutes)) + { + host.Start(); + + var dataSource = host.Services.GetRequiredService(); + // We register 2 endpoints (/negotiate and /) + Assert.Collection(dataSource.Endpoints, + endpoint => + { + Assert.Equal("/path/negotiate", endpoint.DisplayName); + Assert.Equal(typeof(AuthHub), endpoint.Metadata.GetMetadata()?.HubType); + Assert.NotNull(endpoint.Metadata.GetMetadata()); + }, + endpoint => + { + Assert.Equal("/path", endpoint.DisplayName); + Assert.Equal(typeof(AuthHub), endpoint.Metadata.GetMetadata()?.HubType); + Assert.Null(endpoint.Metadata.GetMetadata()); + }); } } @@ -252,6 +392,7 @@ private IWebHost BuildWebHostWithEndPointRouting(Action c .Build(); } +#pragma warning disable CS0618 // Type or member is obsolete private IWebHost BuildWebHost(Action configure) { return new WebHostBuilder() @@ -267,5 +408,6 @@ private IWebHost BuildWebHost(Action configure) .UseUrls("http://127.0.0.1:0") .Build(); } +#pragma warning restore CS0618 // Type or member is obsolete } } diff --git a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests.csproj b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests.csproj index 24958a00b6c4..527ae19b1945 100644 --- a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests.csproj +++ b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests.csproj @@ -13,8 +13,8 @@ - + diff --git a/src/SignalR/server/SignalR/test/Startup.cs b/src/SignalR/server/SignalR/test/Startup.cs index 0f8ff7f21a21..25612f917f40 100644 --- a/src/SignalR/server/SignalR/test/Startup.cs +++ b/src/SignalR/server/SignalR/test/Startup.cs @@ -1,14 +1,24 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using Microsoft.AspNetCore.Authentication.Cookies; +using System; +using System.IdentityModel.Tokens.Jwt; +using System.Security.Claims; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Authentication.JwtBearer; +using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; using Microsoft.Extensions.DependencyInjection; +using Microsoft.IdentityModel.Tokens; namespace Microsoft.AspNetCore.SignalR.Tests { public class Startup { + private readonly SymmetricSecurityKey SecurityKey = new SymmetricSecurityKey(Guid.NewGuid().ToByteArray()); + private readonly JwtSecurityTokenHandler JwtTokenHandler = new JwtSecurityTokenHandler(); + public void ConfigureServices(IServiceCollection services) { services.AddConnections(); @@ -19,11 +29,40 @@ public void ConfigureServices(IServiceCollection services) services.AddAuthentication(options => { - options.DefaultAuthenticateScheme = CookieAuthenticationDefaults.AuthenticationScheme; - options.DefaultChallengeScheme = CookieAuthenticationDefaults.AuthenticationScheme; - }).AddCookie(); + options.DefaultAuthenticateScheme = JwtBearerDefaults.AuthenticationScheme; + options.DefaultChallengeScheme = JwtBearerDefaults.AuthenticationScheme; + }).AddJwtBearer(options => + { + options.TokenValidationParameters = + new TokenValidationParameters + { + LifetimeValidator = (before, expires, token, parameters) => expires > DateTime.UtcNow, + ValidateAudience = false, + ValidateIssuer = false, + ValidateActor = false, + ValidateLifetime = true, + IssuerSigningKey = SecurityKey + }; + + options.Events = new JwtBearerEvents + { + OnMessageReceived = context => + { + var accessToken = context.Request.Query["access_token"]; + + if (!string.IsNullOrEmpty(accessToken) && + (context.HttpContext.WebSockets.IsWebSocketRequest || context.Request.Headers["Accept"] == "text/event-stream")) + { + context.Token = context.Request.Query["access_token"]; + } + return Task.CompletedTask; + } + }; + }); services.AddAuthorization(); + + services.AddSingleton(); } public void Configure(IApplicationBuilder app) @@ -32,15 +71,37 @@ public void Configure(IApplicationBuilder app) app.UseAuthentication(); app.UseAuthorization(); + // Legacy routing, runs different code path for mapping hubs +#pragma warning disable CS0618 // Type or member is obsolete + app.UseSignalR(routes => + { + routes.MapHub("/authHub"); + }); +#pragma warning restore CS0618 // Type or member is obsolete + app.UseEndpoints(endpoints => { endpoints.MapHub("/uncreatable"); + endpoints.MapHub("/authHubEndpoints"); endpoints.MapConnectionHandler("/echo"); endpoints.MapConnectionHandler("/echoAndClose"); endpoints.MapConnectionHandler("/httpheader"); endpoints.MapConnectionHandler("/auth"); + + endpoints.MapGet("/generatetoken", context => + { + return context.Response.WriteAsync(GenerateToken(context)); + }); }); } + + private string GenerateToken(HttpContext httpContext) + { + var claims = new[] { new Claim(ClaimTypes.NameIdentifier, httpContext.Request.Query["user"]) }; + var credentials = new SigningCredentials(SecurityKey, SecurityAlgorithms.HmacSha256); + var token = new JwtSecurityToken("SignalRTestServer", "SignalRTests", claims, expires: DateTime.UtcNow.AddMinutes(1), signingCredentials: credentials); + return JwtTokenHandler.WriteToken(token); + } } } diff --git a/src/SignalR/server/SignalR/test/TestAuthHandler.cs b/src/SignalR/server/SignalR/test/TestAuthHandler.cs new file mode 100644 index 000000000000..0e070944a1a8 --- /dev/null +++ b/src/SignalR/server/SignalR/test/TestAuthHandler.cs @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Security.Claims; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Authorization; + +namespace Microsoft.AspNetCore.SignalR.Tests +{ + public class TestAuthHandler : IAuthorizationHandler + { + public Task HandleAsync(AuthorizationHandlerContext context) + { + foreach (var req in context.Requirements) + { + context.Succeed(req); + } + + var hasClaim = context.User.HasClaim(o => o.Type == ClaimTypes.NameIdentifier && !string.IsNullOrEmpty(o.Value)); + + if (!hasClaim) + { + context.Fail(); + } + + return Task.CompletedTask; + } + } +} diff --git a/src/SignalR/server/SignalR/test/UncreatableHub.cs b/src/SignalR/server/SignalR/test/UncreatableHub.cs index 97142cf2716e..e481613f727c 100644 --- a/src/SignalR/server/SignalR/test/UncreatableHub.cs +++ b/src/SignalR/server/SignalR/test/UncreatableHub.cs @@ -1,9 +1,9 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. namespace Microsoft.AspNetCore.SignalR.Tests { - public class UncreatableHub: Hub + public class UncreatableHub : Hub { public UncreatableHub(object obj) {