Skip to content

Commit 311a1de

Browse files
authored
Merge branch 'main' into refactor-zipRes
2 parents 012997d + 78444cd commit 311a1de

File tree

13 files changed

+52
-22
lines changed

13 files changed

+52
-22
lines changed

scripts/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ pytest
88
transformers
99
datasets
1010
tensorboard
11+
omegaconf

src/templates/template-common/config.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ num_workers: 4
66
max_epochs: 20
77
use_amp: false
88
debug: false
9+
train_epoch_length: null
10+
eval_epoch_length: null
911

1012
#::: if (it.dist === 'spawn') { :::#
1113
# distributed spawn

src/templates/template-common/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ torch>=1.10.2
22
torchvision>=0.11.3
33
pytorch-ignite>=0.4.8
44
pyyaml
5+
omegaconf
56

67
#::: if (['neptune', 'polyaxon'].includes(it.logger)) { :::#
78

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
def test_save_config():
2+
with open("./config.yaml", "r") as f:
3+
config = OmegaConf.load(f)
4+
5+
save_config(config, "./")
6+
7+
with open("./config-lock.yaml", "r") as f:
8+
test_config = OmegaConf.load(f)
9+
10+
assert config == test_config

src/templates/template-common/utils.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import ignite.distributed as idist
99
import torch
10-
import yaml
1110
from ignite.contrib.engines import common
1211
from ignite.engine import Engine
1312

@@ -35,6 +34,7 @@
3534

3635
#::: } :::#
3736
from ignite.utils import setup_logger
37+
from omegaconf import DictConfig, OmegaConf
3838

3939

4040
def get_default_parser():
@@ -57,17 +57,11 @@ def setup_config(parser=None):
5757
args = parser.parse_args()
5858
config_path = args.config
5959

60-
with open(config_path, "r") as f:
61-
config = yaml.safe_load(f.read())
60+
config = OmegaConf.load(config_path)
6261

63-
optional_attributes = ["train_epoch_length", "eval_epoch_length"]
64-
for attr in optional_attributes:
65-
config[attr] = config.get(attr, None)
62+
config.backend = args.backend
6663

67-
for k, v in config.items():
68-
setattr(args, k, v)
69-
70-
return args
64+
return DictConfig(config)
7165

7266

7367
def log_metrics(engine: Engine, tag: str) -> None:
@@ -138,6 +132,12 @@ def setup_output_dir(config: Any, rank: int) -> Path:
138132
return Path(idist.broadcast(config.output_dir, src=0))
139133

140134

135+
def save_config(config, output_dir):
136+
"""Save configuration to config-lock.yaml for result reproducibility."""
137+
with open(f"{output_dir}/config-lock.yaml", "w") as f:
138+
OmegaConf.save(config, f)
139+
140+
141141
def setup_logging(config: Any) -> Logger:
142142
"""Setup logger with `ignite.utils.setup_logger()`.
143143

src/templates/template-text-classification/main.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
from pprint import pformat
3-
from shutil import copy
43
from typing import Any
54

65
import ignite.distributed as idist
@@ -25,7 +24,7 @@ def run(local_rank: int, config: Any):
2524
# create output folder and copy config file to output dir
2625
config.output_dir = setup_output_dir(config, rank)
2726
if rank == 0:
28-
copy(config.config, f"{config.output_dir}/config-lock.yaml")
27+
save_config(config, config.output_dir)
2928

3029
# donwload datasets and create dataloaders
3130
dataloader_train, dataloader_eval = setup_data(config)
@@ -69,7 +68,7 @@ def run(local_rank: int, config: Any):
6968
# setup engines logger with python logging
7069
# print training configurations
7170
logger = setup_logging(config)
72-
logger.info("Configuration: \n%s", pformat(vars(config)))
71+
logger.info("Configuration: \n%s", pformat(config))
7372
trainer.logger = evaluator.logger = logger
7473

7574
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)

src/templates/template-text-classification/test_all.py

+5
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
import ignite.distributed as idist
66
import pytest
77
from data import setup_data
8+
from omegaconf import OmegaConf
89
from torch import nn, optim
910
from torch.functional import Tensor
1011
from torch.utils.data import DataLoader
12+
from utils import save_config
1113

1214

1315
def set_up():
@@ -45,3 +47,6 @@ def test_setup_data():
4547
assert isinstance(eval_batch["attention_mask"], Tensor)
4648
assert isinstance(eval_batch["token_type_ids"], Tensor)
4749
assert isinstance(eval_batch["label"], Tensor)
50+
51+
52+
#::= from_template_common ::#

src/templates/template-vision-classification/main.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from pprint import pformat
2-
from shutil import copy
32
from typing import Any
43

54
import ignite.distributed as idist
@@ -22,7 +21,7 @@ def run(local_rank: int, config: Any):
2221
# create output folder and copy config file to output dir
2322
config.output_dir = setup_output_dir(config, rank)
2423
if rank == 0:
25-
copy(config.config, f"{config.output_dir}/config-lock.yaml")
24+
save_config(config, config.output_dir)
2625

2726
# donwload datasets and create dataloaders
2827
dataloader_train, dataloader_eval = setup_data(config)
@@ -59,7 +58,7 @@ def run(local_rank: int, config: Any):
5958
# setup engines logger with python logging
6059
# print training configurations
6160
logger = setup_logging(config)
62-
logger.info("Configuration: \n%s", pformat(vars(config)))
61+
logger.info("Configuration: \n%s", pformat(config))
6362
trainer.logger = evaluator.logger = logger
6463

6564
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)

src/templates/template-vision-classification/test_all.py

+5
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
import pytest
77
import torch
88
from data import setup_data
9+
from omegaconf import OmegaConf
910
from torch import nn, optim, Tensor
1011
from torch.utils.data.dataloader import DataLoader
1112
from trainers import setup_evaluator
13+
from utils import save_config
1214

1315

1416
def set_up():
@@ -48,3 +50,6 @@ def test_setup_evaluator():
4850
evaluator = setup_evaluator(config, model, device)
4951
evaluator.run([batch, batch])
5052
assert isinstance(evaluator.state.output, tuple)
53+
54+
55+
#::= from_template_common ::#

src/templates/template-vision-dcgan/main.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from pprint import pformat
2-
from shutil import copy
32
from typing import Any
43

54
import ignite.distributed as idist
@@ -25,7 +24,7 @@ def run(local_rank: int, config: Any):
2524
# create output folder and copy config file to output dir
2625
config.output_dir = setup_output_dir(config, rank)
2726
if rank == 0:
28-
copy(config.config, f"{config.output_dir}/config-lock.yaml")
27+
save_config(config, config.output_dir)
2928

3029
# donwload datasets and create dataloaders
3130
dataloader_train, dataloader_eval, num_channels = setup_data(config)
@@ -74,7 +73,7 @@ def run(local_rank: int, config: Any):
7473
# setup engines logger with python logging
7574
# print training configurations
7675
logger = setup_logging(config)
77-
logger.info("Configuration: \n%s", pformat(vars(config)))
76+
logger.info("Configuration: \n%s", pformat(config))
7877
trainer.logger = evaluator.logger = logger
7978

8079
#::: if (it.save_training || it.save_evaluation) { :::#

src/templates/template-vision-dcgan/test_all.py

+5
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
import torch
88
from data import setup_data
99
from models import Discriminator, Generator
10+
from omegaconf import OmegaConf
1011
from torch import nn, optim, Tensor
1112
from torch.utils.data.dataloader import DataLoader
1213
from trainers import setup_trainer
14+
from utils import save_config
1315

1416

1517
def set_up():
@@ -62,3 +64,6 @@ def test_setup_trainer():
6264
trainer = setup_trainer(config, model, model, optimizer, optimizer, loss_fn, device, None)
6365
trainer.run([batch, batch])
6466
assert isinstance(trainer.state.output, dict)
67+
68+
69+
#::= from_template_common ::#

src/templates/template-vision-segmentation/main.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from functools import partial
22
from pprint import pformat
3-
from shutil import copy
43
from typing import Any, cast
54

65
import ignite.distributed as idist
@@ -30,7 +29,7 @@ def run(local_rank: int, config: Any):
3029
# create output folder and copy config file to output dir
3130
config.output_dir = setup_output_dir(config, rank)
3231
if rank == 0:
33-
copy(config.config, f"{config.output_dir}/config-lock.yaml")
32+
save_config(config, config.output_dir)
3433

3534
# donwload datasets and create dataloaders
3635
dataloader_train, dataloader_eval = setup_data(config)
@@ -72,7 +71,7 @@ def run(local_rank: int, config: Any):
7271
# setup engines logger with python logging
7372
# print training configurations
7473
logger = setup_logging(config)
75-
logger.info("Configuration: \n%s", pformat(vars(config)))
74+
logger.info("Configuration: \n%s", pformat(config))
7675
trainer.logger = evaluator.logger = logger
7776

7877
if isinstance(lr_scheduler, PyTorchLRScheduler):

src/templates/template-vision-segmentation/test_all.py

+5
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33

44
import pytest
55
from data import setup_data
6+
from omegaconf import OmegaConf
67
from torch import Tensor
78
from torch.utils.data.dataloader import DataLoader
9+
from utils import save_config
810

911

1012
@pytest.mark.skipif(os.getenv("RUN_SLOW_TESTS", 0) == 0, reason="Skip slow tests")
@@ -26,3 +28,6 @@ def test_setup_data():
2628
assert isinstance(eval_batch["mask"], Tensor)
2729
assert eval_batch["image"].ndim == 4
2830
assert eval_batch["mask"].ndim == 3
31+
32+
33+
#::= from_template_common ::#

0 commit comments

Comments
 (0)