diff --git a/templates/gan/main.py b/templates/gan/main.py index 05eb3b74..94f430ee 100644 --- a/templates/gan/main.py +++ b/templates/gan/main.py @@ -21,23 +21,35 @@ from config import get_default_parser -PRINT_FREQ = 100 FAKE_IMG_FNAME = "fake_sample_epoch_{:04d}.png" REAL_IMG_FNAME = "real_sample_epoch_{:04d}.png" LOGS_FNAME = "logs.tsv" PLOT_FNAME = "plot.svg" -SAMPLES_FNAME = "samples.svg" -CKPT_PREFIX = "networks" def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): """function to be run by idist.Parallel context manager.""" + # ---------------------- + # make a certain seed + # ---------------------- + rank = idist.get_rank() + manual_seed(config.seed + rank) + # ----------------------------- # datasets and dataloaders # ----------------------------- + if rank > 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataset, num_channels = get_datasets(config.dataset, config.data_path) + + if rank == 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataloader = idist.auto_dataloader( train_dataset, batch_size=config.batch_size, @@ -97,7 +109,10 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): lr_scheduler=lr_scheduler, output_names=["errD", "errG", "D_x", "D_G_z1", "D_G_z2"], ) - logger_handler = get_logger(config=config, train_engine=train_engine, optimizers=optimizers) + + # setup ignite logger only on rank 0 + if rank == 0: + logger_handler = get_logger(config=config, train_engine=train_engine, optimizers=optimizers) # ----------------------------------- # resume from the saved checkpoints @@ -177,12 +192,13 @@ def create_plots(engine): # close the logger after the training completed / terminated # ------------------------------------------------------------ - if isinstance(logger_handler, WandBLogger): - # why handle differently for wandb ? - # See : https://github.com/pytorch/ignite/issues/1894 - logger_handler.finish() - elif logger_handler: - logger_handler.close() + if rank == 0: + if isinstance(logger_handler, WandBLogger): + # why handle differently for wandb ? + # See : https://github.com/pytorch/ignite/issues/1894 + logger_handler.finish() + elif logger_handler: + logger_handler.close() # ----------------------------------------- # where is my best and last checkpoint ? @@ -194,7 +210,6 @@ def create_plots(engine): def main(): parser = ArgumentParser(parents=[get_default_parser()]) config = parser.parse_args() - manual_seed(config.seed) if config.output_dir: now = datetime.now().strftime("%Y%m%d-%H%M%S") diff --git a/templates/image_classification/main.py b/templates/image_classification/main.py index 0edb0762..874279ff 100644 --- a/templates/image_classification/main.py +++ b/templates/image_classification/main.py @@ -22,6 +22,12 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): """function to be run by idist.Parallel context manager.""" + # ---------------------- + # make a certain seed + # ---------------------- + rank = idist.get_rank() + manual_seed(config.seed + rank) + # ----------------------------- # datasets and dataloaders # ----------------------------- @@ -30,7 +36,16 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): # TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments # See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader + if rank > 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataset, eval_dataset = get_datasets(path=config.data_path) + + if rank == 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataloader = idist.auto_dataloader( train_dataset, batch_size=config.train_batch_size, @@ -110,7 +125,10 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): lr_scheduler=lr_scheduler, output_names=None, ) - logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer) + + # setup ignite logger only on rank 0 + if rank == 0: + logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer) # ----------------------------------- # resume from the saved checkpoints @@ -183,12 +201,13 @@ def _(): # close the logger after the training completed / terminated # ------------------------------------------------------------ - if isinstance(logger_handler, WandBLogger): - # why handle differently for wandb ? - # See : https://github.com/pytorch/ignite/issues/1894 - logger_handler.finish() - elif logger_handler: - logger_handler.close() + if rank == 0: + if isinstance(logger_handler, WandBLogger): + # why handle differently for wandb ? + # See : https://github.com/pytorch/ignite/issues/1894 + logger_handler.finish() + elif logger_handler: + logger_handler.close() # ----------------------------------------- # where is my best and last checkpoint ? @@ -200,7 +219,6 @@ def _(): def main(): parser = ArgumentParser(parents=[get_default_parser()]) config = parser.parse_args() - manual_seed(config.seed) if config.output_dir: now = datetime.now().strftime("%Y%m%d-%H%M%S") diff --git a/templates/single/main.py b/templates/single/main.py index effc8341..a9a3e082 100644 --- a/templates/single/main.py +++ b/templates/single/main.py @@ -20,6 +20,12 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): """function to be run by idist.Parallel context manager.""" + # ---------------------- + # make a certain seed + # ---------------------- + rank = idist.get_rank() + manual_seed(config.seed + rank) + # ----------------------------- # datasets and dataloaders # ----------------------------- @@ -28,8 +34,17 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): # TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments # See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader + if rank > 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataset = ... eval_dataset = ... + + if rank == 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + train_dataloader = idist.auto_dataloader(train_dataset, **kwargs) eval_dataloader = idist.auto_dataloader(eval_dataset, **kwargs) @@ -86,7 +101,10 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): lr_scheduler=lr_scheduler, output_names=None, ) - logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer) + + # setup ignite logger only on rank 0 + if rank == 0: + logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer) # ----------------------------------- # resume from the saved checkpoints @@ -160,12 +178,13 @@ def _(): # close the logger after the training completed / terminated # ------------------------------------------------------------ - if isinstance(logger_handler, WandBLogger): - # why handle differently for wandb ? - # See : https://github.com/pytorch/ignite/issues/1894 - logger_handler.finish() - elif logger_handler: - logger_handler.close() + if rank == 0: + if isinstance(logger_handler, WandBLogger): + # why handle differently for wandb ? + # See : https://github.com/pytorch/ignite/issues/1894 + logger_handler.finish() + elif logger_handler: + logger_handler.close() # ----------------------------------------- # where is my best and last checkpoint ? @@ -177,7 +196,6 @@ def _(): def main(): parser = ArgumentParser(parents=[get_default_parser()]) config = parser.parse_args() - manual_seed(config.seed) if config.output_dir: now = datetime.now().strftime("%Y%m%d-%H%M%S")