diff --git a/samples/ChatSample/PresenceHubLifetimeManager.cs b/samples/ChatSample/PresenceHubLifetimeManager.cs index 6bc1a25bdb..cfca0a5590 100644 --- a/samples/ChatSample/PresenceHubLifetimeManager.cs +++ b/samples/ChatSample/PresenceHubLifetimeManager.cs @@ -171,5 +171,10 @@ public override Task RemoveGroupAsync(string connectionId, string groupName) { return _wrappedHubLifetimeManager.RemoveGroupAsync(connectionId, groupName); } + + public override Task InvokeGroupExceptAsync(string groupName, string methodName, object[] args, IReadOnlyList excludedIds) + { + return _wrappedHubLifetimeManager.InvokeGroupExceptAsync(groupName, methodName, args, excludedIds); + } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs index c036489dbb..4a2612f78f 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs @@ -129,6 +129,25 @@ public override Task InvokeGroupAsync(string groupName, string methodName, objec return Task.CompletedTask; } + public override Task InvokeGroupExceptAsync(string groupName, string methodName, object[] args, IReadOnlyList excludedIds) + { + if (groupName == null) + { + throw new ArgumentNullException(nameof(groupName)); + } + + var group = _groups[groupName]; + if (group != null) + { + var message = CreateInvocationMessage(methodName, args); + var tasks = group.Values.Where(connection => !excludedIds.Contains(connection.ConnectionId)) + .Select(c => c.WriteAsync(message)); + return Task.WhenAll(tasks); + } + + return Task.CompletedTask; + } + private InvocationMessage CreateInvocationMessage(string methodName, object[] args) { return new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, argumentBindingException: null, arguments: args); diff --git a/src/Microsoft.AspNetCore.SignalR.Core/DynamicHubClients.cs b/src/Microsoft.AspNetCore.SignalR.Core/DynamicHubClients.cs index 09a15c8db4..14c58be703 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/DynamicHubClients.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/DynamicHubClients.cs @@ -18,7 +18,8 @@ public DynamicHubClients(IHubCallerClients clients) public dynamic AllExcept(IReadOnlyList excludedIds) => new DynamicClientProxy(_clients.AllExcept(excludedIds)); public dynamic Caller => new DynamicClientProxy(_clients.Caller); public dynamic Client(string connectionId) => new DynamicClientProxy(_clients.Client(connectionId)); - public dynamic Group(string group) => new DynamicClientProxy(_clients.Group(group)); + public dynamic Group(string groupName) => new DynamicClientProxy(_clients.Group(groupName)); + public dynamic GroupExcept(string groupName, IReadOnlyList excludedIds) => new DynamicClientProxy(_clients.GroupExcept(groupName, excludedIds)); public dynamic Others => new DynamicClientProxy(_clients.Others); public dynamic User(string userId) => new DynamicClientProxy(_clients.User(userId)); } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubCallerClients.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubCallerClients.cs index 1ff7d5d34b..b90d7df5c8 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubCallerClients.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubCallerClients.cs @@ -39,6 +39,11 @@ public IClientProxy Group(string groupName) return _hubClients.Group(groupName); } + public IClientProxy GroupExcept(string groupName, IReadOnlyList excludeIds) + { + return _hubClients.GroupExcept(groupName, excludeIds); + } + public IClientProxy User(string userId) { return _hubClients.User(userId); diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubContext.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubContext.cs index 8fbcf2d0bc..de4a55c61a 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubContext.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubContext.cs @@ -37,6 +37,11 @@ public virtual IClientProxy Group(string groupName) return new GroupProxy(_lifetimeManager, groupName); } + public IClientProxy GroupExcept(string groupName, IReadOnlyList excludeIds) + { + return new GroupExceptProxy(_lifetimeManager, groupName, excludeIds); + } + public virtual IClientProxy User(string userId) { return new UserProxy(_lifetimeManager, userId); diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubContext`T.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubContext`T.cs index 99b1aea9a4..ceba0bd81e 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubContext`T.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubContext`T.cs @@ -39,6 +39,11 @@ public virtual T Group(string groupName) return TypedClientBuilder.Build(new GroupProxy(_lifetimeManager, groupName)); } + public T GroupExcept(string groupName, IReadOnlyList excludeIds) + { + return TypedClientBuilder.Build(new GroupExceptProxy(_lifetimeManager, groupName, excludeIds)); + } + public virtual T User(string userId) { return TypedClientBuilder.Build(new UserProxy(_lifetimeManager, userId)); diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubLifetimeManager.cs index 826da7faf3..939209ef58 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubLifetimeManager.cs @@ -20,6 +20,8 @@ public abstract class HubLifetimeManager public abstract Task InvokeGroupAsync(string groupName, string methodName, object[] args); + public abstract Task InvokeGroupExceptAsync(string groupName, string methodName, object[] args, IReadOnlyList excludedIds); + public abstract Task InvokeUserAsync(string userId, string methodName, object[] args); public abstract Task AddGroupAsync(string connectionId, string groupName); diff --git a/src/Microsoft.AspNetCore.SignalR.Core/IHubClients`T.cs b/src/Microsoft.AspNetCore.SignalR.Core/IHubClients`T.cs index cf62e6ca22..87d6fc41a1 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/IHubClients`T.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/IHubClients`T.cs @@ -15,6 +15,8 @@ public interface IHubClients T Group(string groupName); + T GroupExcept(string groupName, IReadOnlyList excludeIds); + T User(string userId); } } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Proxies.cs b/src/Microsoft.AspNetCore.SignalR.Core/Proxies.cs index 5d4672fc53..cfa812c44b 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Proxies.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Proxies.cs @@ -40,6 +40,25 @@ public Task InvokeAsync(string method, params object[] args) } } + public class GroupExceptProxy : IClientProxy + { + private readonly string _groupName; + private readonly HubLifetimeManager _lifetimeManager; + private readonly IReadOnlyList _excludedIds; + + public GroupExceptProxy(HubLifetimeManager lifetimeManager, string groupName, IReadOnlyList excludedIds) + { + _lifetimeManager = lifetimeManager; + _groupName = groupName; + _excludedIds = excludedIds; + } + + public Task InvokeAsync(string method, params object[] args) + { + return _lifetimeManager.InvokeGroupExceptAsync(_groupName, method, args, _excludedIds); + } + } + public class AllClientProxy : IClientProxy { private readonly HubLifetimeManager _lifetimeManager; diff --git a/src/Microsoft.AspNetCore.SignalR.Core/TypedHubClients.cs b/src/Microsoft.AspNetCore.SignalR.Core/TypedHubClients.cs index 34d9974201..68fb565344 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/TypedHubClients.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/TypedHubClients.cs @@ -32,6 +32,11 @@ public T Group(string groupName) return TypedClientBuilder.Build(_hubClients.Group(groupName)); } + public T GroupExcept(string groupName, IReadOnlyList excludeIds) + { + return TypedClientBuilder.Build(_hubClients.GroupExcept(groupName, excludeIds)); + } + public T User(string userId) { return TypedClientBuilder.Build(_hubClients.User(userId)); diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index a4eff48152..44426a04da 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -190,7 +190,19 @@ public override Task InvokeGroupAsync(string groupName, string methodName, objec throw new ArgumentNullException(nameof(groupName)); } - var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, argumentBindingException: null, arguments: args); + var message = new RedisExcludeClientsMessage(GetInvocationId(), nonBlocking: true, target: methodName, excludedIds: null, arguments: args); + + return PublishAsync(_channelNamePrefix + ".group." + groupName, message); + } + + public override Task InvokeGroupExceptAsync(string groupName, string methodName, object[] args, IReadOnlyList excludedIds) + { + if (groupName == null) + { + throw new ArgumentNullException(nameof(groupName)); + } + + var message = new RedisExcludeClientsMessage(GetInvocationId(), nonBlocking: true, target: methodName, excludedIds: excludedIds, arguments: args); return PublishAsync(_channelNamePrefix + ".group." + groupName, message); } @@ -547,11 +559,16 @@ private Task SubscribeToGroup(string groupChannel, GroupData group) { try { - var message = DeserializeMessage(data); + var message = DeserializeMessage(data); - var tasks = new List(group.Connections.Count); + var tasks = new List(); foreach (var groupConnection in group.Connections) { + if (message.ExcludedIds?.Contains(groupConnection.ConnectionId) == true) + { + continue; + } + tasks.Add(groupConnection.WriteAsync(message)); } diff --git a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs index b92c45b291..aba7461fdc 100644 --- a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; @@ -94,6 +95,37 @@ public async Task InvokeGroupAsyncWritesToAllConnectionsInGroupOutput() await manager.InvokeGroupAsync("gunit", "Hello", new object[] { "World" }).OrTimeout(); + await AssertMessageAsync(client1); + Assert.Null(client2.TryRead()); + + await connection1.DisposeAsync().OrTimeout(); + await connection2.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task InvokeGroupExceptAsyncWritesToAllValidConnectionsInGroupOutput() + { + using (var client1 = new TestClient()) + using (var client2 = new TestClient()) + { + var manager = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), + Options.Create(new RedisOptions() + { + Factory = t => new TestConnectionMultiplexer() + })); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); + + await manager.OnConnectedAsync(connection1).OrTimeout(); + await manager.OnConnectedAsync(connection2).OrTimeout(); + + await manager.AddGroupAsync(connection1.ConnectionId, "gunit").OrTimeout(); + await manager.AddGroupAsync(connection2.ConnectionId, "gunit").OrTimeout(); + + var excludedIds = new List{ client2.Connection.ConnectionId }; + await manager.InvokeGroupExceptAsync("gunit", "Hello", new object[] { "World" }, excludedIds).OrTimeout(); + await AssertMessageAsync(client1); await connection1.DisposeAsync().OrTimeout(); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index 00e1d378b6..a98926904a 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -992,6 +992,58 @@ public async Task HubsCanAddAndSendToGroup(Type hubType) } } + [Theory] + [MemberData(nameof(HubTypes))] + public async Task SendToGroupExcept(Type hubType) + { + var serviceProvider = CreateServiceProvider(); + + dynamic endPoint = serviceProvider.GetService(GetEndPointType(hubType)); + + using (var firstClient = new TestClient()) + using (var secondClient = new TestClient()) + { + Task firstEndPointTask = endPoint.OnConnectedAsync(firstClient.Connection); + Task secondEndPointTask = endPoint.OnConnectedAsync(secondClient.Connection); + + await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout(); + + var result = (await firstClient.InvokeAsync("GroupSendMethod", "testGroup", "test").OrTimeout()).Result; + + // check that 'firstConnection' hasn't received the group send + Assert.Null(firstClient.TryRead()); + + // check that 'secondConnection' hasn't received the group send + Assert.Null(secondClient.TryRead()); + + await firstClient.InvokeAsync(nameof(MethodHub.GroupAddMethod), "testGroup").OrTimeout(); + await secondClient.InvokeAsync(nameof(MethodHub.GroupAddMethod), "testGroup").OrTimeout(); + + var excludedIds = new List { firstClient.Connection.ConnectionId }; + + await firstClient.SendInvocationAsync("GroupExceptSendMethod", "testGroup", "test", excludedIds).OrTimeout(); + + // check that 'secondConnection' has received the group send + var hubMessage = await secondClient.ReadAsync().OrTimeout(); + var invocation = Assert.IsType(hubMessage); + Assert.Equal("Send", invocation.Target); + Assert.Single(invocation.Arguments); + Assert.Equal("test", invocation.Arguments[0]); + + // Check that first client only got the completion message + hubMessage = await firstClient.ReadAsync().OrTimeout(); + Assert.IsType(hubMessage); + + Assert.Null(firstClient.TryRead()); + + // kill the connections + firstClient.Dispose(); + secondClient.Dispose(); + + await Task.WhenAll(firstEndPointTask, secondEndPointTask).OrTimeout(); + } + } + [Fact] public async Task RemoveFromGroupWhenNotInGroupDoesNotFail() { @@ -1626,6 +1678,11 @@ public Task GroupSendMethod(string groupName, string message) return Clients.Group(groupName).Send(message); } + public Task GroupExceptSendMethod(string groupName, string message, IReadOnlyList excludedIds) + { + return Clients.GroupExcept(groupName, excludedIds).Send(message); + } + public Task BroadcastMethod(string message) { return Clients.All.Broadcast(message); @@ -1814,6 +1871,11 @@ public Task GroupSendMethod(string groupName, string message) return Clients.Group(groupName).Send(message); } + public Task GroupExceptSendMethod(string groupName, string message, IReadOnlyList excludedIds) + { + return Clients.GroupExcept(groupName, excludedIds).Send(message); + } + public Task BroadcastMethod(string message) { return Clients.All.Broadcast(message); @@ -1945,6 +2007,11 @@ public Task GroupSendMethod(string groupName, string message) return Clients.Group(groupName).InvokeAsync("Send", message); } + public Task GroupExceptSendMethod(string groupName, string message, IReadOnlyList excludedIds) + { + return Clients.GroupExcept(groupName, excludedIds).InvokeAsync("Send", message); + } + public Task BroadcastMethod(string message) { return Clients.All.InvokeAsync("Broadcast", message);