Skip to content

Commit 0c19ad2

Browse files
authored
fix(sampler): call set_epoch every epoch start in distributed training (#130)
1 parent 9ca5397 commit 0c19ad2

File tree

2 files changed

+12
-8
lines changed
  • src/templates

2 files changed

+12
-8
lines changed

src/templates/template-vision-classification/main.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,12 @@ def run(local_rank: int, config: Any):
5454
trainer.logger = evaluator.logger = logger
5555

5656
# set epoch for distributed sampler
57-
if idist.get_world_size() > 1 and isinstance(
58-
dataloader_train.sampler, DistributedSampler
59-
):
60-
dataloader_train.sampler.set_epoch(trainer.state.epoch - 1)
57+
@trainer.on(Events.EPOCH_STARTED)
58+
def set_epoch():
59+
if idist.get_world_size() > 1 and isinstance(
60+
dataloader_train.sampler, DistributedSampler
61+
):
62+
dataloader_train.sampler.set_epoch(trainer.state.epoch - 1)
6163

6264
# setup ignite handlers
6365
#::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.timer || it.limit_sec) { :::#

src/templates/template-vision-dcgan/main.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,12 @@ def run(local_rank: int, config: Any):
8484
trainer.logger = evaluator.logger = logger
8585

8686
# set epoch for distributed sampler
87-
if idist.get_world_size() > 1 and isinstance(
88-
dataloader_train.sampler, DistributedSampler
89-
):
90-
dataloader_train.sampler.set_epoch(trainer.state.epoch - 1)
87+
@trainer.on(Events.EPOCH_STARTED)
88+
def set_epoch():
89+
if idist.get_world_size() > 1 and isinstance(
90+
dataloader_train.sampler, DistributedSampler
91+
):
92+
dataloader_train.sampler.set_epoch(trainer.state.epoch - 1)
9193

9294
# setup ignite handlers
9395
#::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.timer || it.limit_sec) { :::#

0 commit comments

Comments
 (0)