Skip to content

Commit 209c160

Browse files
bottlerfacebook-github-bot
authored andcommitted
foreach optimizers
Summary: Allow using the new `foreach` option on optimizers. Reviewed By: shapovalov Differential Revision: D39694843 fbshipit-source-id: 97109c245b669bc6edff0f246893f95b7ae71f90
1 parent db3c12a commit 209c160

File tree

3 files changed

+33
-13
lines changed

3 files changed

+33
-13
lines changed

projects/implicitron_trainer/impl/optimizer_factory.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import inspect
78
import logging
89
import os
910
from typing import Any, Dict, Optional, Tuple
@@ -61,6 +62,8 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
6162
increasing epoch indices at which the learning rate is modified.
6263
momentum: Momentum factor for SGD optimizer.
6364
weight_decay: The optimizer weight_decay (L2 penalty on model weights).
65+
foreach: Whether to use new "foreach" implementation of optimizer where
66+
available (e.g. requires PyTorch 1.12.0 for Adam)
6467
"""
6568

6669
betas: Tuple[float, ...] = (0.9, 0.999)
@@ -74,6 +77,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
7477
weight_decay: float = 0.0
7578
linear_exponential_lr_milestone: int = 200
7679
linear_exponential_start_gamma: float = 0.1
80+
foreach: Optional[bool] = True
7781

7882
def __post_init__(self):
7983
run_auto_creation(self)
@@ -115,23 +119,24 @@ def __call__(
115119
p_groups = [{"params": allprm, "lr": self.lr}]
116120

117121
# Intialize the optimizer
122+
optimizer_kwargs: Dict[str, Any] = {
123+
"lr": self.lr,
124+
"weight_decay": self.weight_decay,
125+
}
118126
if self.breed == "SGD":
119-
optimizer = torch.optim.SGD(
120-
p_groups,
121-
lr=self.lr,
122-
momentum=self.momentum,
123-
weight_decay=self.weight_decay,
124-
)
127+
optimizer_class = torch.optim.SGD
128+
optimizer_kwargs["momentum"] = self.momentum
125129
elif self.breed == "Adagrad":
126-
optimizer = torch.optim.Adagrad(
127-
p_groups, lr=self.lr, weight_decay=self.weight_decay
128-
)
130+
optimizer_class = torch.optim.Adagrad
129131
elif self.breed == "Adam":
130-
optimizer = torch.optim.Adam(
131-
p_groups, lr=self.lr, betas=self.betas, weight_decay=self.weight_decay
132-
)
132+
optimizer_class = torch.optim.Adam
133+
optimizer_kwargs["betas"] = self.betas
133134
else:
134135
raise ValueError(f"No such solver type {self.breed}")
136+
137+
if "foreach" in inspect.signature(optimizer_class.__init__).parameters:
138+
optimizer_kwargs["foreach"] = self.foreach
139+
optimizer = optimizer_class(p_groups, **optimizer_kwargs)
135140
logger.info(f"Solver type = {self.breed}")
136141

137142
# Load state from checkpoint

projects/implicitron_trainer/tests/experiment.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ optimizer_factory_ImplicitronOptimizerFactory_args:
406406
weight_decay: 0.0
407407
linear_exponential_lr_milestone: 200
408408
linear_exponential_start_gamma: 0.1
409+
foreach: true
409410
training_loop_ImplicitronTrainingLoop_args:
410411
evaluator_class_type: ImplicitronEvaluator
411412
evaluator_ImplicitronEvaluator_args:

projects/implicitron_trainer/tests/test_experiment.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,17 @@
99
import unittest
1010
from pathlib import Path
1111

12+
import torch
13+
1214
from hydra import compose, initialize_config_dir
1315
from omegaconf import OmegaConf
16+
from projects.implicitron_trainer.impl.optimizer_factory import (
17+
ImplicitronOptimizerFactory,
18+
)
1419

1520
from .. import experiment
1621
from .utils import interactive_testing_requested, intercept_logs
1722

18-
1923
internal = os.environ.get("FB_TEST", False)
2024

2125

@@ -151,6 +155,16 @@ def test_load_configs(self):
151155
with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
152156
compose(file.name)
153157

158+
def test_optimizer_factory(self):
159+
model = torch.nn.Linear(2, 2)
160+
161+
adam, sched = ImplicitronOptimizerFactory(breed="Adam")(0, model)
162+
self.assertIsInstance(adam, torch.optim.Adam)
163+
sgd, sched = ImplicitronOptimizerFactory(breed="SGD")(0, model)
164+
self.assertIsInstance(sgd, torch.optim.SGD)
165+
adagrad, sched = ImplicitronOptimizerFactory(breed="Adagrad")(0, model)
166+
self.assertIsInstance(adagrad, torch.optim.Adagrad)
167+
154168

155169
class TestNerfRepro(unittest.TestCase):
156170
@unittest.skip("This test runs full blender training.")

0 commit comments

Comments
 (0)