|
11 | 11 | from ignite.utils import manual_seed
|
12 | 12 | from models import setup_model
|
13 | 13 | from torch import nn, optim
|
14 |
| -from torch.optim.lr_scheduler import _LRScheduler, LambdaLR |
| 14 | +from torch.optim.lr_scheduler import LambdaLR |
15 | 15 | from trainers import setup_evaluator, setup_trainer
|
16 | 16 | from utils import *
|
17 | 17 | from vis import predictions_gt_images_handler
|
18 | 18 |
|
| 19 | +try: |
| 20 | + from torch.optim.lr_scheduler import LRScheduler as PyTorchLRScheduler |
| 21 | +except ImportError: |
| 22 | + from torch.optim.lr_scheduler import _LRScheduler as PyTorchLRScheduler |
19 | 23 |
|
20 | 24 | def run(local_rank: int, config: Any):
|
21 | 25 | # make a certain seed
|
@@ -71,10 +75,10 @@ def run(local_rank: int, config: Any):
|
71 | 75 | (config.output_dir / "config-lock.yaml").write_text(yaml.dump(config))
|
72 | 76 | trainer.logger = evaluator.logger = logger
|
73 | 77 |
|
74 |
| - if isinstance(lr_scheduler, _LRScheduler): |
| 78 | + if isinstance(lr_scheduler, PyTorchLRScheduler): |
75 | 79 | trainer.add_event_handler(
|
76 | 80 | Events.ITERATION_COMPLETED,
|
77 |
| - lambda engine: cast(_LRScheduler, lr_scheduler).step(), |
| 81 | + lambda engine: cast(PyTorchLRScheduler, lr_scheduler).step(), |
78 | 82 | )
|
79 | 83 | elif isinstance(lr_scheduler, LRScheduler):
|
80 | 84 | trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)
|
|
0 commit comments