diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 06cc00f4..5eddccb1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -70,7 +70,7 @@ jobs: restore-keys: | pnpm-and-pip-cache- - - run: pip install "black==20.8b1" "isort==5.7.0" + - run: pip install -Uq pip wheel && pip install -Uq "black==20.8b1" "isort==5.7.0" - run: npm i -g pnpm - run: pnpm i --frozen-lockfile --color - run: pnpm lint diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4907dd8a..7bb912ba 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -25,7 +25,7 @@ To contribute to Code-Generator App, you will need Nodejs LTS v14.16.x, VSCode, - Install pnpm from https://pnpm.io/installation. Use standalone script if there is a root issue with installing with npm. -- Install the dependencies with `pnpm install` in the project root directory. This might take a while to install. +- Install the dependencies with `pnpm install` and `bash .github/run_code_style.sh install` in the project root directory. This might take a while to install. - Run `pnpm run dev` to start local development server and starts editing the codes in the `src` directory. Changes will be updated on the app. @@ -68,6 +68,8 @@ To add a new template, 8. For the `utils.py`, copy the starter code from the `src/template-common/utils.py`. +9. You can check if the copied codes needed are up-to-date with the base codes with: `python scripts/check_copies.py` + ## Pull Request Guidelines - Checkout a topic branch from a base branch, e.g. `v1` (currently). diff --git a/scripts/check_copies.py b/scripts/check_copies.py index 61190894..a3676e53 100644 --- a/scripts/check_copies.py +++ b/scripts/check_copies.py @@ -3,44 +3,40 @@ from pathlib import Path -def check_utils(): +def check(fname): red = "\033[31m" green = "\033[32m" reset = "\033[0m" - with open("./src/templates/template-common/utils.py", "r") as f: - common_utils = f.read() + with open(f"./src/templates/template-common/{fname}", "r") as f: + common = f.readlines() path = Path("./src/templates/") - for file in path.rglob("**/utils.py"): - utils = file.read_text("utf-8") - if utils.find(common_utils) > -1: - print(green, "Matched", file, reset) + for file in path.rglob(f"**/{fname}"): + if str(file).find("common") > -1: + continue else: - print(red, "Unmatched", file, reset) + template = file.read_text("utf-8") + match = [] + for c in common: + match.append(template.find(c) > -1) -def check_readme(): - red = "\033[31m" - green = "\033[32m" - reset = "\033[0m" - - with open("./src/templates/template-common/README.md", "r") as f: - common_utils = f.read() - - path = Path("./src/templates/") - - for file in path.rglob("**/README.md"): - utils = file.read_text("utf-8") - if utils.find(common_utils) > -1: - print(green, "Matched", file, reset) - else: - print(red, "Unmatched", file, reset) + if all(match): + print(green, "Matched", file, reset) + else: + print(red, "Unmatched", file, reset) + exit(1) if __name__ == "__main__": - check_utils() + check("config.yaml") + print() + check("main.py") + print() + check("README.md") print() - check_readme() + check("requirements.txt") print() + check("utils.py") diff --git a/src/templates/template-common/config.yaml b/src/templates/template-common/config.yaml new file mode 100644 index 00000000..65eb3fba --- /dev/null +++ b/src/templates/template-common/config.yaml @@ -0,0 +1,52 @@ +backend: null +seed: 666 +data_path: ./ +train_batch_size: 4 +eval_batch_size: 8 +num_workers: 2 +max_epochs: 2 +train_epoch_length: 4 +eval_epoch_length: 4 +lr: 0.0001 +use_amp: false +verbose: false + +#::: if (it.dist === 'spawn') { :::# +# distributed spawn +nproc_per_node: #:::= it.nproc_per_node :::# +#::: if (it.nnodes) { :::# +# distributed multi node spawn +nnodes: #:::= it.nnodes :::# +node_rank: 0 +master_addr: #:::= it.master_addr :::# +master_port: #:::= it.master_port :::# +#::: } :::# +#::: } :::# + +#::: if (it.filename_prefix) { :::# +filename_prefix: #:::= it.filename_prefix :::# +#::: } :::# + +#::: if (it.n_saved) { :::# +n_saved: #:::= it.n_saved :::# +#::: } :::# + +#::: if (it.save_every_iters) { :::# +save_every_iters: #:::= it.save_every_iters :::# +#::: } :::# + +#::: if (it.patience) { :::# +patience: #:::= it.patience :::# +#::: } :::# + +#::: if (it.limit_sec) { :::# +limit_sec: #:::= it.limit_sec :::# +#::: } :::# + +#::: if (it.output_dir) { :::# +output_dir: #:::= it.output_dir :::# +#::: } :::# + +#::: if (it.log_every_iters) { :::# +log_every_iters: #:::= it.log_every_iters :::# +#::: } :::# diff --git a/src/templates/template-common/main.py b/src/templates/template-common/main.py new file mode 100644 index 00000000..6399ece3 --- /dev/null +++ b/src/templates/template-common/main.py @@ -0,0 +1,57 @@ +ckpt_handler_train, ckpt_handler_eval, timer = setup_handlers( + trainer, evaluator, config, to_save_train, to_save_eval +) + +#::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.timer || it.limit_sec) { :::# +if timer is not None: + logger.info("Time per batch: %.4f seconds", timer.value()) + timer.reset() +#::: } :::# + +#::: if (it.logger) { :::# +if rank == 0: + from ignite.contrib.handlers.wandb_logger import WandBLogger + + if isinstance(exp_logger, WandBLogger): + # why handle differently for wandb? + # See: https://github.com/pytorch/ignite/issues/1894 + exp_logger.finish() + elif exp_logger: + exp_logger.close() +#::: } :::# + +#::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.timer || it.limit_sec) { :::# +if ckpt_handler_train is not None: + logger.info( + "Last training checkpoint name - %s", + ckpt_handler_train.last_checkpoint, + ) + +if ckpt_handler_eval is not None: + logger.info( + "Last evaluation checkpoint name - %s", + ckpt_handler_eval.last_checkpoint, + ) +#::: } :::# + +# main entrypoint +@hydra.main(config_name="config") +def main(config): + #::: if (it.dist === 'spawn') { :::# + #::: if (it.nproc_per_node && it.nnodes && it.master_addr && it.master_port) { :::# + kwargs = { + "nproc_per_node": config.nproc_per_node, + "nnodes": config.nnodes, + "node_rank": config.node_rank, + "master_addr": config.master_addr, + "master_port": config.master_port, + } + #::: } else if (it.nproc_per_node) { :::# + kwargs = {"nproc_per_node": config.nproc_per_node} + #::: } :::# + with idist.Parallel(config.backend, **kwargs) as p: + p.run(run, config=config) + #::: } else { :::# + with idist.Parallel(config.backend) as p: + p.run(run, config=config) + #::: } :::# diff --git a/src/templates/template-common/requirements.txt b/src/templates/template-common/requirements.txt new file mode 100644 index 00000000..56dba2c0 --- /dev/null +++ b/src/templates/template-common/requirements.txt @@ -0,0 +1,14 @@ +torch>=1.8.0 +torchvision>=0.9.0 +pytorch-ignite>=0.4.4 +hydra-core>=1.0.0 + +#::: if (['neptune', 'polyaxon'].includes(it.logger)) { :::# + +#:::= it.logger + '-client' :::# + +#::: } else { :::# + +#:::= it.logger :::# + +#::: } :::# diff --git a/src/templates/template-vision-dcgan/README.md b/src/templates/template-vision-dcgan/README.md new file mode 100644 index 00000000..a0718abd --- /dev/null +++ b/src/templates/template-vision-dcgan/README.md @@ -0,0 +1,137 @@ +# Template by Code-Generator + +## Getting Started + +Install the dependencies with `pip`: + +```sh +pip install -r requirements.txt --progress-bar off -U +``` + +## Training + +#::: if (it.dist === 'launch') { :::# +#::: if (it.nproc_per_node) { :::# +#::: if (it.nnodes && it.master_addr && it.master_port) { :::# + +### Multi Node, Multi GPU Training (`torch.distributed.launch`) (recommended) + +- Execute on master node + +```sh +python -m torch.distributed.launch \ + --nproc_per_node #:::= nproc_per_node :::# \ + --nnodes #:::= it.nnodes :::# \ + --node_rank 0 \ + --master_addr #:::= it.master_addr :::# \ + --master_port #:::= it.master_port :::# \ + --use_env main.py backend=nccl \ + hydra.run.dir=. \ + hydra.output_subdir=null \ + hydra/job_logging=disabled \ + hydra/hydra_logging=disabled +``` + +- Execute on worker nodes + +```sh +python -m torch.distributed.launch \ + --nproc_per_node #:::= nproc_per_node :::# \ + --nnodes #:::= it.nnodes :::# \ + --node_rank \ + --master_addr #:::= it.master_addr :::# \ + --master_port #:::= it.master_port :::# \ + --use_env main.py backend=nccl \ + hydra.run.dir=. \ + hydra.output_subdir=null \ + hydra/job_logging=disabled \ + hydra/hydra_logging=disabled +``` + +#::: } else { :::# + +### Multi GPU Training (`torch.distributed.launch`) (recommended) + +```sh +python -m torch.distributed.launch \ + --nproc_per_node #:::= it.nproc_per_node :::# \ + --use_env main.py backend=nccl \ + hydra.run.dir=. \ + hydra.output_subdir=null \ + hydra/job_logging=disabled \ + hydra/hydra_logging=disabled +``` + +#::: } :::# +#::: } :::# +#::: } :::# + +#::: if (it.dist === 'spawn') { :::# +#::: if (it.nproc_per_node) { :::# +#::: if (it.nnodes && it.master_addr && it.master_port) { :::# + +### Multi Node, Multi GPU Training (`torch.multiprocessing.spawn`) + +- Execute on master node + +```sh +python main.py \ + nproc_per_node=#:::= nproc_per_node :::# \ + nnodes=#:::= it.nnodes :::# \ + node_rank=0 \ + master_addr=#:::= it.master_addr :::# \ + master_port=#:::= it.master_port :::# \ + backend=nccl \ + hydra.run.dir=. \ + hydra.output_subdir=null \ + hydra/job_logging=disabled \ + hydra/hydra_logging=disabled +``` + +- Execute on worker nodes + +```sh +python main.py \ + nproc_per_node=#:::= nproc_per_node :::# \ + nnodes=#:::= it.nnodes :::# \ + node_rank= \ + master_addr=#:::= it.master_addr :::# \ + master_port=#:::= it.master_port :::# \ + backend=nccl \ + hydra.run.dir=. \ + hydra.output_subdir=null \ + hydra/job_logging=disabled \ + hydra/hydra_logging=disabled +``` + +#::: } else { :::# + +### Multi GPU Training (`torch.multiprocessing.spawn`) + +```sh +python main.py \ + nproc_per_node=#:::= it.nproc_per_node :::# \ + backend=nccl \ + hydra.run.dir=. \ + hydra.output_subdir=null \ + hydra/job_logging=disabled \ + hydra/hydra_logging=disabled +``` + +#::: } :::# +#::: } :::# +#::: } :::# + +#::: if (!it.nproc_per_node) { :::# + +### 1 GPU Training + +```sh +python main.py \ + hydra.run.dir=. \ + hydra.output_subdir=null \ + hydra/job_logging=disabled \ + hydra/hydra_logging=disabled +``` + +#::: } :::# diff --git a/src/templates/template-vision-dcgan/config.yaml b/src/templates/template-vision-dcgan/config.yaml new file mode 100644 index 00000000..549e48e8 --- /dev/null +++ b/src/templates/template-vision-dcgan/config.yaml @@ -0,0 +1,55 @@ +backend: null +seed: 666 +data_path: ./ +train_batch_size: 4 +eval_batch_size: 8 +num_workers: 2 +max_epochs: 2 +train_epoch_length: 4 +eval_epoch_length: 4 +lr: 0.0001 +use_amp: false +verbose: false +z_dim: 100 +d_filters: 64 +g_filters: 64 + +#::: if (it.dist === 'spawn') { :::# +# distributed spawn +nproc_per_node: #:::= it.nproc_per_node :::# +#::: if (it.nnodes) { :::# +# distributed multi node spawn +nnodes: #:::= it.nnodes :::# +node_rank: 0 +master_addr: #:::= it.master_addr :::# +master_port: #:::= it.master_port :::# +#::: } :::# +#::: } :::# + +#::: if (it.filename_prefix) { :::# +filename_prefix: #:::= it.filename_prefix :::# +#::: } :::# + +#::: if (it.n_saved) { :::# +n_saved: #:::= it.n_saved :::# +#::: } :::# + +#::: if (it.save_every_iters) { :::# +save_every_iters: #:::= it.save_every_iters :::# +#::: } :::# + +#::: if (it.patience) { :::# +patience: #:::= it.patience :::# +#::: } :::# + +#::: if (it.limit_sec) { :::# +limit_sec: #:::= it.limit_sec :::# +#::: } :::# + +#::: if (it.output_dir) { :::# +output_dir: #:::= it.output_dir :::# +#::: } :::# + +#::: if (it.log_every_iters) { :::# +log_every_iters: #:::= it.log_every_iters :::# +#::: } :::# diff --git a/src/templates/template-vision-dcgan/data.py b/src/templates/template-vision-dcgan/data.py new file mode 100644 index 00000000..e1ddf1e5 --- /dev/null +++ b/src/templates/template-vision-dcgan/data.py @@ -0,0 +1,57 @@ +from typing import Any + +import ignite.distributed as idist +import torchvision +import torchvision.transforms as T + + +def setup_data(config: Any): + """Download datasets and create dataloaders + + Parameters + ---------- + config: needs to contain `data_path`, `train_batch_size`, `eval_batch_size`, and `num_workers` + """ + local_rank = idist.get_local_rank() + transform = T.Compose( + [ + T.Resize(64), + T.ToTensor(), + T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) + + if local_rank > 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + + dataset_train = torchvision.datasets.CIFAR10( + root=config.data_path, + train=True, + download=True, + transform=transform, + ) + dataset_eval = torchvision.datasets.CIFAR10( + root=config.data_path, + train=False, + download=True, + transform=transform, + ) + nc = 3 + if local_rank == 0: + # Ensure that only rank 0 download the dataset + idist.barrier() + + dataloader_train = idist.auto_dataloader( + dataset_train, + batch_size=config.train_batch_size, + shuffle=True, + num_workers=config.num_workers, + ) + dataloader_eval = idist.auto_dataloader( + dataset_eval, + batch_size=config.eval_batch_size, + shuffle=False, + num_workers=config.num_workers, + ) + return dataloader_train, dataloader_eval, nc diff --git a/src/templates/template-vision-dcgan/main.py b/src/templates/template-vision-dcgan/main.py new file mode 100644 index 00000000..4f99ebdb --- /dev/null +++ b/src/templates/template-vision-dcgan/main.py @@ -0,0 +1,235 @@ +import os +from typing import Any + +import hydra +import ignite.distributed as idist +import torch +import torchvision.utils as vutils +from data import setup_data +from ignite.engine import Events +from ignite.utils import manual_seed +from model import Discriminator, Generator +from omegaconf import OmegaConf +from torch import nn, optim +from torch.utils.data.distributed import DistributedSampler +from trainers import setup_evaluator, setup_trainer +from utils import * + +FAKE_IMG_FNAME = "fake_sample_epoch_{:04d}.png" +REAL_IMG_FNAME = "real_sample_epoch_{:04d}.png" + + +def run(local_rank: int, config: Any): + + # make a certain seed + rank = idist.get_rank() + manual_seed(config.seed + rank) + + # create output folder + config.output_dir = setup_output_dir(config, rank) + + # donwload datasets and create dataloaders + dataloader_train, dataloader_eval, num_channels = setup_data(config) + + # model, optimizer, loss function, device + device = idist.device() + + fixed_noise = torch.randn( + config.train_batch_size // idist.get_world_size(), + config.z_dim, + 1, + 1, + device=device, + ) + + # networks + model_g = idist.auto_model( + Generator(config.z_dim, config.g_filters, num_channels) + ) + model_d = idist.auto_model(Discriminator(num_channels, config.d_filters)) + + # loss + loss_fn = nn.BCELoss().to(device=device) + + # optimizers + optimizer_d = idist.auto_optim( + optim.Adam(model_d.parameters(), lr=config.lr, betas=(0.5, 0.999)) + ) + optimizer_g = idist.auto_optim( + optim.Adam(model_g.parameters(), lr=config.lr, betas=(0.5, 0.999)) + ) + + # trainer and evaluator + trainer = setup_trainer( + config=config, + model_g=model_g, + model_d=model_d, + optimizer_d=optimizer_d, + optimizer_g=optimizer_g, + loss_fn=loss_fn, + device=device, + ) + evaluator = setup_evaluator( + config=config, + model_g=model_g, + model_d=model_d, + loss_fn=loss_fn, + device=device, + ) + + # setup engines logger with python logging + # print training configurations + logger = setup_logging(config) + logger.info("Configuration: \n%s", OmegaConf.to_yaml(config)) + trainer.logger = evaluator.logger = logger + + # set epoch for distributed sampler + if idist.get_world_size() > 1 and isinstance( + dataloader_train.sampler, DistributedSampler + ): + dataloader_train.sampler.set_epoch(trainer.state.epoch - 1) + + # setup ignite handlers + #::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.timer || it.limit_sec) { :::# + + #::: if (it.save_training) { :::# + to_save_train = { + "model_d": model_d, + "model_g": model_g, + "optimizer_d": optimizer_d, + "optimizer_g": optimizer_g, + "trainer": trainer, + } + #::: } else { :::# + to_save_train = None + #::: } :::# + + #::: if (it.save_evaluation) { :::# + to_save_eval = {"model_d": model_d, "model_g": model_g} + #::: } else { :::# + to_save_eval = None + #::: } :::# + + ckpt_handler_train, ckpt_handler_eval, timer = setup_handlers( + trainer, evaluator, config, to_save_train, to_save_eval + ) + #::: } :::# + + # experiment tracking + #::: if (it.logger) { :::# + if rank == 0: + exp_logger = setup_exp_logging( + config, + trainer, + {"optimizer_d": optimizer_d, "optimizer_g": optimizer_g}, + evaluator, + ) + #::: } :::# + + # print metrics to the stderr + # with `add_event_handler` API + # for training stats + trainer.add_event_handler( + Events.ITERATION_COMPLETED(every=config.log_every_iters), + log_metrics, + tag="train", + ) + + # adding handlers using `trainer.on` decorator API + @trainer.on(Events.EPOCH_COMPLETED) + def save_fake_example(engine): + fake = model_g(fixed_noise) + path = os.path.join( + config.output_dir, FAKE_IMG_FNAME.format(engine.state.epoch) + ) + vutils.save_image(fake.detach(), path, normalize=True) + + # adding handlers using `trainer.on` decorator API + @trainer.on(Events.EPOCH_COMPLETED) + def save_real_example(engine): + img, y = engine.state.batch + path = os.path.join( + config.output_dir, REAL_IMG_FNAME.format(engine.state.epoch) + ) + vutils.save_image(img, path, normalize=True) + + # run evaluation at every training epoch end + # with shortcut `on` decorator API and + # print metrics to the stderr + # again with `add_event_handler` API + # for evaluation stats + @trainer.on(Events.EPOCH_COMPLETED(every=1)) + def _(): + #::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.timer || it.limit_sec) { :::# + if timer is not None: + logger.info("Time per batch: %.4f seconds", timer.value()) + timer.reset() + #::: } :::# + + evaluator.run(dataloader_eval, epoch_length=config.eval_epoch_length) + log_metrics(evaluator, "eval") + + # let's try run evaluation first as a sanity check + @trainer.on(Events.STARTED) + def _(): + evaluator.run(dataloader_eval, epoch_length=config.eval_epoch_length) + + # setup if done. let's run the training + trainer.run( + dataloader_train, + max_epochs=config.max_epochs, + epoch_length=config.train_epoch_length, + ) + + #::: if (it.logger) { :::# + if rank == 0: + from ignite.contrib.handlers.wandb_logger import WandBLogger + + if isinstance(exp_logger, WandBLogger): + # why handle differently for wandb? + # See: https://github.com/pytorch/ignite/issues/1894 + exp_logger.finish() + elif exp_logger: + exp_logger.close() + #::: } :::# + + #::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.timer || it.limit_sec) { :::# + if ckpt_handler_train is not None: + logger.info( + "Last training checkpoint name - %s", + ckpt_handler_train.last_checkpoint, + ) + + if ckpt_handler_eval is not None: + logger.info( + "Last evaluation checkpoint name - %s", + ckpt_handler_eval.last_checkpoint, + ) + #::: } :::# + + +# main entrypoint +@hydra.main(config_name="config") +def main(config): + #::: if (it.dist === 'spawn') { :::# + #::: if (it.nproc_per_node && it.nnodes && it.master_addr && it.master_port) { :::# + kwargs = { + "nproc_per_node": config.nproc_per_node, + "nnodes": config.nnodes, + "node_rank": config.node_rank, + "master_addr": config.master_addr, + "master_port": config.master_port, + } + #::: } else if (it.nproc_per_node) { :::# + kwargs = {"nproc_per_node": config.nproc_per_node} + #::: } :::# + with idist.Parallel(config.backend, **kwargs) as p: + p.run(run, config=config) + #::: } else { :::# + with idist.Parallel(config.backend) as p: + p.run(run, config=config) + #::: } :::# + + +if __name__ == "__main__": + main() diff --git a/src/templates/template-vision-dcgan/model.py b/src/templates/template-vision-dcgan/model.py new file mode 100644 index 00000000..341a6679 --- /dev/null +++ b/src/templates/template-vision-dcgan/model.py @@ -0,0 +1,166 @@ +from torch import nn + + +class Net(nn.Module): + """A base class for both generator and the discriminator. + Provides a common weight initialization scheme. + """ + + def weights_init(self): + for m in self.modules(): + classname = m.__class__.__name__ + + if "Conv" in classname: + m.weight.data.normal_(0.0, 0.02) + + elif "BatchNorm" in classname: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + + def forward(self, x): + return x + + +class Generator(Net): + """Generator network. + Args: + nf (int): Number of filters in the second-to-last deconv layer + """ + + def __init__(self, z_dim, nf, nc): + super(Generator, self).__init__() + + self.net = nn.Sequential( + # input is Z, going into a convolution + nn.ConvTranspose2d( + in_channels=z_dim, + out_channels=nf * 8, + kernel_size=4, + stride=1, + padding=0, + bias=False, + ), + nn.BatchNorm2d(nf * 8), + nn.ReLU(inplace=True), + # state size. (nf*8) x 4 x 4 + nn.ConvTranspose2d( + in_channels=nf * 8, + out_channels=nf * 4, + kernel_size=4, + stride=2, + padding=1, + bias=False, + ), + nn.BatchNorm2d(nf * 4), + nn.ReLU(inplace=True), + # state size. (nf*4) x 8 x 8 + nn.ConvTranspose2d( + in_channels=nf * 4, + out_channels=nf * 2, + kernel_size=4, + stride=2, + padding=1, + bias=False, + ), + nn.BatchNorm2d(nf * 2), + nn.ReLU(inplace=True), + # state size. (nf*2) x 16 x 16 + nn.ConvTranspose2d( + in_channels=nf * 2, + out_channels=nf, + kernel_size=4, + stride=2, + padding=1, + bias=False, + ), + nn.BatchNorm2d(nf), + nn.ReLU(inplace=True), + # state size. (nf) x 32 x 32 + nn.ConvTranspose2d( + in_channels=nf, + out_channels=nc, + kernel_size=4, + stride=2, + padding=1, + bias=False, + ), + nn.Tanh() + # state size. (nc) x 64 x 64 + ) + + self.weights_init() + + def forward(self, x): + return self.net(x) + + +class Discriminator(Net): + """Discriminator network. + Args: + nf (int): Number of filters in the first conv layer. + """ + + def __init__(self, nc, nf): + super(Discriminator, self).__init__() + + self.net = nn.Sequential( + # input is (nc) x 64 x 64 + nn.Conv2d( + in_channels=nc, + out_channels=nf, + kernel_size=4, + stride=2, + padding=1, + bias=False, + ), + nn.LeakyReLU(0.2, inplace=True), + # state size. (nf) x 32 x 32 + nn.Conv2d( + in_channels=nf, + out_channels=nf * 2, + kernel_size=4, + stride=2, + padding=1, + bias=False, + ), + nn.BatchNorm2d(nf * 2), + nn.LeakyReLU(0.2, inplace=True), + # state size. (nf*2) x 16 x 16 + nn.Conv2d( + in_channels=nf * 2, + out_channels=nf * 4, + kernel_size=4, + stride=2, + padding=1, + bias=False, + ), + nn.BatchNorm2d(nf * 4), + nn.LeakyReLU(0.2, inplace=True), + # state size. (nf*4) x 8 x 8 + nn.Conv2d( + in_channels=nf * 4, + out_channels=nf * 8, + kernel_size=4, + stride=2, + padding=1, + bias=False, + ), + nn.BatchNorm2d(nf * 8), + nn.LeakyReLU(0.2, inplace=True), + # state size. (nf*8) x 4 x 4 + nn.Conv2d( + in_channels=nf * 8, + out_channels=1, + kernel_size=4, + stride=1, + padding=0, + bias=False, + ), + nn.Sigmoid(), + ) + + self.weights_init() + + def forward(self, x): + output = self.net(x) + return output.view(-1, 1).squeeze(1) diff --git a/src/templates/template-vision-dcgan/requirements.txt b/src/templates/template-vision-dcgan/requirements.txt new file mode 100644 index 00000000..56dba2c0 --- /dev/null +++ b/src/templates/template-vision-dcgan/requirements.txt @@ -0,0 +1,14 @@ +torch>=1.8.0 +torchvision>=0.9.0 +pytorch-ignite>=0.4.4 +hydra-core>=1.0.0 + +#::: if (['neptune', 'polyaxon'].includes(it.logger)) { :::# + +#:::= it.logger + '-client' :::# + +#::: } else { :::# + +#:::= it.logger :::# + +#::: } :::# diff --git a/src/templates/template-vision-dcgan/trainers.py b/src/templates/template-vision-dcgan/trainers.py new file mode 100644 index 00000000..7801aad4 --- /dev/null +++ b/src/templates/template-vision-dcgan/trainers.py @@ -0,0 +1,159 @@ +from typing import Any, Union + +import ignite.distributed as idist +import torch +from ignite.engine import DeterministicEngine, Engine +from torch.cuda.amp import autocast +from torch.nn import Module +from torch.optim import Optimizer + + +def setup_trainer( + config: Any, + model_g: Module, + model_d: Module, + optimizer_d: Optimizer, + optimizer_g: Optimizer, + loss_fn: Module, + device: Union[str, torch.device], +) -> Union[Engine, DeterministicEngine]: + + ws = idist.get_world_size() + + real_labels = torch.ones(config.train_batch_size // ws, device=device) + fake_labels = torch.zeros(config.train_batch_size // ws, device=device) + noise = torch.randn( + config.train_batch_size // ws, config.z_dim, 1, 1, device=device + ) + + def train_function(engine: Union[Engine, DeterministicEngine], batch: Any): + model_g.train() + model_d.train() + + # unpack the batch. It comes from a dataset, so we have pairs. Discard labels. + real = batch[0].to(device, non_blocking=True) + + # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) + model_d.zero_grad() + + # train with real + with autocast(config.use_amp): + outputs = model_d(real) + errD_real = loss_fn(outputs, real_labels) + + D_x = outputs.mean().item() + errD_real.backward() + + # get fake image from generator + fake = model_g(noise) + + # train with fake + with autocast(config.use_amp): + outputs = model_d(fake.detach()) + errD_fake = loss_fn(outputs, fake_labels) + + D_G_z1 = outputs.mean().item() + + errD_fake.backward() + + errD = errD_real + errD_fake + optimizer_d.step() + + # (2) Update G network: maximize log(D(G(z))) + model_g.zero_grad() + + # Update generator. We want to make a step that will make it more likely that discriminator outputs "real" + with autocast(config.use_amp): + output = model_d(fake) + errG = loss_fn(output, real_labels) + + D_G_z2 = output.mean().item() + + errG.backward() + + # gradient update + optimizer_g.step() + + metrics = { + "epoch": engine.state.epoch, + "errD": errD.item(), + "errG": errG.item(), + "D_x": D_x, + "D_G_z1": D_G_z1, + "D_G_z2": D_G_z2, + } + engine.state.metrics = metrics + + return metrics + + #::: if(it.deterministic) { :::# + return DeterministicEngine(train_function) + #::: } else { :::# + return Engine(train_function) + #::: } :::# + + +def setup_evaluator( + config: Any, + model_g: Module, + model_d: Module, + loss_fn: Module, + device: Union[str, torch.device], +) -> Engine: + + ws = idist.get_world_size() + + real_labels = torch.ones(config.eval_batch_size // ws, device=device) + fake_labels = torch.zeros(config.eval_batch_size // ws, device=device) + noise = torch.randn( + config.eval_batch_size // ws, config.z_dim, 1, 1, device=device + ) + + @torch.no_grad() + def eval_function(engine: Engine, batch: Any): + model_g.eval() + model_d.eval() + + # unpack the batch. It comes from a dataset, so we have pairs. Discard labels. + real = batch[0].to(device, non_blocking=True) + + # train with real + with autocast(config.use_amp): + outputs = model_d(real) + errD_real = loss_fn(outputs, real_labels) + + D_x = outputs.mean().item() + + # get fake image from generator + fake = model_g(noise) + + # train with fake + with autocast(config.use_amp): + outputs = model_d(fake.detach()) + errD_fake = loss_fn(outputs, fake_labels) + + D_G_z1 = outputs.mean().item() + + errD = errD_real + errD_fake + + # Update generator. We want to make a step that will make it more likely that discriminator outputs "real" + with autocast(config.use_amp): + output = model_d(fake) + errG = loss_fn(output, real_labels) + + D_G_z2 = output.mean().item() + + metrics = { + "epoch": engine.state.epoch, + "errD": errD.item(), + "eval_loss": errD.item(), + "errG": errG.item(), + "D_x": D_x, + "D_G_z1": D_G_z1, + "D_G_z2": D_G_z2, + } + engine.state.metrics = metrics + + return metrics + + return Engine(eval_function) diff --git a/src/templates/template-vision-dcgan/utils.py b/src/templates/template-vision-dcgan/utils.py new file mode 100644 index 00000000..fedc6de5 --- /dev/null +++ b/src/templates/template-vision-dcgan/utils.py @@ -0,0 +1,246 @@ +import logging +from datetime import datetime +from logging import Logger +from pathlib import Path +from typing import Any, Mapping, Optional, Union + +import ignite.distributed as idist +import torch +from ignite.contrib.engines import common +from ignite.engine import Engine +from ignite.engine.events import Events +from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine +from ignite.handlers.early_stopping import EarlyStopping +from ignite.handlers.terminate_on_nan import TerminateOnNan +from ignite.handlers.time_limit import TimeLimit +from ignite.handlers.timing import Timer +from ignite.utils import setup_logger + + +def log_metrics(engine: Engine, tag: str) -> None: + """Log `engine.state.metrics` with given `engine` and `tag`. + + Parameters + ---------- + engine + instance of `Engine` which metrics to log. + tag + a string to add at the start of output. + """ + metrics_format = "{0} [{1}/{2}]: {3}".format( + tag, engine.state.epoch, engine.state.iteration, engine.state.metrics + ) + engine.logger.info(metrics_format) + + +def resume_from( + to_load: Mapping, + checkpoint_fp: Union[str, Path], + logger: Logger, + strict: bool = True, + model_dir: Optional[str] = None, +) -> None: + """Loads state dict from a checkpoint file to resume the training. + + Parameters + ---------- + to_load + a dictionary with objects, e.g. {“model”: model, “optimizer”: optimizer, ...} + checkpoint_fp + path to the checkpoint file + logger + to log info about resuming from a checkpoint + strict + whether to strictly enforce that the keys in `state_dict` match the keys + returned by this module’s `state_dict()` function. Default: True + model_dir + directory in which to save the object + """ + if isinstance(checkpoint_fp, str) and checkpoint_fp.startswith("https://"): + checkpoint = torch.hub.load_state_dict_from_url( + checkpoint_fp, + model_dir=model_dir, + map_location="cpu", + check_hash=True, + ) + else: + if isinstance(checkpoint_fp, str): + checkpoint_fp = Path(checkpoint_fp) + + if not checkpoint_fp.exists(): + raise FileNotFoundError( + f"Given {str(checkpoint_fp)} does not exist." + ) + checkpoint = torch.load(checkpoint_fp, map_location="cpu") + + Checkpoint.load_objects( + to_load=to_load, checkpoint=checkpoint, strict=strict + ) + logger.info("Successfully resumed from a checkpoint: %s", checkpoint_fp) + + +def setup_output_dir(config: Any, rank: int): + """Create output folder.""" + if rank == 0: + now = datetime.now().strftime("%Y%m%d-%H%M%S") + name = f"{now}-backend-{config.backend}-lr-{config.lr}" + path = Path(config.output_dir, name) + path.mkdir(parents=True, exist_ok=True) + config.output_dir = path.as_posix() + + return idist.broadcast(config.output_dir, src=0) + + +def setup_logging(config: Any) -> Logger: + """Setup logger with `ignite.utils.setup_logger()`. + + Parameters + ---------- + config + config object. config has to contain `verbose` and `output_dir` attribute. + + Returns + ------- + logger + an instance of `Logger` + """ + green = "\033[32m" + reset = "\033[0m" + logger = setup_logger( + name=f"{green}[ignite]{reset}", + level=logging.DEBUG if config.verbose else logging.INFO, + format="%(name)s: %(message)s", + filepath=Path(config.output_dir) / "training-info.log", + ) + return logger + + +#::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.timer || it.limit_sec) { :::# + + +def setup_handlers( + trainer: Engine, + evaluator: Engine, + config: Any, + to_save_train: Optional[dict] = None, + to_save_eval: Optional[dict] = None, +): + """Setup Ignite handlers.""" + + ckpt_handler_train = ckpt_handler_eval = timer = None + #::: if (it.save_training || it.save_evaluation) { :::# + # checkpointing + saver = DiskSaver( + Path(config.output_dir) / "checkpoints", require_empty=False + ) + #::: if (it.save_training) { :::# + ckpt_handler_train = Checkpoint( + to_save_train, + saver, + filename_prefix=config.filename_prefix, + n_saved=config.n_saved, + ) + trainer.add_event_handler( + Events.ITERATION_COMPLETED(every=config.save_every_iters), + ckpt_handler_train, + ) + #::: } :::# + #::: if (it.save_evaluation) { :::# + global_step_transform = None + if to_save_train.get("trainer", None) is not None: + global_step_transform = global_step_from_engine( + to_save_train["trainer"] + ) + ckpt_handler_eval = Checkpoint( + to_save_eval, + saver, + filename_prefix="best", + n_saved=config.n_saved, + global_step_transform=global_step_transform, + ) + evaluator.add_event_handler( + Events.EPOCH_COMPLETED(every=1), ckpt_handler_eval + ) + #::: } :::# + #::: } :::# + + #::: if (it.patience) { :::# + # early stopping + def score_fn(engine: Engine): + return -engine.state.metrics["eval_loss"] + + es = EarlyStopping(config.patience, score_fn, trainer) + evaluator.add_event_handler(Events.EPOCH_COMPLETED, es) + #::: } :::# + + #::: if (it.terminate_on_nan) { :::# + # terminate on nan + trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) + #::: } :::# + + #::: if (it.timer) { :::# + # timer + timer = Timer(average=True) + timer.attach( + trainer, + start=Events.EPOCH_STARTED, + resume=Events.ITERATION_STARTED, + pause=Events.ITERATION_COMPLETED, + step=Events.ITERATION_COMPLETED, + ) + #::: } :::# + + #::: if (it.limit_sec) { :::# + # time limit + trainer.add_event_handler( + Events.ITERATION_COMPLETED, TimeLimit(config.limit_sec) + ) + #::: } :::# + return ckpt_handler_train, ckpt_handler_eval, timer + + +#::: } :::# + +#::: if (it.logger) { :::# + + +def setup_exp_logging(config, trainer, optimizers, evaluators): + """Setup Experiment Tracking logger from Ignite.""" + + #::: if (it.logger === 'clearml') { :::# + logger = common.setup_clearml_logging( + trainer, optimizers, evaluators, config.log_every_iters + ) + #::: } else if (it.logger === 'mlflow') { :::# + logger = common.setup_mlflow_logging( + trainer, optimizers, evaluators, config.log_every_iters + ) + #::: } else if (it.logger === 'neptune') { :::# + logger = common.setup_neptune_logging( + trainer, optimizers, evaluators, config.log_every_iters + ) + #::: } else if (it.logger === 'polyaxon') { :::# + logger = common.setup_plx_logging( + trainer, optimizers, evaluators, config.log_every_iters + ) + #::: } else if (it.logger === 'tensorboard') { :::# + logger = common.setup_tb_logging( + config.output_dir, + trainer, + optimizers, + evaluators, + config.log_every_iters, + ) + #::: } else if (it.logger === 'visdom') { :::# + logger = common.setup_visdom_logging( + trainer, optimizers, evaluators, config.log_every_iters + ) + #::: } else if (it.logger === 'wandb') { :::# + logger = common.setup_wandb_logging( + trainer, optimizers, evaluators, config.log_every_iters + ) + #::: } :::# + return logger + + +#::: } :::# diff --git a/src/templates/templates.json b/src/templates/templates.json index aaced4a5..d0b4ec8f 100644 --- a/src/templates/templates.json +++ b/src/templates/templates.json @@ -8,5 +8,15 @@ "trainers.py", "utils.py", "requirements.txt" + ], + "template-vision-dcgan": [ + "README.md", + "config.yaml", + "data.py", + "main.py", + "model.py", + "trainers.py", + "utils.py", + "requirements.txt" ] }