@@ -125,8 +125,18 @@ type clusterUpdate struct {
125125
126126// Dispatcher is responsible for dispatching tasks and tracking agent health.
127127type Dispatcher struct {
128- mu sync.Mutex
129- wg sync.WaitGroup
128+ // mu is a lock to provide mutually exclusive access to dispatcher fields
129+ // e.g. lastSeenManagers, networkBootstrapKeys, lastSeenRootCert etc.
130+ // Also used to make atomic the setting of the shutdown flag to 'true' and the
131+ // Add() operation on the shutdownWait to make sure that stop() waits for
132+ // all operations to finish and disallow new operations from starting.
133+ mu sync.Mutex
134+ // shutdown is a flag to indicate shutdown and prevent new operations on the dispatcher.
135+ // Set by calling Stop().
136+ shutdown bool
137+ // shutdownWait is used by stop() to wait for existing operations to finish.
138+ shutdownWait sync.WaitGroup
139+
130140 nodes * nodeStore
131141 store * store.MemoryStore
132142 lastSeenManagers []* api.WeightedPeer
@@ -195,6 +205,12 @@ func getWeightedPeers(cluster Cluster) []*api.WeightedPeer {
195205// Run runs dispatcher tasks which should be run on leader dispatcher.
196206// Dispatcher can be stopped with cancelling ctx or calling Stop().
197207func (d * Dispatcher ) Run (ctx context.Context ) error {
208+ // The dispatcher object is not recreated when a node re-gains
209+ // leadership. We need to reset to default state.
210+ d .mu .Lock ()
211+ d .shutdown = false
212+ d .mu .Unlock ()
213+
198214 ctx = log .WithModule (ctx , "dispatcher" )
199215 log .G (ctx ).Info ("dispatcher starting" )
200216
@@ -249,8 +265,8 @@ func (d *Dispatcher) Run(ctx context.Context) error {
249265 defer cancel ()
250266 d .ctx , d .cancel = context .WithCancel (ctx )
251267 ctx = d .ctx
252- d .wg .Add (1 )
253- defer d .wg .Done ()
268+ d .shutdownWait .Add (1 )
269+ defer d .shutdownWait .Done ()
254270 d .mu .Unlock ()
255271
256272 publishManagers := func (peers []* api.Peer ) {
@@ -313,11 +329,19 @@ func (d *Dispatcher) Stop() error {
313329 return errors .New ("dispatcher is already stopped" )
314330 }
315331
316- log := log .G (d .ctx ).WithField ("method" , "(*Dispatcher).Stop" )
317- log .Info ("dispatcher stopping" )
332+ // Set shutdown to true.
333+ // This will prevent RPCs that start after stop() is called
334+ // from making progress and essentially puts the dispatcher in drain.
335+ d .shutdown = true
336+
337+ // Cancel dispatcher context.
338+ // This should also close the the streams in Tasks(), Assignments().
318339 d .cancel ()
319340 d .mu .Unlock ()
320341
342+ // Wait for the RPCs that are in-progress to finish.
343+ d .shutdownWait .Wait ()
344+
321345 d .nodes .Clean ()
322346
323347 d .processUpdatesLock .Lock ()
@@ -328,9 +352,6 @@ func (d *Dispatcher) Stop() error {
328352 d .processUpdatesLock .Unlock ()
329353
330354 d .clusterUpdateQueue .Close ()
331-
332- d .wg .Wait ()
333-
334355 return nil
335356}
336357
@@ -478,13 +499,19 @@ func nodeIPFromContext(ctx context.Context) (string, error) {
478499
479500// register is used for registration of node with particular dispatcher.
480501func (d * Dispatcher ) register (ctx context.Context , nodeID string , description * api.NodeDescription ) (string , error ) {
481- logLocal := log .G (ctx ).WithField ("method" , "(*Dispatcher).register" )
482- // prevent register until we're ready to accept it
483- dctx , err := d .isRunningLocked ()
484- if err != nil {
485- return "" , err
502+ d .mu .Lock ()
503+ if d .shutdown {
504+ d .mu .Unlock ()
505+ return "" , status .Errorf (codes .Aborted , "dispatcher is stopped" )
486506 }
487507
508+ dctx := d .ctx
509+ d .shutdownWait .Add (1 )
510+ defer d .shutdownWait .Done ()
511+ d .mu .Unlock ()
512+
513+ logLocal := log .G (ctx ).WithField ("method" , "(*Dispatcher).register" )
514+
488515 if err := d .nodes .CheckRateLimit (nodeID ); err != nil {
489516 return "" , err
490517 }
@@ -532,6 +559,16 @@ func (d *Dispatcher) register(ctx context.Context, nodeID string, description *a
532559// UpdateTaskStatus updates status of task. Node should send such updates
533560// on every status change of its tasks.
534561func (d * Dispatcher ) UpdateTaskStatus (ctx context.Context , r * api.UpdateTaskStatusRequest ) (* api.UpdateTaskStatusResponse , error ) {
562+ d .mu .Lock ()
563+ if d .shutdown {
564+ d .mu .Unlock ()
565+ return nil , status .Errorf (codes .Aborted , "dispatcher is stopped" )
566+ }
567+ dctx := d .ctx
568+ d .shutdownWait .Add (1 )
569+ defer d .shutdownWait .Done ()
570+ d .mu .Unlock ()
571+
535572 nodeInfo , err := ca .RemoteNode (ctx )
536573 if err != nil {
537574 return nil , err
@@ -547,11 +584,6 @@ func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStat
547584 }
548585 log := log .G (ctx ).WithFields (fields )
549586
550- dctx , err := d .isRunningLocked ()
551- if err != nil {
552- return nil , err
553- }
554-
555587 if _ , err := d .nodes .GetWithSession (nodeID , r .SessionID ); err != nil {
556588 return nil , err
557589 }
@@ -774,6 +806,16 @@ func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServe
774806 defer cancel ()
775807
776808 for {
809+ d .mu .Lock ()
810+ if d .shutdown {
811+ d .mu .Unlock ()
812+ return status .Errorf (codes .Aborted , "dispatcher is stopped" )
813+ }
814+
815+ d .shutdownWait .Add (1 )
816+ defer d .shutdownWait .Done ()
817+ d .mu .Unlock ()
818+
777819 if _ , err := d .nodes .GetWithSession (nodeID , r .SessionID ); err != nil {
778820 return err
779821 }
@@ -919,6 +961,16 @@ func (d *Dispatcher) Assignments(r *api.AssignmentsRequest, stream api.Dispatche
919961 }
920962
921963 for {
964+ d .mu .Lock ()
965+ if d .shutdown {
966+ d .mu .Unlock ()
967+ return nil
968+ }
969+
970+ d .shutdownWait .Add (1 )
971+ defer d .shutdownWait .Done ()
972+ d .mu .Unlock ()
973+
922974 // Check for session expiration
923975 if _ , err := d .nodes .GetWithSession (nodeID , r .SessionID ); err != nil {
924976 return err
@@ -1103,6 +1155,15 @@ func (d *Dispatcher) markNodeNotReady(id string, state api.NodeStatus_State, mes
11031155// Node should send new heartbeat earlier than now + TTL, otherwise it will
11041156// be deregistered from dispatcher and its status will be updated to NodeStatus_DOWN
11051157func (d * Dispatcher ) Heartbeat (ctx context.Context , r * api.HeartbeatRequest ) (* api.HeartbeatResponse , error ) {
1158+ d .mu .Lock ()
1159+ if d .shutdown {
1160+ d .mu .Unlock ()
1161+ return nil , status .Errorf (codes .Aborted , "dispatcher is stopped" )
1162+ }
1163+ d .shutdownWait .Add (1 )
1164+ defer d .shutdownWait .Done ()
1165+ d .mu .Unlock ()
1166+
11061167 nodeInfo , err := ca .RemoteNode (ctx )
11071168 if err != nil {
11081169 return nil , err
@@ -1232,6 +1293,15 @@ func (d *Dispatcher) Session(r *api.SessionRequest, stream api.Dispatcher_Sessio
12321293 }
12331294
12341295 for {
1296+ d .mu .Lock ()
1297+ if d .shutdown {
1298+ d .mu .Unlock ()
1299+ status .Errorf (codes .Aborted , "dispatcher is stopped" )
1300+ }
1301+ d .shutdownWait .Add (1 )
1302+ defer d .shutdownWait .Done ()
1303+ d .mu .Unlock ()
1304+
12351305 // After each message send, we need to check the nodes sessionID hasn't
12361306 // changed. If it has, we will shut down the stream and make the node
12371307 // re-register.
0 commit comments