Skip to content

Commit e78ef13

Browse files
Fix LRScheduler issue of PyTorch
- PyTorch 2.0.0 has deprecated _LRScheduler, use LRScheduler instead - Imported _LRScheduler or LRScheduler as PyTorchLRScheduler
1 parent 60c056c commit e78ef13

File tree

2 files changed

+14
-6
lines changed
  • src/templates
    • template-text-classification
    • template-vision-segmentation

2 files changed

+14
-6
lines changed

src/templates/template-text-classification/main.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@
1111
from ignite.utils import manual_seed
1212
from models import TransformerModel
1313
from torch import nn, optim
14-
from torch.optim.lr_scheduler import _LRScheduler
1514
from trainers import setup_evaluator, setup_trainer
1615
from utils import *
1716

17+
try:
18+
from torch.optim.lr_scheduler import _LRScheduler as PyTorchLRScheduler
19+
except ImportError:
20+
from torch.optim.lr_scheduler import LRScheduler as PyTorchLRScheduler
21+
1822
os.environ["TOKENIZERS_PARALLELISM"] = "false" # remove tokenizer paralleism warning
1923

2024

@@ -78,10 +82,10 @@ def run(local_rank: int, config: Any):
7882
(config.output_dir / "config-lock.yaml").write_text(yaml.dump(config))
7983
trainer.logger = evaluator.logger = logger
8084

81-
if isinstance(lr_scheduler, _LRScheduler):
85+
if isinstance(lr_scheduler, PyTorchLRScheduler):
8286
trainer.add_event_handler(
8387
Events.ITERATION_COMPLETED,
84-
lambda engine: cast(_LRScheduler, lr_scheduler).step(),
88+
lambda engine: cast(PyTorchLRScheduler, lr_scheduler).step(),
8589
)
8690
elif isinstance(lr_scheduler, LRScheduler):
8791
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)

src/templates/template-vision-segmentation/main.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@
1111
from ignite.utils import manual_seed
1212
from models import setup_model
1313
from torch import nn, optim
14-
from torch.optim.lr_scheduler import _LRScheduler, LambdaLR
14+
from torch.optim.lr_scheduler import LambdaLR
1515
from trainers import setup_evaluator, setup_trainer
1616
from utils import *
1717
from vis import predictions_gt_images_handler
1818

19+
try:
20+
from torch.optim.lr_scheduler import LRScheduler as PyTorchLRScheduler
21+
except ImportError:
22+
from torch.optim.lr_scheduler import _LRScheduler as PyTorchLRScheduler
1923

2024
def run(local_rank: int, config: Any):
2125
# make a certain seed
@@ -71,10 +75,10 @@ def run(local_rank: int, config: Any):
7175
(config.output_dir / "config-lock.yaml").write_text(yaml.dump(config))
7276
trainer.logger = evaluator.logger = logger
7377

74-
if isinstance(lr_scheduler, _LRScheduler):
78+
if isinstance(lr_scheduler, PyTorchLRScheduler):
7579
trainer.add_event_handler(
7680
Events.ITERATION_COMPLETED,
77-
lambda engine: cast(_LRScheduler, lr_scheduler).step(),
81+
lambda engine: cast(PyTorchLRScheduler, lr_scheduler).step(),
7882
)
7983
elif isinstance(lr_scheduler, LRScheduler):
8084
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)

0 commit comments

Comments
 (0)