Skip to content

Commit 4b66525

Browse files
committed
[manager/dispatcher] Synchronize Dispatcher.Stop() with incoming rpcs.
Signed-off-by: Anshul Pundir <anshul.pundir@docker.com>
1 parent 68a376d commit 4b66525

File tree

1 file changed

+89
-19
lines changed

1 file changed

+89
-19
lines changed

manager/dispatcher/dispatcher.go

Lines changed: 89 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,18 @@ type clusterUpdate struct {
125125

126126
// Dispatcher is responsible for dispatching tasks and tracking agent health.
127127
type Dispatcher struct {
128-
mu sync.Mutex
129-
wg sync.WaitGroup
128+
// mu is a lock to provide mutually exclusive access to dispatcher fields
129+
// e.g. lastSeenManagers, networkBootstrapKeys, lastSeenRootCert etc.
130+
// Also used to make atomic the setting of the shutdown flag to 'true' and the
131+
// Add() operation on the shutdownWait to make sure that stop() waits for
132+
// all operations to finish and disallow new operations from starting.
133+
mu sync.Mutex
134+
// shutdown is a flag to indicate shutdown and prevent new operations on the dispatcher.
135+
// Set by calling Stop().
136+
shutdown bool
137+
// shutdownWait is used by stop() to wait for existing operations to finish.
138+
shutdownWait sync.WaitGroup
139+
130140
nodes *nodeStore
131141
store *store.MemoryStore
132142
lastSeenManagers []*api.WeightedPeer
@@ -195,6 +205,12 @@ func getWeightedPeers(cluster Cluster) []*api.WeightedPeer {
195205
// Run runs dispatcher tasks which should be run on leader dispatcher.
196206
// Dispatcher can be stopped with cancelling ctx or calling Stop().
197207
func (d *Dispatcher) Run(ctx context.Context) error {
208+
// The dispatcher object is not recreated when a node re-gains
209+
// leadership. We need to reset to default state.
210+
d.mu.Lock()
211+
d.shutdown = false
212+
d.mu.Unlock()
213+
198214
ctx = log.WithModule(ctx, "dispatcher")
199215
log.G(ctx).Info("dispatcher starting")
200216

@@ -249,8 +265,8 @@ func (d *Dispatcher) Run(ctx context.Context) error {
249265
defer cancel()
250266
d.ctx, d.cancel = context.WithCancel(ctx)
251267
ctx = d.ctx
252-
d.wg.Add(1)
253-
defer d.wg.Done()
268+
d.shutdownWait.Add(1)
269+
defer d.shutdownWait.Done()
254270
d.mu.Unlock()
255271

256272
publishManagers := func(peers []*api.Peer) {
@@ -313,11 +329,19 @@ func (d *Dispatcher) Stop() error {
313329
return errors.New("dispatcher is already stopped")
314330
}
315331

316-
log := log.G(d.ctx).WithField("method", "(*Dispatcher).Stop")
317-
log.Info("dispatcher stopping")
332+
// Set shutdown to true.
333+
// This will prevent RPCs that start after stop() is called
334+
// from making progress and essentially puts the dispatcher in drain.
335+
d.shutdown = true
336+
337+
// Cancel dispatcher context.
338+
// This should also close the the streams in Tasks(), Assignments().
318339
d.cancel()
319340
d.mu.Unlock()
320341

342+
// Wait for the RPCs that are in-progress to finish.
343+
d.shutdownWait.Wait()
344+
321345
d.nodes.Clean()
322346

323347
d.processUpdatesLock.Lock()
@@ -328,9 +352,6 @@ func (d *Dispatcher) Stop() error {
328352
d.processUpdatesLock.Unlock()
329353

330354
d.clusterUpdateQueue.Close()
331-
332-
d.wg.Wait()
333-
334355
return nil
335356
}
336357

@@ -478,13 +499,19 @@ func nodeIPFromContext(ctx context.Context) (string, error) {
478499

479500
// register is used for registration of node with particular dispatcher.
480501
func (d *Dispatcher) register(ctx context.Context, nodeID string, description *api.NodeDescription) (string, error) {
481-
logLocal := log.G(ctx).WithField("method", "(*Dispatcher).register")
482-
// prevent register until we're ready to accept it
483-
dctx, err := d.isRunningLocked()
484-
if err != nil {
485-
return "", err
502+
d.mu.Lock()
503+
if d.shutdown {
504+
d.mu.Unlock()
505+
return "", status.Errorf(codes.Aborted, "dispatcher is stopped")
486506
}
487507

508+
dctx := d.ctx
509+
d.shutdownWait.Add(1)
510+
defer d.shutdownWait.Done()
511+
d.mu.Unlock()
512+
513+
logLocal := log.G(ctx).WithField("method", "(*Dispatcher).register")
514+
488515
if err := d.nodes.CheckRateLimit(nodeID); err != nil {
489516
return "", err
490517
}
@@ -532,6 +559,16 @@ func (d *Dispatcher) register(ctx context.Context, nodeID string, description *a
532559
// UpdateTaskStatus updates status of task. Node should send such updates
533560
// on every status change of its tasks.
534561
func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStatusRequest) (*api.UpdateTaskStatusResponse, error) {
562+
d.mu.Lock()
563+
if d.shutdown {
564+
d.mu.Unlock()
565+
return nil, status.Errorf(codes.Aborted, "dispatcher is stopped")
566+
}
567+
dctx := d.ctx
568+
d.shutdownWait.Add(1)
569+
defer d.shutdownWait.Done()
570+
d.mu.Unlock()
571+
535572
nodeInfo, err := ca.RemoteNode(ctx)
536573
if err != nil {
537574
return nil, err
@@ -547,11 +584,6 @@ func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStat
547584
}
548585
log := log.G(ctx).WithFields(fields)
549586

550-
dctx, err := d.isRunningLocked()
551-
if err != nil {
552-
return nil, err
553-
}
554-
555587
if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
556588
return nil, err
557589
}
@@ -774,6 +806,16 @@ func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServe
774806
defer cancel()
775807

776808
for {
809+
d.mu.Lock()
810+
if d.shutdown {
811+
d.mu.Unlock()
812+
return status.Errorf(codes.Aborted, "dispatcher is stopped")
813+
}
814+
815+
d.shutdownWait.Add(1)
816+
defer d.shutdownWait.Done()
817+
d.mu.Unlock()
818+
777819
if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
778820
return err
779821
}
@@ -919,6 +961,16 @@ func (d *Dispatcher) Assignments(r *api.AssignmentsRequest, stream api.Dispatche
919961
}
920962

921963
for {
964+
d.mu.Lock()
965+
if d.shutdown {
966+
d.mu.Unlock()
967+
return nil
968+
}
969+
970+
d.shutdownWait.Add(1)
971+
defer d.shutdownWait.Done()
972+
d.mu.Unlock()
973+
922974
// Check for session expiration
923975
if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
924976
return err
@@ -1103,6 +1155,15 @@ func (d *Dispatcher) markNodeNotReady(id string, state api.NodeStatus_State, mes
11031155
// Node should send new heartbeat earlier than now + TTL, otherwise it will
11041156
// be deregistered from dispatcher and its status will be updated to NodeStatus_DOWN
11051157
func (d *Dispatcher) Heartbeat(ctx context.Context, r *api.HeartbeatRequest) (*api.HeartbeatResponse, error) {
1158+
d.mu.Lock()
1159+
if d.shutdown {
1160+
d.mu.Unlock()
1161+
return nil, status.Errorf(codes.Aborted, "dispatcher is stopped")
1162+
}
1163+
d.shutdownWait.Add(1)
1164+
defer d.shutdownWait.Done()
1165+
d.mu.Unlock()
1166+
11061167
nodeInfo, err := ca.RemoteNode(ctx)
11071168
if err != nil {
11081169
return nil, err
@@ -1232,6 +1293,15 @@ func (d *Dispatcher) Session(r *api.SessionRequest, stream api.Dispatcher_Sessio
12321293
}
12331294

12341295
for {
1296+
d.mu.Lock()
1297+
if d.shutdown {
1298+
d.mu.Unlock()
1299+
status.Errorf(codes.Aborted, "dispatcher is stopped")
1300+
}
1301+
d.shutdownWait.Add(1)
1302+
defer d.shutdownWait.Done()
1303+
d.mu.Unlock()
1304+
12351305
// After each message send, we need to check the nodes sessionID hasn't
12361306
// changed. If it has, we will shut down the stream and make the node
12371307
// re-register.

0 commit comments

Comments
 (0)