Skip to content

Commit b17397b

Browse files
author
Jeff Yang
authored
fix: make a seed respect to distributed settings (#60)
* fix: make a seed in run * fix: download datasets only on rank 0 (#62) * fix: create ignite loggers only on rank 0 (#64)
1 parent 920ac61 commit b17397b

File tree

3 files changed

+78
-27
lines changed

3 files changed

+78
-27
lines changed

templates/gan/main.py

+26-11
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,35 @@
2121
from config import get_default_parser
2222

2323

24-
PRINT_FREQ = 100
2524
FAKE_IMG_FNAME = "fake_sample_epoch_{:04d}.png"
2625
REAL_IMG_FNAME = "real_sample_epoch_{:04d}.png"
2726
LOGS_FNAME = "logs.tsv"
2827
PLOT_FNAME = "plot.svg"
29-
SAMPLES_FNAME = "samples.svg"
30-
CKPT_PREFIX = "networks"
3128

3229

3330
def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
3431
"""function to be run by idist.Parallel context manager."""
3532

33+
# ----------------------
34+
# make a certain seed
35+
# ----------------------
36+
rank = idist.get_rank()
37+
manual_seed(config.seed + rank)
38+
3639
# -----------------------------
3740
# datasets and dataloaders
3841
# -----------------------------
3942

43+
if rank > 0:
44+
# Ensure that only rank 0 download the dataset
45+
idist.barrier()
46+
4047
train_dataset, num_channels = get_datasets(config.dataset, config.data_path)
48+
49+
if rank == 0:
50+
# Ensure that only rank 0 download the dataset
51+
idist.barrier()
52+
4153
train_dataloader = idist.auto_dataloader(
4254
train_dataset,
4355
batch_size=config.batch_size,
@@ -97,7 +109,10 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
97109
lr_scheduler=lr_scheduler,
98110
output_names=["errD", "errG", "D_x", "D_G_z1", "D_G_z2"],
99111
)
100-
logger_handler = get_logger(config=config, train_engine=train_engine, optimizers=optimizers)
112+
113+
# setup ignite logger only on rank 0
114+
if rank == 0:
115+
logger_handler = get_logger(config=config, train_engine=train_engine, optimizers=optimizers)
101116

102117
# -----------------------------------
103118
# resume from the saved checkpoints
@@ -177,12 +192,13 @@ def create_plots(engine):
177192
# close the logger after the training completed / terminated
178193
# ------------------------------------------------------------
179194

180-
if isinstance(logger_handler, WandBLogger):
181-
# why handle differently for wandb ?
182-
# See : https://github.com/pytorch/ignite/issues/1894
183-
logger_handler.finish()
184-
elif logger_handler:
185-
logger_handler.close()
195+
if rank == 0:
196+
if isinstance(logger_handler, WandBLogger):
197+
# why handle differently for wandb ?
198+
# See : https://github.com/pytorch/ignite/issues/1894
199+
logger_handler.finish()
200+
elif logger_handler:
201+
logger_handler.close()
186202

187203
# -----------------------------------------
188204
# where is my best and last checkpoint ?
@@ -194,7 +210,6 @@ def create_plots(engine):
194210
def main():
195211
parser = ArgumentParser(parents=[get_default_parser()])
196212
config = parser.parse_args()
197-
manual_seed(config.seed)
198213

199214
if config.output_dir:
200215
now = datetime.now().strftime("%Y%m%d-%H%M%S")

templates/image_classification/main.py

+26-8
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@
2222
def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
2323
"""function to be run by idist.Parallel context manager."""
2424

25+
# ----------------------
26+
# make a certain seed
27+
# ----------------------
28+
rank = idist.get_rank()
29+
manual_seed(config.seed + rank)
30+
2531
# -----------------------------
2632
# datasets and dataloaders
2733
# -----------------------------
@@ -30,7 +36,16 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
3036
# TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments
3137
# See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader
3238

39+
if rank > 0:
40+
# Ensure that only rank 0 download the dataset
41+
idist.barrier()
42+
3343
train_dataset, eval_dataset = get_datasets(path=config.data_path)
44+
45+
if rank == 0:
46+
# Ensure that only rank 0 download the dataset
47+
idist.barrier()
48+
3449
train_dataloader = idist.auto_dataloader(
3550
train_dataset,
3651
batch_size=config.train_batch_size,
@@ -110,7 +125,10 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
110125
lr_scheduler=lr_scheduler,
111126
output_names=None,
112127
)
113-
logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer)
128+
129+
# setup ignite logger only on rank 0
130+
if rank == 0:
131+
logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer)
114132

115133
# -----------------------------------
116134
# resume from the saved checkpoints
@@ -183,12 +201,13 @@ def _():
183201
# close the logger after the training completed / terminated
184202
# ------------------------------------------------------------
185203

186-
if isinstance(logger_handler, WandBLogger):
187-
# why handle differently for wandb ?
188-
# See : https://github.com/pytorch/ignite/issues/1894
189-
logger_handler.finish()
190-
elif logger_handler:
191-
logger_handler.close()
204+
if rank == 0:
205+
if isinstance(logger_handler, WandBLogger):
206+
# why handle differently for wandb ?
207+
# See : https://github.com/pytorch/ignite/issues/1894
208+
logger_handler.finish()
209+
elif logger_handler:
210+
logger_handler.close()
192211

193212
# -----------------------------------------
194213
# where is my best and last checkpoint ?
@@ -200,7 +219,6 @@ def _():
200219
def main():
201220
parser = ArgumentParser(parents=[get_default_parser()])
202221
config = parser.parse_args()
203-
manual_seed(config.seed)
204222

205223
if config.output_dir:
206224
now = datetime.now().strftime("%Y%m%d-%H%M%S")

templates/single/main.py

+26-8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@
2020
def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
2121
"""function to be run by idist.Parallel context manager."""
2222

23+
# ----------------------
24+
# make a certain seed
25+
# ----------------------
26+
rank = idist.get_rank()
27+
manual_seed(config.seed + rank)
28+
2329
# -----------------------------
2430
# datasets and dataloaders
2531
# -----------------------------
@@ -28,8 +34,17 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
2834
# TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments
2935
# See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader
3036

37+
if rank > 0:
38+
# Ensure that only rank 0 download the dataset
39+
idist.barrier()
40+
3141
train_dataset = ...
3242
eval_dataset = ...
43+
44+
if rank == 0:
45+
# Ensure that only rank 0 download the dataset
46+
idist.barrier()
47+
3348
train_dataloader = idist.auto_dataloader(train_dataset, **kwargs)
3449
eval_dataloader = idist.auto_dataloader(eval_dataset, **kwargs)
3550

@@ -86,7 +101,10 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
86101
lr_scheduler=lr_scheduler,
87102
output_names=None,
88103
)
89-
logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer)
104+
105+
# setup ignite logger only on rank 0
106+
if rank == 0:
107+
logger_handler = get_logger(config=config, train_engine=train_engine, eval_engine=eval_engine, optimizers=optimizer)
90108

91109
# -----------------------------------
92110
# resume from the saved checkpoints
@@ -160,12 +178,13 @@ def _():
160178
# close the logger after the training completed / terminated
161179
# ------------------------------------------------------------
162180

163-
if isinstance(logger_handler, WandBLogger):
164-
# why handle differently for wandb ?
165-
# See : https://github.com/pytorch/ignite/issues/1894
166-
logger_handler.finish()
167-
elif logger_handler:
168-
logger_handler.close()
181+
if rank == 0:
182+
if isinstance(logger_handler, WandBLogger):
183+
# why handle differently for wandb ?
184+
# See : https://github.com/pytorch/ignite/issues/1894
185+
logger_handler.finish()
186+
elif logger_handler:
187+
logger_handler.close()
169188

170189
# -----------------------------------------
171190
# where is my best and last checkpoint ?
@@ -177,7 +196,6 @@ def _():
177196
def main():
178197
parser = ArgumentParser(parents=[get_default_parser()])
179198
config = parser.parse_args()
180-
manual_seed(config.seed)
181199

182200
if config.output_dir:
183201
now = datetime.now().strftime("%Y%m%d-%H%M%S")

0 commit comments

Comments
 (0)