Skip to content

Commit 827caf2

Browse files
Restructured config (#243)
* Restructured config - Restructured config so that arguments are defined in cofing.yaml * Fix main.py * MOdified tests according to new config restructuring * Updating remaining templates with restructured config * Update according to original config args * Configs for all the templates - Created new yaml files for testing the code - These are the test args that will be run when we run the tests * Modified tests according to new config structure * Fix typo * Correct backend argument to be passed in command line * Pass backend argument as a command line argument * Modifying the config structure in template-common --------- Co-authored-by: vfdev <[email protected]>
1 parent 81e88a2 commit 827caf2

27 files changed

+428
-81
lines changed

scripts/run_tests.sh

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,7 @@ run_simple() {
1818
for dir in $(find ./dist-tests/$1-simple -type d)
1919
do
2020
cd $dir
21-
python main.py --data_path ~/data \
22-
--train_batch_size 2 \
23-
--eval_batch_size 2 \
24-
--num_workers 2 \
25-
--max_epochs 2 \
26-
--train_epoch_length 4 \
27-
--eval_epoch_length 4
21+
python main.py ../../src/tests/ci-configs/$1-simple.yaml
2822
cd $CWD
2923
done
3024
}
@@ -34,13 +28,7 @@ run_all() {
3428
do
3529
cd $dir
3630
pytest -vra --color=yes --tb=short test_*.py
37-
python main.py --data_path ~/data \
38-
--train_batch_size 2 \
39-
--eval_batch_size 2 \
40-
--num_workers 2 \
41-
--max_epochs 2 \
42-
--train_epoch_length 4 \
43-
--eval_epoch_length 4
31+
python main.py ../../src/tests/ci-configs/$1-all.yaml
4432
cd $CWD
4533
done
4634
}
@@ -49,15 +37,7 @@ run_launch() {
4937
for dir in $(find ./dist-tests/$1-launch -type d)
5038
do
5139
cd $dir
52-
torchrun \
53-
--nproc_per_node 2 \
54-
main.py --backend gloo --data_path ~/data \
55-
--train_batch_size 2 \
56-
--eval_batch_size 2 \
57-
--num_workers 1 \
58-
--max_epochs 2 \
59-
--train_epoch_length 4 \
60-
--eval_epoch_length 4
40+
torchrun --nproc_per_node 2 main.py ../../src/tests/ci-configs/$1-launch.yaml --backend gloo
6141
cd $CWD
6242
done
6343
}
@@ -66,14 +46,7 @@ run_spawn() {
6646
for dir in $(find ./dist-tests/$1-spawn -type d)
6747
do
6848
cd $dir
69-
python main.py --data_path ~/data \
70-
--nproc_per_node 2 --backend gloo \
71-
--train_batch_size 4 \
72-
--eval_batch_size 4 \
73-
--num_workers 1 \
74-
--max_epochs 2 \
75-
--train_epoch_length 4 \
76-
--eval_epoch_length 4
49+
python main.py ../../src/tests/ci-configs/$1-spawn.yaml --backend gloo
7750
cd $CWD
7851
done
7952
}

src/templates/template-common/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
# main entrypoint
2525
def main():
26-
config = setup_parser().parse_args()
26+
config = setup_config()
2727
#::: if (it.dist === 'spawn') { :::#
2828
#::: if (it.nproc_per_node && it.nnodes > 1 && it.master_addr && it.master_port) { :::#
2929
kwargs = {

src/templates/template-common/utils.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,33 @@
1818
from ignite.utils import setup_logger
1919

2020

21-
def setup_parser():
22-
with open("config.yaml", "r") as f:
21+
def get_default_parser():
22+
parser = ArgumentParser()
23+
parser.add_argument("config", type=Path, help="Config file path")
24+
parser.add_argument(
25+
"--backend",
26+
default=None,
27+
choices=["nccl", "gloo"],
28+
type=str,
29+
help="DDP backend",
30+
)
31+
return parser
32+
33+
34+
def setup_config(parser=None):
35+
if parser is None:
36+
parser = get_default_parser()
37+
38+
args = parser.parse_args()
39+
config_path = args.config
40+
41+
with open(config_path, "r") as f:
2342
config = yaml.safe_load(f.read())
2443

25-
parser = ArgumentParser()
26-
parser.add_argument("--backend", default=None, type=str)
2744
for k, v in config.items():
28-
if isinstance(v, bool):
29-
parser.add_argument(f"--{k}", action="store_true")
30-
else:
31-
parser.add_argument(f"--{k}", default=v, type=type(v))
45+
setattr(args, k, v)
3246

33-
return parser
47+
return args
3448

3549

3650
def log_metrics(engine: Engine, tag: str) -> None:

src/templates/template-text-classification/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def _():
175175

176176
# main entrypoint
177177
def main():
178-
config = setup_parser().parse_args()
178+
config = setup_config()
179179
#::: if (it.dist === 'spawn') { :::#
180180
#::: if (it.nproc_per_node && it.nnodes > 1 && it.master_addr && it.master_port) { :::#
181181
kwargs = {

src/templates/template-text-classification/utils.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,33 @@
1818
from ignite.utils import setup_logger
1919

2020

21-
def setup_parser():
22-
with open("config.yaml", "r") as f:
21+
def get_default_parser():
22+
parser = ArgumentParser()
23+
parser.add_argument("config", type=Path, help="Config file path")
24+
parser.add_argument(
25+
"--backend",
26+
default=None,
27+
choices=["nccl", "gloo"],
28+
type=str,
29+
help="DDP backend",
30+
)
31+
return parser
32+
33+
34+
def setup_config(parser=None):
35+
if parser is None:
36+
parser = get_default_parser()
37+
38+
args = parser.parse_args()
39+
config_path = args.config
40+
41+
with open(config_path, "r") as f:
2342
config = yaml.safe_load(f.read())
2443

25-
parser = ArgumentParser()
26-
parser.add_argument("--backend", default=None, type=str)
2744
for k, v in config.items():
28-
if isinstance(v, bool):
29-
parser.add_argument(f"--{k}", action="store_true")
30-
else:
31-
parser.add_argument(f"--{k}", default=v, type=type(v))
45+
setattr(args, k, v)
3246

33-
return parser
47+
return args
3448

3549

3650
def log_metrics(engine: Engine, tag: str) -> None:

src/templates/template-vision-classification/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def _():
131131

132132
# main entrypoint
133133
def main():
134-
config = setup_parser().parse_args()
134+
config = setup_config()
135135
#::: if (it.dist === 'spawn') { :::#
136136
#::: if (it.nproc_per_node && it.nnodes > 1 && it.master_addr && it.master_port) { :::#
137137
kwargs = {

src/templates/template-vision-classification/utils.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,33 @@
1818
from ignite.utils import setup_logger
1919

2020

21-
def setup_parser():
22-
with open("config.yaml", "r") as f:
21+
def get_default_parser():
22+
parser = ArgumentParser()
23+
parser.add_argument("config", type=Path, help="Config file path")
24+
parser.add_argument(
25+
"--backend",
26+
default=None,
27+
choices=["nccl", "gloo"],
28+
type=str,
29+
help="DDP backend",
30+
)
31+
return parser
32+
33+
34+
def setup_config(parser=None):
35+
if parser is None:
36+
parser = get_default_parser()
37+
38+
args = parser.parse_args()
39+
config_path = args.config
40+
41+
with open(config_path, "r") as f:
2342
config = yaml.safe_load(f.read())
2443

25-
parser = ArgumentParser()
26-
parser.add_argument("--backend", default=None, type=str)
2744
for k, v in config.items():
28-
if isinstance(v, bool):
29-
parser.add_argument(f"--{k}", action="store_true")
30-
else:
31-
parser.add_argument(f"--{k}", default=v, type=type(v))
45+
setattr(args, k, v)
3246

33-
return parser
47+
return args
3448

3549

3650
def log_metrics(engine: Engine, tag: str) -> None:

src/templates/template-vision-dcgan/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def _():
183183

184184
# main entrypoint
185185
def main():
186-
config = setup_parser().parse_args()
186+
config = setup_config()
187187
#::: if (it.dist === 'spawn') { :::#
188188
#::: if (it.nproc_per_node && it.nnodes > 1 && it.master_addr && it.master_port) { :::#
189189
kwargs = {

src/templates/template-vision-dcgan/utils.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,33 @@
1818
from ignite.utils import setup_logger
1919

2020

21-
def setup_parser():
22-
with open("config.yaml", "r") as f:
21+
def get_default_parser():
22+
parser = ArgumentParser()
23+
parser.add_argument("config", type=Path, help="Config file path")
24+
parser.add_argument(
25+
"--backend",
26+
default=None,
27+
choices=["nccl", "gloo"],
28+
type=str,
29+
help="DDP backend",
30+
)
31+
return parser
32+
33+
34+
def setup_config(parser=None):
35+
if parser is None:
36+
parser = get_default_parser()
37+
38+
args = parser.parse_args()
39+
config_path = args.config
40+
41+
with open(config_path, "r") as f:
2342
config = yaml.safe_load(f.read())
2443

25-
parser = ArgumentParser()
26-
parser.add_argument("--backend", default=None, type=str)
2744
for k, v in config.items():
28-
if isinstance(v, bool):
29-
parser.add_argument(f"--{k}", action="store_true")
30-
else:
31-
parser.add_argument(f"--{k}", default=v, type=type(v))
45+
setattr(args, k, v)
3246

33-
return parser
47+
return args
3448

3549

3650
def log_metrics(engine: Engine, tag: str) -> None:

src/templates/template-vision-segmentation/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def _():
197197

198198
# main entrypoint
199199
def main():
200-
config = setup_parser().parse_args()
200+
config = setup_config()
201201
#::: if (it.dist === 'spawn') { :::#
202202
#::: if (it.nproc_per_node && it.nnodes > 1 && it.master_addr && it.master_port) { :::#
203203
kwargs = {

src/templates/template-vision-segmentation/utils.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,33 @@
1818
from ignite.utils import setup_logger
1919

2020

21-
def setup_parser():
22-
with open("config.yaml", "r") as f:
21+
def get_default_parser():
22+
parser = ArgumentParser()
23+
parser.add_argument("config", type=Path, help="Config file path")
24+
parser.add_argument(
25+
"--backend",
26+
default=None,
27+
choices=["nccl", "gloo"],
28+
type=str,
29+
help="DDP backend",
30+
)
31+
return parser
32+
33+
34+
def setup_config(parser=None):
35+
if parser is None:
36+
parser = get_default_parser()
37+
38+
args = parser.parse_args()
39+
config_path = args.config
40+
41+
with open(config_path, "r") as f:
2342
config = yaml.safe_load(f.read())
2443

25-
parser = ArgumentParser()
26-
parser.add_argument("--backend", default=None, type=str)
2744
for k, v in config.items():
28-
if isinstance(v, bool):
29-
parser.add_argument(f"--{k}", action="store_true")
30-
else:
31-
parser.add_argument(f"--{k}", default=v, type=type(v))
45+
setattr(args, k, v)
3246

33-
return parser
47+
return args
3448

3549

3650
def log_metrics(engine: Engine, tag: str) -> None:
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
seed: 666
2+
data_path: ~/data
3+
train_batch_size: 2
4+
eval_batch_size: 2
5+
num_workers: 2
6+
max_epochs: 2
7+
train_epoch_length: 4
8+
eval_epoch_length: 4
9+
use_amp: false
10+
debug: false
11+
model: bert-base-uncased
12+
model_dir: /tmp/model
13+
tokenizer_dir: /tmp/tokenizer
14+
num_classes: 1
15+
drop_out: .3
16+
n_fc: 768
17+
weight_decay: 0.01
18+
num_warmup_epochs: 0
19+
max_length: 256
20+
lr: 0.00005
21+
filename_prefix: training
22+
n_saved: 2
23+
save_every_iters: 2
24+
patience: 2
25+
limit_sec: 60
26+
output_dir: ./logs
27+
log_every_iters: 2
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
seed: 666
2+
data_path: ~/data
3+
train_batch_size: 2
4+
eval_batch_size: 2
5+
num_workers: 1
6+
max_epochs: 2
7+
train_epoch_length: 4
8+
eval_epoch_length: 4
9+
use_amp: false
10+
debug: false
11+
model: bert-base-uncased
12+
model_dir: /tmp/model
13+
tokenizer_dir: /tmp/tokenizer
14+
num_classes: 1
15+
drop_out: .3
16+
n_fc: 768
17+
weight_decay: 0.01
18+
num_warmup_epochs: 0
19+
max_length: 256
20+
lr: 0.00005
21+
output_dir: ./logs
22+
log_every_iters: 2
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
seed: 666
2+
data_path: ~/data
3+
train_batch_size: 2
4+
eval_batch_size: 2
5+
num_workers: 2
6+
max_epochs: 2
7+
train_epoch_length: 4
8+
eval_epoch_length: 4
9+
use_amp: false
10+
debug: false
11+
model: bert-base-uncased
12+
model_dir: /tmp/model
13+
tokenizer_dir: /tmp/tokenizer
14+
num_classes: 1
15+
drop_out: .3
16+
n_fc: 768
17+
weight_decay: 0.01
18+
num_warmup_epochs: 0
19+
max_length: 256
20+
lr: 0.00005
21+
output_dir: ./logs
22+
log_every_iters: 2

0 commit comments

Comments
 (0)