From 3f9bdf792ae516e46bec5edc6fc965b42ca4f268 Mon Sep 17 00:00:00 2001 From: ydcjeff Date: Sat, 10 Apr 2021 13:46:11 +0630 Subject: [PATCH 1/3] fix: make a seed in run --- templates/gan/main.py | 7 ++++++- templates/image_classification/main.py | 7 ++++++- templates/single/main.py | 7 ++++++- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/templates/gan/main.py b/templates/gan/main.py index 05eb3b74..25b786d3 100644 --- a/templates/gan/main.py +++ b/templates/gan/main.py @@ -33,6 +33,12 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): """function to be run by idist.Parallel context manager.""" + # ---------------------- + # make a certain seed + # ---------------------- + rank = idist.get_rank() + manual_seed(config.seed + rank) + # ----------------------------- # datasets and dataloaders # ----------------------------- @@ -194,7 +200,6 @@ def create_plots(engine): def main(): parser = ArgumentParser(parents=[get_default_parser()]) config = parser.parse_args() - manual_seed(config.seed) if config.output_dir: now = datetime.now().strftime("%Y%m%d-%H%M%S") diff --git a/templates/image_classification/main.py b/templates/image_classification/main.py index 0edb0762..7efb0e71 100644 --- a/templates/image_classification/main.py +++ b/templates/image_classification/main.py @@ -22,6 +22,12 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): """function to be run by idist.Parallel context manager.""" + # ---------------------- + # make a certain seed + # ---------------------- + rank = idist.get_rank() + manual_seed(config.seed + rank) + # ----------------------------- # datasets and dataloaders # ----------------------------- @@ -200,7 +206,6 @@ def _(): def main(): parser = ArgumentParser(parents=[get_default_parser()]) config = parser.parse_args() - manual_seed(config.seed) if config.output_dir: now = datetime.now().strftime("%Y%m%d-%H%M%S") diff --git a/templates/single/main.py b/templates/single/main.py index effc8341..130dcdfb 100644 --- a/templates/single/main.py +++ b/templates/single/main.py @@ -20,6 +20,12 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): """function to be run by idist.Parallel context manager.""" + # ---------------------- + # make a certain seed + # ---------------------- + rank = idist.get_rank() + manual_seed(config.seed + rank) + # ----------------------------- # datasets and dataloaders # ----------------------------- @@ -177,7 +183,6 @@ def _(): def main(): parser = ArgumentParser(parents=[get_default_parser()]) config = parser.parse_args() - manual_seed(config.seed) if config.output_dir: now = datetime.now().strftime("%Y%m%d-%H%M%S") From fb9e972a8717d3ec87e5bf7a58be9174d1d53cfc Mon Sep 17 00:00:00 2001 From: Jeff Yang Date: Sat, 10 Apr 2021 16:47:14 +0630 Subject: [PATCH 2/3] fix: download datasets only on rank 0 (#62) --- templates/gan/main.py | 9 +++++++++ templates/image_classification/main.py | 9 +++++++++ templates/single/main.py | 9 +++++++++ 3 files changed, 27 insertions(+) diff --git a/templates/gan/main.py b/templates/gan/main.py index 25b786d3..e5b36c40 100644 --- a/templates/gan/main.py +++ b/templates/gan/main.py @@ -43,7 +43,16 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): # datasets and dataloaders # ----------------------------- + if rank > 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataset, num_channels = get_datasets(config.dataset, config.data_path) + + if rank == 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataloader = idist.auto_dataloader( train_dataset, batch_size=config.batch_size, diff --git a/templates/image_classification/main.py b/templates/image_classification/main.py index 7efb0e71..37a05341 100644 --- a/templates/image_classification/main.py +++ b/templates/image_classification/main.py @@ -36,7 +36,16 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): # TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments # See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader + if rank > 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataset, eval_dataset = get_datasets(path=config.data_path) + + if rank == 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataloader = idist.auto_dataloader( train_dataset, batch_size=config.train_batch_size, diff --git a/templates/single/main.py b/templates/single/main.py index 130dcdfb..8d5915e8 100644 --- a/templates/single/main.py +++ b/templates/single/main.py @@ -34,8 +34,17 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): # TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments # See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader + if rank > 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataset = ... eval_dataset = ... + + if rank == 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataloader = idist.auto_dataloader(train_dataset, **kwargs) eval_dataloader = idist.auto_dataloader(eval_dataset, **kwargs) From e1a21f65ae4fc2fe646880a8cd671e5dd8451d8d Mon Sep 17 00:00:00 2001 From: Jeff Yang Date: Sat, 10 Apr 2021 20:51:42 +0630 Subject: [PATCH 3/3] fix: create ignite loggers only on rank 0 (#64) --- templates/gan/main.py | 21 +++++++++++---------- templates/image_classification/main.py | 18 +++++++++++------- templates/single/main.py | 18 +++++++++++------- 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/templates/gan/main.py b/templates/gan/main.py index e5b36c40..94f430ee 100644 --- a/templates/gan/main.py +++ b/templates/gan/main.py @@ -21,13 +21,10 @@ from config import get_default_parser -PRINT_FREQ = 100 FAKE_IMG_FNAME = "fake_sample_epoch_{:04d}.png" REAL_IMG_FNAME = "real_sample_epoch_{:04d}.png" LOGS_FNAME = "logs.tsv" PLOT_FNAME = "plot.svg" -SAMPLES_FNAME = "samples.svg" -CKPT_PREFIX = "networks" def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): @@ -112,7 +109,10 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): lr_scheduler=lr_scheduler, output_names=["errD", "errG", "D_x", "D_G_z1", "D_G_z2"], ) - logger_handler = get_logger(config=config, train_engine=train_engine, optimizers=optimizers) + + # setup ignite logger only on rank 0 + if rank == 0: + logger_handler = get_logger(config=config, train_engine=train_engine, optimizers=optimizers) # ----------------------------------- # resume from the saved checkpoints @@ -192,12 +192,13 @@ def create_plots(engine): # close the logger after the training completed / terminated # ------------------------------------------------------------ - if isinstance(logger_handler, WandBLogger): - # why handle differently for wandb ? - # See : https://github.com/pytorch/ignite/issues/1894 - logger_handler.finish() - elif logger_handler: - logger_handler.close() + if rank == 0: + if isinstance(logger_handler, WandBLogger): + # why handle differently for wandb ? + # See : https://github.com/pytorch/ignite/issues/1894 + logger_handler.finish() + elif logger_handler: + logger_handler.close() # ----------------------------------------- # where is my best and last checkpoint ? diff --git a/templates/image_classification/main.py b/templates/image_classification/main.py index 37a05341..874279ff 100644 --- a/templates/image_classification/main.py +++ b/templates/image_classification/main.py @@ -125,7 +125,10 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): lr_scheduler=lr_scheduler, output_names=None, ) - logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer) + + # setup ignite logger only on rank 0 + if rank == 0: + logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer) # ----------------------------------- # resume from the saved checkpoints @@ -198,12 +201,13 @@ def _(): # close the logger after the training completed / terminated # ------------------------------------------------------------ - if isinstance(logger_handler, WandBLogger): - # why handle differently for wandb ? - # See : https://github.com/pytorch/ignite/issues/1894 - logger_handler.finish() - elif logger_handler: - logger_handler.close() + if rank == 0: + if isinstance(logger_handler, WandBLogger): + # why handle differently for wandb ? + # See : https://github.com/pytorch/ignite/issues/1894 + logger_handler.finish() + elif logger_handler: + logger_handler.close() # ----------------------------------------- # where is my best and last checkpoint ? diff --git a/templates/single/main.py b/templates/single/main.py index 8d5915e8..a9a3e082 100644 --- a/templates/single/main.py +++ b/templates/single/main.py @@ -101,7 +101,10 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): lr_scheduler=lr_scheduler, output_names=None, ) - logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer) + + # setup ignite logger only on rank 0 + if rank == 0: + logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer) # ----------------------------------- # resume from the saved checkpoints @@ -175,12 +178,13 @@ def _(): # close the logger after the training completed / terminated # ------------------------------------------------------------ - if isinstance(logger_handler, WandBLogger): - # why handle differently for wandb ? - # See : https://github.com/pytorch/ignite/issues/1894 - logger_handler.finish() - elif logger_handler: - logger_handler.close() + if rank == 0: + if isinstance(logger_handler, WandBLogger): + # why handle differently for wandb ? + # See : https://github.com/pytorch/ignite/issues/1894 + logger_handler.finish() + elif logger_handler: + logger_handler.close() # ----------------------------------------- # where is my best and last checkpoint ?