@@ -18,6 +18,9 @@ enum RequestType {
18
18
REMOTE_DISCONNECT = 4 ,
19
19
REMOTE_FETCH = 5 ,
20
20
SERVER_SIDE_EMIT = 6 ,
21
+ BROADCAST ,
22
+ BROADCAST_CLIENT_COUNT ,
23
+ BROADCAST_ACK ,
21
24
}
22
25
23
26
interface Request {
@@ -29,6 +32,11 @@ interface Request {
29
32
[ other : string ] : any ;
30
33
}
31
34
35
+ interface AckRequest {
36
+ clientCountCallback : ( clientCount : number ) => void ;
37
+ ack : ( ...args : any [ ] ) => void ;
38
+ }
39
+
32
40
const isNumeric = ( str ) => ! isNaN ( str ) && ! isNaN ( parseFloat ( str ) ) ;
33
41
34
42
export interface RedisAdapterOptions {
@@ -84,6 +92,7 @@ export class RedisAdapter extends Adapter {
84
92
private readonly requestChannel : string ;
85
93
private readonly responseChannel : string ;
86
94
private requests : Map < string , Request > = new Map ( ) ;
95
+ private ackRequests : Map < string , AckRequest > = new Map ( ) ;
87
96
88
97
/**
89
98
* Adapter constructor.
@@ -127,7 +136,8 @@ export class RedisAdapter extends Adapter {
127
136
[ this . requestChannel , this . responseChannel , specificResponseChannel ] ,
128
137
( msg , channel ) => {
129
138
this . onrequest ( channel , msg ) ;
130
- }
139
+ } ,
140
+ true
131
141
) ;
132
142
} else {
133
143
this . subClient . psubscribe ( this . channel + "*" ) ;
@@ -212,7 +222,12 @@ export class RedisAdapter extends Adapter {
212
222
let request ;
213
223
214
224
try {
215
- request = JSON . parse ( msg ) ;
225
+ // if the buffer starts with a "{" character
226
+ if ( msg [ 0 ] === 0x7b ) {
227
+ request = JSON . parse ( msg . toString ( ) ) ;
228
+ } else {
229
+ request = msgpack . decode ( msg ) ;
230
+ }
216
231
} catch ( err ) {
217
232
debug ( "ignoring malformed request" ) ;
218
233
return ;
@@ -379,6 +394,47 @@ export class RedisAdapter extends Adapter {
379
394
this . nsp . _onServerSideEmit ( request . data ) ;
380
395
break ;
381
396
397
+ case RequestType . BROADCAST : {
398
+ if ( this . ackRequests . has ( request . requestId ) ) {
399
+ // ignore self
400
+ return ;
401
+ }
402
+
403
+ const opts = {
404
+ rooms : new Set < Room > ( request . opts . rooms ) ,
405
+ except : new Set < Room > ( request . opts . except ) ,
406
+ } ;
407
+
408
+ super . broadcastWithAck (
409
+ request . packet ,
410
+ opts ,
411
+ ( clientCount ) => {
412
+ debug ( "waiting for %d client acknowledgements" , clientCount ) ;
413
+ this . publishResponse (
414
+ request ,
415
+ JSON . stringify ( {
416
+ type : RequestType . BROADCAST_CLIENT_COUNT ,
417
+ requestId : request . requestId ,
418
+ clientCount,
419
+ } )
420
+ ) ;
421
+ } ,
422
+ ( arg ) => {
423
+ debug ( "received acknowledgement with value %j" , arg ) ;
424
+
425
+ this . publishResponse (
426
+ request ,
427
+ msgpack . encode ( {
428
+ type : RequestType . BROADCAST_ACK ,
429
+ requestId : request . requestId ,
430
+ packet : arg ,
431
+ } )
432
+ ) ;
433
+ }
434
+ ) ;
435
+ break ;
436
+ }
437
+
382
438
default :
383
439
debug ( "ignoring unknown request type: %s" , request . type ) ;
384
440
}
@@ -407,15 +463,40 @@ export class RedisAdapter extends Adapter {
407
463
let response ;
408
464
409
465
try {
410
- response = JSON . parse ( msg ) ;
466
+ // if the buffer starts with a "{" character
467
+ if ( msg [ 0 ] === 0x7b ) {
468
+ response = JSON . parse ( msg . toString ( ) ) ;
469
+ } else {
470
+ response = msgpack . decode ( msg ) ;
471
+ }
411
472
} catch ( err ) {
412
473
debug ( "ignoring malformed response" ) ;
413
474
return ;
414
475
}
415
476
416
477
const requestId = response . requestId ;
417
478
418
- if ( ! requestId || ! this . requests . has ( requestId ) ) {
479
+ if ( this . ackRequests . has ( requestId ) ) {
480
+ const ackRequest = this . ackRequests . get ( requestId ) ;
481
+
482
+ switch ( response . type ) {
483
+ case RequestType . BROADCAST_CLIENT_COUNT : {
484
+ ackRequest ?. clientCountCallback ( response . clientCount ) ;
485
+ break ;
486
+ }
487
+
488
+ case RequestType . BROADCAST_ACK : {
489
+ ackRequest ?. ack ( response . packet ) ;
490
+ break ;
491
+ }
492
+ }
493
+ return ;
494
+ }
495
+
496
+ if (
497
+ ! requestId ||
498
+ ! ( this . requests . has ( requestId ) || this . ackRequests . has ( requestId ) )
499
+ ) {
419
500
debug ( "ignoring unknown request" ) ;
420
501
return ;
421
502
}
@@ -526,6 +607,50 @@ export class RedisAdapter extends Adapter {
526
607
super . broadcast ( packet , opts ) ;
527
608
}
528
609
610
+ public broadcastWithAck (
611
+ packet : any ,
612
+ opts : BroadcastOptions ,
613
+ clientCountCallback : ( clientCount : number ) => void ,
614
+ ack : ( ...args : any [ ] ) => void
615
+ ) {
616
+ packet . nsp = this . nsp . name ;
617
+
618
+ const onlyLocal = opts ?. flags ?. local ;
619
+
620
+ if ( ! onlyLocal ) {
621
+ const requestId = uid2 ( 6 ) ;
622
+
623
+ const rawOpts = {
624
+ rooms : [ ...opts . rooms ] ,
625
+ except : [ ...new Set ( opts . except ) ] ,
626
+ flags : opts . flags ,
627
+ } ;
628
+
629
+ const request = msgpack . encode ( {
630
+ uid : this . uid ,
631
+ requestId,
632
+ type : RequestType . BROADCAST ,
633
+ packet,
634
+ opts : rawOpts ,
635
+ } ) ;
636
+
637
+ this . pubClient . publish ( this . requestChannel , request ) ;
638
+
639
+ this . ackRequests . set ( requestId , {
640
+ clientCountCallback,
641
+ ack,
642
+ } ) ;
643
+
644
+ // we have no way to know at this level whether the server has received an acknowledgement from each client, so we
645
+ // will simply clean up the ackRequests map after the given delay
646
+ setTimeout ( ( ) => {
647
+ this . ackRequests . delete ( requestId ) ;
648
+ } , opts . flags ! . timeout ) ;
649
+ }
650
+
651
+ super . broadcastWithAck ( packet , opts , clientCountCallback , ack ) ;
652
+ }
653
+
529
654
/**
530
655
* Gets a list of sockets by sid.
531
656
*
@@ -955,4 +1080,8 @@ export class RedisAdapter extends Adapter {
955
1080
} ) ;
956
1081
}
957
1082
}
1083
+
1084
+ serverCount ( ) : Promise < number > {
1085
+ return this . getNumSub ( ) ;
1086
+ }
958
1087
}
0 commit comments