Skip to content

Commit d82f05f

Browse files
minhua-chenfacebook-github-bot
authored andcommitted
AdagradW (fbgemm frontend) (#3850)
Summary: X-link: facebookresearch/FBGEMM#938 AdagradW (fbgemm frontend) Differential Revision: D71102031
1 parent 00cd37b commit d82f05f

File tree

2 files changed

+53
-8
lines changed

2 files changed

+53
-8
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class CounterWeightDecayMode(enum.IntEnum):
9999
NONE = 0
100100
L2 = 1
101101
DECOUPLE = 2
102+
ADAGRADW = 3
102103

103104

104105
class StepMode(enum.IntEnum):
@@ -2657,6 +2658,12 @@ def get_optimizer_state(self) -> List[Dict[str, torch.Tensor]]:
26572658
list_of_state_dict = [
26582659
(
26592660
{"sum": states[0], "prev_iter": states[1], "row_counter": states[2]}
2661+
| (
2662+
{"iter": self.iter}
2663+
if self.optimizer_args.weight_decay_mode
2664+
== CounterWeightDecayMode.ADAGRADW.value
2665+
else {}
2666+
)
26602667
if self._used_rowwise_adagrad_with_counter
26612668
else (
26622669
{

fbgemm_gpu/test/tbe/training/backward_optimizers_test.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def execute_backward_optimizers_( # noqa C901
107107
uvm_non_rowwise_momentum: bool = False,
108108
optimizer_state_dtypes: Optional[Dict[str, SparseType]] = None,
109109
use_rowwise_bias_correction: bool = False,
110+
counter_weight_decay_mode: Optional[CounterWeightDecayMode] = None,
110111
) -> None:
111112
# NOTE: limit (T * B * L * D) to avoid timeout for CPU version!
112113

@@ -152,6 +153,11 @@ def execute_backward_optimizers_( # noqa C901
152153
return
153154
if mixed_B and (use_cpu or pooling_mode == PoolingMode.NONE):
154155
return
156+
if (
157+
pooling_mode == PoolingMode.NONE
158+
and counter_weight_decay_mode == CounterWeightDecayMode.ADAGRADW
159+
):
160+
return
155161

156162
emb_op = SplitTableBatchedEmbeddingBagsCodegen
157163
if pooling_mode == PoolingMode.SUM:
@@ -278,12 +284,22 @@ def execute_backward_optimizers_( # noqa C901
278284
optimizer_kwargs["weight_decay_mode"] = weight_decay_mode
279285

280286
if weight_decay_mode == WeightDecayMode.COUNTER:
287+
opt_arg_weight_decay_mode: CounterWeightDecayMode = (
288+
counter_weight_decay_mode
289+
if counter_weight_decay_mode is not None
290+
else CounterWeightDecayMode.DECOUPLE
291+
)
292+
opt_arg_learning_rate_mode: LearningRateMode = (
293+
LearningRateMode.TAIL_ID_LR_DECREASE
294+
if opt_arg_weight_decay_mode != CounterWeightDecayMode.ADAGRADW
295+
else LearningRateMode.EQUAL
296+
)
281297
counter_based_regularization = CounterBasedRegularizationDefinition(
282-
counter_weight_decay_mode=CounterWeightDecayMode.DECOUPLE,
298+
counter_weight_decay_mode=opt_arg_weight_decay_mode,
283299
counter_halflife=20000,
284-
adjustment_iter=24000,
300+
adjustment_iter=-1,
285301
adjustment_ub=0.1,
286-
learning_rate_mode=LearningRateMode.TAIL_ID_LR_DECREASE,
302+
learning_rate_mode=opt_arg_learning_rate_mode,
287303
grad_sum_decay=GradSumDecay.NO_DECAY,
288304
tail_id_threshold=TailIdThreshold(val=1000, is_ratio=False),
289305
)
@@ -545,6 +561,12 @@ def execute_backward_optimizers_( # noqa C901
545561
WeightDecayMode.COWCLIP,
546562
):
547563
expected_keys.update(["prev_iter", "row_counter"])
564+
if (
565+
weight_decay_mode == WeightDecayMode.COUNTER
566+
and counter_based_regularization.counter_weight_decay_mode
567+
== CounterWeightDecayMode.ADAGRADW
568+
):
569+
expected_keys.update(["iter"])
548570
assert set(optimizer_states_dict.keys()) == expected_keys
549571

550572
if optimizer in (OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM):
@@ -778,12 +800,13 @@ def _get_grad_from_counter_adagrad(
778800
l2_wd = 1.0 if counter_weight_decay_mode == CounterWeightDecayMode.L2 else 0.0
779801

780802
if counter_halflife > 0:
781-
freq = torch.tensor([counter_halflife]) / row_counter
803+
freq = torch.where(
804+
row_counter > 0,
805+
torch.tensor([counter_halflife]) / row_counter,
806+
torch.tensor([1.0]),
807+
)
782808

783-
if isinstance(regularization, CounterBasedRegularizationDefinition):
784-
dense_cpu_grad += l2_wd * freq * weight_decay * weights
785-
else:
786-
dense_cpu_grad += l2_wd * weight_decay * weights
809+
dense_cpu_grad += l2_wd * weight_decay * weights
787810
return dense_cpu_grad, row_counter, freq
788811

789812
def _get_wts_from_counter_adagrad_using_counter(
@@ -863,6 +886,11 @@ def _get_wts_from_counter_adagrad_using_counter(
863886
exp_reg_correction = 1.0 - freq * weight_decay * learning_rate
864887
elif counter_weight_decay_mode == CounterWeightDecayMode.L2:
865888
exp_reg_correction = 1.0 - freq * weight_decay * multiplier
889+
elif counter_weight_decay_mode == CounterWeightDecayMode.ADAGRADW:
890+
adjusted_multiplier = multiplier * torch.sqrt(row_counter)
891+
exp_reg_correction = torch.where(
892+
row_counter > 0, 1.0 - weight_decay * learning_rate, 1.0
893+
)
866894

867895
weights = exp_reg_correction * weights - adjusted_multiplier * dense_cpu_grad
868896
return weights
@@ -1129,6 +1157,14 @@ def test_backward_optimizers_partial_rowwise_adam_bf16_momentum( # noqa C901
11291157
WeightDecayMode.COWCLIP,
11301158
]
11311159
),
1160+
counter_weight_decay_mode=st.sampled_from(
1161+
[
1162+
CounterWeightDecayMode.NONE,
1163+
CounterWeightDecayMode.L2,
1164+
CounterWeightDecayMode.DECOUPLE,
1165+
CounterWeightDecayMode.ADAGRADW,
1166+
]
1167+
),
11321168
)
11331169
@settings(
11341170
verbosity=VERBOSITY,
@@ -1152,6 +1188,7 @@ def test_backward_optimizers_adagrad( # noqa C901
11521188
pooling_mode: PoolingMode,
11531189
use_cpu: bool,
11541190
weight_decay_mode: WeightDecayMode,
1191+
counter_weight_decay_mode: CounterWeightDecayMode,
11551192
) -> None:
11561193
if (
11571194
pooling_mode == PoolingMode.NONE
@@ -1172,6 +1209,7 @@ def test_backward_optimizers_adagrad( # noqa C901
11721209
pooling_mode,
11731210
use_cpu,
11741211
weight_decay_mode,
1212+
counter_weight_decay_mode=counter_weight_decay_mode,
11751213
)
11761214

11771215
@given(

0 commit comments

Comments
 (0)