diff --git a/src/templates/template-vision-segmentation/main.py b/src/templates/template-vision-segmentation/main.py index 9ad19858..7133ad28 100644 --- a/src/templates/template-vision-segmentation/main.py +++ b/src/templates/template-vision-segmentation/main.py @@ -1,15 +1,16 @@ from functools import partial from pprint import pformat -from typing import Any +from typing import Any, cast import ignite.distributed as idist import yaml from data import denormalize, download_datasets, setup_data +from ignite.contrib.handlers import LRScheduler from ignite.engine import Events from ignite.metrics import ConfusionMatrix, IoU, mIoU from ignite.utils import manual_seed from torch import nn, optim -from torch.optim.lr_scheduler import LambdaLR +from torch.optim.lr_scheduler import LambdaLR, _LRScheduler from torch.utils.data.distributed import DistributedSampler from torchvision.models.segmentation import deeplabv3_resnet101 from trainers import setup_evaluator, setup_trainer @@ -80,6 +81,16 @@ def set_epoch(): ): dataloader_train.sampler.set_epoch(trainer.state.epoch - 1) + if isinstance(lr_scheduler, _LRScheduler): + trainer.add_event_handler( + Events.ITERATION_COMPLETED, + lambda engine: cast(_LRScheduler, lr_scheduler).step(), + ) + elif isinstance(lr_scheduler, LRScheduler): + trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler) + else: + trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler) + # setup ignite handlers #::: if (it.save_training || it.save_evaluation || it.patience || it.terminate_on_nan || it.timer || it.limit_sec) { :::#