Skip to content

Restructured config #243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 20, 2023
35 changes: 4 additions & 31 deletions scripts/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,7 @@ run_simple() {
for dir in $(find ./dist-tests/$1-simple -type d)
do
cd $dir
python main.py --data_path ~/data \
--train_batch_size 2 \
--eval_batch_size 2 \
--num_workers 2 \
--max_epochs 2 \
--train_epoch_length 4 \
--eval_epoch_length 4
python main.py ../../src/tests/ci-configs/$1-simple.yaml
cd $CWD
done
}
Expand All @@ -34,13 +28,7 @@ run_all() {
do
cd $dir
pytest -vra --color=yes --tb=short test_*.py
python main.py --data_path ~/data \
--train_batch_size 2 \
--eval_batch_size 2 \
--num_workers 2 \
--max_epochs 2 \
--train_epoch_length 4 \
--eval_epoch_length 4
python main.py ../../src/tests/ci-configs/$1-all.yaml
cd $CWD
done
}
Expand All @@ -49,15 +37,7 @@ run_launch() {
for dir in $(find ./dist-tests/$1-launch -type d)
do
cd $dir
torchrun \
--nproc_per_node 2 \
main.py --backend gloo --data_path ~/data \
--train_batch_size 2 \
--eval_batch_size 2 \
--num_workers 1 \
--max_epochs 2 \
--train_epoch_length 4 \
--eval_epoch_length 4
torchrun --nproc_per_node 2 main.py ../../src/tests/ci-configs/$1-launch.yaml --backend gloo
cd $CWD
done
}
Expand All @@ -66,14 +46,7 @@ run_spawn() {
for dir in $(find ./dist-tests/$1-spawn -type d)
do
cd $dir
python main.py --data_path ~/data \
--nproc_per_node 2 --backend gloo \
--train_batch_size 4 \
--eval_batch_size 4 \
--num_workers 1 \
--max_epochs 2 \
--train_epoch_length 4 \
--eval_epoch_length 4
python main.py ../../src/tests/ci-configs/$1-spawn.yaml --backend gloo
cd $CWD
done
}
Expand Down
2 changes: 1 addition & 1 deletion src/templates/template-common/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

# main entrypoint
def main():
config = setup_parser().parse_args()
config = setup_config()
#::: if (it.dist === 'spawn') { :::#
#::: if (it.nproc_per_node && it.nnodes > 1 && it.master_addr && it.master_port) { :::#
kwargs = {
Expand Down
32 changes: 23 additions & 9 deletions src/templates/template-common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,33 @@
from ignite.utils import setup_logger


def setup_parser():
with open("config.yaml", "r") as f:
def get_default_parser():
parser = ArgumentParser()
parser.add_argument("config", type=Path, help="Config file path")
parser.add_argument(
"--backend",
default=None,
choices=["nccl", "gloo"],
type=str,
help="DDP backend",
)
return parser


def setup_config(parser=None):
if parser is None:
parser = get_default_parser()

args = parser.parse_args()
config_path = args.config

with open(config_path, "r") as f:
config = yaml.safe_load(f.read())

parser = ArgumentParser()
parser.add_argument("--backend", default=None, type=str)
for k, v in config.items():
if isinstance(v, bool):
parser.add_argument(f"--{k}", action="store_true")
else:
parser.add_argument(f"--{k}", default=v, type=type(v))
setattr(args, k, v)

return parser
return args


def log_metrics(engine: Engine, tag: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/templates/template-text-classification/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _():

# main entrypoint
def main():
config = setup_parser().parse_args()
config = setup_config()
#::: if (it.dist === 'spawn') { :::#
#::: if (it.nproc_per_node && it.nnodes > 1 && it.master_addr && it.master_port) { :::#
kwargs = {
Expand Down
32 changes: 23 additions & 9 deletions src/templates/template-text-classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,33 @@
from ignite.utils import setup_logger


def setup_parser():
with open("config.yaml", "r") as f:
def get_default_parser():
parser = ArgumentParser()
parser.add_argument("config", type=Path, help="Config file path")
parser.add_argument(
"--backend",
default=None,
choices=["nccl", "gloo"],
type=str,
help="DDP backend",
)
return parser


def setup_config(parser=None):
if parser is None:
parser = get_default_parser()

args = parser.parse_args()
config_path = args.config

with open(config_path, "r") as f:
config = yaml.safe_load(f.read())

parser = ArgumentParser()
parser.add_argument("--backend", default=None, type=str)
for k, v in config.items():
if isinstance(v, bool):
parser.add_argument(f"--{k}", action="store_true")
else:
parser.add_argument(f"--{k}", default=v, type=type(v))
setattr(args, k, v)

return parser
return args


def log_metrics(engine: Engine, tag: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/templates/template-vision-classification/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _():

# main entrypoint
def main():
config = setup_parser().parse_args()
config = setup_config()
#::: if (it.dist === 'spawn') { :::#
#::: if (it.nproc_per_node && it.nnodes > 1 && it.master_addr && it.master_port) { :::#
kwargs = {
Expand Down
32 changes: 23 additions & 9 deletions src/templates/template-vision-classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,33 @@
from ignite.utils import setup_logger


def setup_parser():
with open("config.yaml", "r") as f:
def get_default_parser():
parser = ArgumentParser()
parser.add_argument("config", type=Path, help="Config file path")
parser.add_argument(
"--backend",
default=None,
choices=["nccl", "gloo"],
type=str,
help="DDP backend",
)
return parser


def setup_config(parser=None):
if parser is None:
parser = get_default_parser()

args = parser.parse_args()
config_path = args.config

with open(config_path, "r") as f:
config = yaml.safe_load(f.read())

parser = ArgumentParser()
parser.add_argument("--backend", default=None, type=str)
for k, v in config.items():
if isinstance(v, bool):
parser.add_argument(f"--{k}", action="store_true")
else:
parser.add_argument(f"--{k}", default=v, type=type(v))
setattr(args, k, v)

return parser
return args


def log_metrics(engine: Engine, tag: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/templates/template-vision-dcgan/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _():

# main entrypoint
def main():
config = setup_parser().parse_args()
config = setup_config()
#::: if (it.dist === 'spawn') { :::#
#::: if (it.nproc_per_node && it.nnodes > 1 && it.master_addr && it.master_port) { :::#
kwargs = {
Expand Down
32 changes: 23 additions & 9 deletions src/templates/template-vision-dcgan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,33 @@
from ignite.utils import setup_logger


def setup_parser():
with open("config.yaml", "r") as f:
def get_default_parser():
parser = ArgumentParser()
parser.add_argument("config", type=Path, help="Config file path")
parser.add_argument(
"--backend",
default=None,
choices=["nccl", "gloo"],
type=str,
help="DDP backend",
)
return parser


def setup_config(parser=None):
if parser is None:
parser = get_default_parser()

args = parser.parse_args()
config_path = args.config

with open(config_path, "r") as f:
config = yaml.safe_load(f.read())

parser = ArgumentParser()
parser.add_argument("--backend", default=None, type=str)
for k, v in config.items():
if isinstance(v, bool):
parser.add_argument(f"--{k}", action="store_true")
else:
parser.add_argument(f"--{k}", default=v, type=type(v))
setattr(args, k, v)

return parser
return args


def log_metrics(engine: Engine, tag: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/templates/template-vision-segmentation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _():

# main entrypoint
def main():
config = setup_parser().parse_args()
config = setup_config()
#::: if (it.dist === 'spawn') { :::#
#::: if (it.nproc_per_node && it.nnodes > 1 && it.master_addr && it.master_port) { :::#
kwargs = {
Expand Down
32 changes: 23 additions & 9 deletions src/templates/template-vision-segmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,33 @@
from ignite.utils import setup_logger


def setup_parser():
with open("config.yaml", "r") as f:
def get_default_parser():
parser = ArgumentParser()
parser.add_argument("config", type=Path, help="Config file path")
parser.add_argument(
"--backend",
default=None,
choices=["nccl", "gloo"],
type=str,
help="DDP backend",
)
return parser


def setup_config(parser=None):
if parser is None:
parser = get_default_parser()

args = parser.parse_args()
config_path = args.config

with open(config_path, "r") as f:
config = yaml.safe_load(f.read())

parser = ArgumentParser()
parser.add_argument("--backend", default=None, type=str)
for k, v in config.items():
if isinstance(v, bool):
parser.add_argument(f"--{k}", action="store_true")
else:
parser.add_argument(f"--{k}", default=v, type=type(v))
setattr(args, k, v)

return parser
return args


def log_metrics(engine: Engine, tag: str) -> None:
Expand Down
27 changes: 27 additions & 0 deletions src/tests/ci-configs/text-classification-all.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
seed: 666
data_path: ~/data
train_batch_size: 2
eval_batch_size: 2
num_workers: 2
max_epochs: 2
train_epoch_length: 4
eval_epoch_length: 4
use_amp: false
debug: false
model: bert-base-uncased
model_dir: /tmp/model
tokenizer_dir: /tmp/tokenizer
num_classes: 1
drop_out: .3
n_fc: 768
weight_decay: 0.01
num_warmup_epochs: 0
max_length: 256
lr: 0.00005
filename_prefix: training
n_saved: 2
save_every_iters: 2
patience: 2
limit_sec: 60
output_dir: ./logs
log_every_iters: 2
22 changes: 22 additions & 0 deletions src/tests/ci-configs/text-classification-launch.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
seed: 666
data_path: ~/data
train_batch_size: 2
eval_batch_size: 2
num_workers: 1
max_epochs: 2
train_epoch_length: 4
eval_epoch_length: 4
use_amp: false
debug: false
model: bert-base-uncased
model_dir: /tmp/model
tokenizer_dir: /tmp/tokenizer
num_classes: 1
drop_out: .3
n_fc: 768
weight_decay: 0.01
num_warmup_epochs: 0
max_length: 256
lr: 0.00005
output_dir: ./logs
log_every_iters: 2
22 changes: 22 additions & 0 deletions src/tests/ci-configs/text-classification-simple.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
seed: 666
data_path: ~/data
train_batch_size: 2
eval_batch_size: 2
num_workers: 2
max_epochs: 2
train_epoch_length: 4
eval_epoch_length: 4
use_amp: false
debug: false
model: bert-base-uncased
model_dir: /tmp/model
tokenizer_dir: /tmp/tokenizer
num_classes: 1
drop_out: .3
n_fc: 768
weight_decay: 0.01
num_warmup_epochs: 0
max_length: 256
lr: 0.00005
output_dir: ./logs
log_every_iters: 2
Loading