Skip to content

fix: make a seed respect to distributed settings #60

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions templates/gan/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,35 @@
from config import get_default_parser


PRINT_FREQ = 100
FAKE_IMG_FNAME = "fake_sample_epoch_{:04d}.png"
REAL_IMG_FNAME = "real_sample_epoch_{:04d}.png"
LOGS_FNAME = "logs.tsv"
PLOT_FNAME = "plot.svg"
SAMPLES_FNAME = "samples.svg"
CKPT_PREFIX = "networks"


def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
"""function to be run by idist.Parallel context manager."""

# ----------------------
# make a certain seed
# ----------------------
rank = idist.get_rank()
manual_seed(config.seed + rank)

# -----------------------------
# datasets and dataloaders
# -----------------------------

if rank > 0:
# Ensure that only rank 0 download the dataset
idist.barrier()
Comment on lines +43 to +45
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for delay with the review @ydcjeff . Maybe we could hide these barriers inside get_datasets. I also realized that here it should be local_rank == 0 and not rank == 0 as in case multi-node case the application wont download the data for other nodes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK! Then Ignite examples need a fix too

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that can be possible. Thanks for pointing out! Can you please open an issue and someone could fix it.


train_dataset, num_channels = get_datasets(config.dataset, config.data_path)

if rank == 0:
# Ensure that only rank 0 download the dataset
idist.barrier()

train_dataloader = idist.auto_dataloader(
train_dataset,
batch_size=config.batch_size,
Expand Down Expand Up @@ -97,7 +109,10 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
lr_scheduler=lr_scheduler,
output_names=["errD", "errG", "D_x", "D_G_z1", "D_G_z2"],
)
logger_handler = get_logger(config=config, train_engine=train_engine, optimizers=optimizers)

# setup ignite logger only on rank 0
if rank == 0:
logger_handler = get_logger(config=config, train_engine=train_engine, optimizers=optimizers)

# -----------------------------------
# resume from the saved checkpoints
Expand Down Expand Up @@ -177,12 +192,13 @@ def create_plots(engine):
# close the logger after the training completed / terminated
# ------------------------------------------------------------

if isinstance(logger_handler, WandBLogger):
# why handle differently for wandb ?
# See : https://github.com/pytorch/ignite/issues/1894
logger_handler.finish()
elif logger_handler:
logger_handler.close()
if rank == 0:
if isinstance(logger_handler, WandBLogger):
# why handle differently for wandb ?
# See : https://github.com/pytorch/ignite/issues/1894
logger_handler.finish()
elif logger_handler:
logger_handler.close()

# -----------------------------------------
# where is my best and last checkpoint ?
Expand All @@ -194,7 +210,6 @@ def create_plots(engine):
def main():
parser = ArgumentParser(parents=[get_default_parser()])
config = parser.parse_args()
manual_seed(config.seed)

if config.output_dir:
now = datetime.now().strftime("%Y%m%d-%H%M%S")
Expand Down
34 changes: 26 additions & 8 deletions templates/image_classification/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
"""function to be run by idist.Parallel context manager."""

# ----------------------
# make a certain seed
# ----------------------
rank = idist.get_rank()
manual_seed(config.seed + rank)

# -----------------------------
# datasets and dataloaders
# -----------------------------
Expand All @@ -30,7 +36,16 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
# TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments
# See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader

if rank > 0:
# Ensure that only rank 0 download the dataset
idist.barrier()

train_dataset, eval_dataset = get_datasets(path=config.data_path)

if rank == 0:
# Ensure that only rank 0 download the dataset
idist.barrier()

train_dataloader = idist.auto_dataloader(
train_dataset,
batch_size=config.train_batch_size,
Expand Down Expand Up @@ -110,7 +125,10 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
lr_scheduler=lr_scheduler,
output_names=None,
)
logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer)

# setup ignite logger only on rank 0
if rank == 0:
logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer)

# -----------------------------------
# resume from the saved checkpoints
Expand Down Expand Up @@ -183,12 +201,13 @@ def _():
# close the logger after the training completed / terminated
# ------------------------------------------------------------

if isinstance(logger_handler, WandBLogger):
# why handle differently for wandb ?
# See : https://github.com/pytorch/ignite/issues/1894
logger_handler.finish()
elif logger_handler:
logger_handler.close()
if rank == 0:
if isinstance(logger_handler, WandBLogger):
# why handle differently for wandb ?
# See : https://github.com/pytorch/ignite/issues/1894
logger_handler.finish()
elif logger_handler:
logger_handler.close()
Comment on lines +204 to +210
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, we can say that the code generator aims ignite nightly version and it will be also in the stable version ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are currently testing with ignite v0.4.4, we will update that once there is a new release


# -----------------------------------------
# where is my best and last checkpoint ?
Expand All @@ -200,7 +219,6 @@ def _():
def main():
parser = ArgumentParser(parents=[get_default_parser()])
config = parser.parse_args()
manual_seed(config.seed)

if config.output_dir:
now = datetime.now().strftime("%Y%m%d-%H%M%S")
Expand Down
34 changes: 26 additions & 8 deletions templates/single/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
"""function to be run by idist.Parallel context manager."""

# ----------------------
# make a certain seed
# ----------------------
rank = idist.get_rank()
manual_seed(config.seed + rank)

# -----------------------------
# datasets and dataloaders
# -----------------------------
Expand All @@ -28,8 +34,17 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
# TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments
# See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader

if rank > 0:
# Ensure that only rank 0 download the dataset
idist.barrier()

train_dataset = ...
eval_dataset = ...

if rank == 0:
# Ensure that only rank 0 download the dataset
idist.barrier()

train_dataloader = idist.auto_dataloader(train_dataset, **kwargs)
eval_dataloader = idist.auto_dataloader(eval_dataset, **kwargs)

Expand Down Expand Up @@ -86,7 +101,10 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
lr_scheduler=lr_scheduler,
output_names=None,
)
logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer)

# setup ignite logger only on rank 0
if rank == 0:
logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer)

# -----------------------------------
# resume from the saved checkpoints
Expand Down Expand Up @@ -160,12 +178,13 @@ def _():
# close the logger after the training completed / terminated
# ------------------------------------------------------------

if isinstance(logger_handler, WandBLogger):
# why handle differently for wandb ?
# See : https://github.com/pytorch/ignite/issues/1894
logger_handler.finish()
elif logger_handler:
logger_handler.close()
if rank == 0:
if isinstance(logger_handler, WandBLogger):
# why handle differently for wandb ?
# See : https://github.com/pytorch/ignite/issues/1894
logger_handler.finish()
elif logger_handler:
logger_handler.close()

# -----------------------------------------
# where is my best and last checkpoint ?
Expand All @@ -177,7 +196,6 @@ def _():
def main():
parser = ArgumentParser(parents=[get_default_parser()])
config = parser.parse_args()
manual_seed(config.seed)

if config.output_dir:
now = datetime.now().strftime("%Y%m%d-%H%M%S")
Expand Down