We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 7fb8d06 commit 6e535dbCopy full SHA for 6e535db
references/segmentation/train.py
@@ -10,6 +10,7 @@
10
import utils
11
from coco_utils import get_coco
12
from torch import nn
13
+from torch.optim.lr_scheduler import PolynomialLR
14
from torchvision.transforms import functional as F, InterpolationMode
15
16
@@ -184,8 +185,8 @@ def main(args):
184
185
scaler = torch.cuda.amp.GradScaler() if args.amp else None
186
187
iters_per_epoch = len(data_loader)
- main_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
188
- optimizer, lambda x: (1 - x / (iters_per_epoch * (args.epochs - args.lr_warmup_epochs))) ** 0.9
+ main_lr_scheduler = PolynomialLR(
189
+ optimizer, total_steps=iters_per_epoch * (args.epochs - args.lr_warmup_epochs), power=0.9
190
)
191
192
if args.lr_warmup_epochs > 0:
0 commit comments