4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import inspect
7
8
import logging
8
9
import os
9
10
from typing import Any , Dict , Optional , Tuple
@@ -61,6 +62,8 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
61
62
increasing epoch indices at which the learning rate is modified.
62
63
momentum: Momentum factor for SGD optimizer.
63
64
weight_decay: The optimizer weight_decay (L2 penalty on model weights).
65
+ foreach: Whether to use new "foreach" implementation of optimizer where
66
+ available (e.g. requires PyTorch 1.12.0 for Adam)
64
67
"""
65
68
66
69
betas : Tuple [float , ...] = (0.9 , 0.999 )
@@ -74,6 +77,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
74
77
weight_decay : float = 0.0
75
78
linear_exponential_lr_milestone : int = 200
76
79
linear_exponential_start_gamma : float = 0.1
80
+ foreach : Optional [bool ] = True
77
81
78
82
def __post_init__ (self ):
79
83
run_auto_creation (self )
@@ -115,23 +119,24 @@ def __call__(
115
119
p_groups = [{"params" : allprm , "lr" : self .lr }]
116
120
117
121
# Intialize the optimizer
122
+ optimizer_kwargs : Dict [str , Any ] = {
123
+ "lr" : self .lr ,
124
+ "weight_decay" : self .weight_decay ,
125
+ }
118
126
if self .breed == "SGD" :
119
- optimizer = torch .optim .SGD (
120
- p_groups ,
121
- lr = self .lr ,
122
- momentum = self .momentum ,
123
- weight_decay = self .weight_decay ,
124
- )
127
+ optimizer_class = torch .optim .SGD
128
+ optimizer_kwargs ["momentum" ] = self .momentum
125
129
elif self .breed == "Adagrad" :
126
- optimizer = torch .optim .Adagrad (
127
- p_groups , lr = self .lr , weight_decay = self .weight_decay
128
- )
130
+ optimizer_class = torch .optim .Adagrad
129
131
elif self .breed == "Adam" :
130
- optimizer = torch .optim .Adam (
131
- p_groups , lr = self .lr , betas = self .betas , weight_decay = self .weight_decay
132
- )
132
+ optimizer_class = torch .optim .Adam
133
+ optimizer_kwargs ["betas" ] = self .betas
133
134
else :
134
135
raise ValueError (f"No such solver type { self .breed } " )
136
+
137
+ if "foreach" in inspect .signature (optimizer_class .__init__ ).parameters :
138
+ optimizer_kwargs ["foreach" ] = self .foreach
139
+ optimizer = optimizer_class (p_groups , ** optimizer_kwargs )
135
140
logger .info (f"Solver type = { self .breed } " )
136
141
137
142
# Load state from checkpoint
0 commit comments