Skip to content

Commit 01759ca

Browse files
minhua-chenfacebook-github-bot
authored andcommitted
AdagradW (fbgemm frontend) (#3850)
Summary: X-link: facebookresearch/FBGEMM#938 AdagradW (fbgemm frontend) Reviewed By: sryap, csmiler Differential Revision: D71102031
1 parent 4587ad0 commit 01759ca

File tree

2 files changed

+74
-14
lines changed

2 files changed

+74
-14
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class CounterWeightDecayMode(enum.IntEnum):
9999
NONE = 0
100100
L2 = 1
101101
DECOUPLE = 2
102+
ADAGRADW = 3
102103

103104

104105
class StepMode(enum.IntEnum):
@@ -2656,7 +2657,17 @@ def get_optimizer_state(self) -> List[Dict[str, torch.Tensor]]:
26562657
):
26572658
list_of_state_dict = [
26582659
(
2659-
{"sum": states[0], "prev_iter": states[1], "row_counter": states[2]}
2660+
{
2661+
"sum": states[0],
2662+
"prev_iter": states[1],
2663+
"row_counter": states[2],
2664+
**(
2665+
{"iter": self.iter}
2666+
if self.optimizer_args.weight_decay_mode
2667+
== CounterWeightDecayMode.ADAGRADW.value
2668+
else {}
2669+
),
2670+
}
26602671
if self._used_rowwise_adagrad_with_counter
26612672
else (
26622673
{

fbgemm_gpu/test/tbe/training/backward_optimizers_test.py

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def execute_backward_optimizers_( # noqa C901
107107
uvm_non_rowwise_momentum: bool = False,
108108
optimizer_state_dtypes: Optional[Dict[str, SparseType]] = None,
109109
use_rowwise_bias_correction: bool = False,
110+
counter_weight_decay_mode: Optional[CounterWeightDecayMode] = None,
110111
) -> None:
111112
# NOTE: limit (T * B * L * D) to avoid timeout for CPU version!
112113

@@ -278,15 +279,31 @@ def execute_backward_optimizers_( # noqa C901
278279
optimizer_kwargs["weight_decay_mode"] = weight_decay_mode
279280

280281
if weight_decay_mode == WeightDecayMode.COUNTER:
281-
counter_based_regularization = CounterBasedRegularizationDefinition(
282-
counter_weight_decay_mode=CounterWeightDecayMode.DECOUPLE,
283-
counter_halflife=20000,
284-
adjustment_iter=24000,
285-
adjustment_ub=0.1,
286-
learning_rate_mode=LearningRateMode.TAIL_ID_LR_DECREASE,
287-
grad_sum_decay=GradSumDecay.NO_DECAY,
288-
tail_id_threshold=TailIdThreshold(val=1000, is_ratio=False),
282+
opt_arg_weight_decay_mode: CounterWeightDecayMode = (
283+
counter_weight_decay_mode
284+
if counter_weight_decay_mode is not None
285+
else CounterWeightDecayMode.DECOUPLE
289286
)
287+
if opt_arg_weight_decay_mode != CounterWeightDecayMode.ADAGRADW:
288+
counter_based_regularization = CounterBasedRegularizationDefinition(
289+
counter_weight_decay_mode=opt_arg_weight_decay_mode,
290+
counter_halflife=20000,
291+
adjustment_iter=24000,
292+
adjustment_ub=0.1,
293+
learning_rate_mode=LearningRateMode.TAIL_ID_LR_DECREASE,
294+
grad_sum_decay=GradSumDecay.NO_DECAY,
295+
tail_id_threshold=TailIdThreshold(val=1000, is_ratio=False),
296+
)
297+
else:
298+
counter_based_regularization = CounterBasedRegularizationDefinition(
299+
counter_weight_decay_mode=CounterWeightDecayMode.ADAGRADW,
300+
counter_halflife=-1,
301+
adjustment_iter=-1,
302+
adjustment_ub=0.1,
303+
learning_rate_mode=LearningRateMode.EQUAL,
304+
grad_sum_decay=GradSumDecay.NO_DECAY,
305+
tail_id_threshold=TailIdThreshold(val=1000, is_ratio=False),
306+
)
290307
# fmt: off
291308
optimizer_kwargs["counter_based_regularization"] = (
292309
counter_based_regularization
@@ -545,6 +562,12 @@ def execute_backward_optimizers_( # noqa C901
545562
WeightDecayMode.COWCLIP,
546563
):
547564
expected_keys.update(["prev_iter", "row_counter"])
565+
if (
566+
weight_decay_mode == WeightDecayMode.COUNTER
567+
and counter_based_regularization.counter_weight_decay_mode
568+
== CounterWeightDecayMode.ADAGRADW
569+
):
570+
expected_keys.update(["iter"])
548571
assert set(optimizer_states_dict.keys()) == expected_keys
549572

550573
if optimizer in (OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM):
@@ -778,12 +801,13 @@ def _get_grad_from_counter_adagrad(
778801
l2_wd = 1.0 if counter_weight_decay_mode == CounterWeightDecayMode.L2 else 0.0
779802

780803
if counter_halflife > 0:
781-
freq = torch.tensor([counter_halflife]) / row_counter
804+
freq = torch.where(
805+
row_counter > 0,
806+
torch.tensor([counter_halflife]) / row_counter,
807+
torch.tensor([1.0]),
808+
)
782809

783-
if isinstance(regularization, CounterBasedRegularizationDefinition):
784-
dense_cpu_grad += l2_wd * freq * weight_decay * weights
785-
else:
786-
dense_cpu_grad += l2_wd * weight_decay * weights
810+
dense_cpu_grad += l2_wd * weight_decay * weights
787811
return dense_cpu_grad, row_counter, freq
788812

789813
def _get_wts_from_counter_adagrad_using_counter(
@@ -863,6 +887,21 @@ def _get_wts_from_counter_adagrad_using_counter(
863887
exp_reg_correction = 1.0 - freq * weight_decay * learning_rate
864888
elif counter_weight_decay_mode == CounterWeightDecayMode.L2:
865889
exp_reg_correction = 1.0 - freq * weight_decay * multiplier
890+
else:
891+
if adjustment_iter <= 0 or (
892+
adjustment_iter > 0 and iter_ > adjustment_iter
893+
):
894+
if counter_weight_decay_mode == CounterWeightDecayMode.ADAGRADW:
895+
adjusted_multiplier = torch.where(
896+
row_counter > 0,
897+
multiplier * torch.sqrt(row_counter),
898+
torch.Tensor([0.0]),
899+
)
900+
exp_reg_correction = torch.where(
901+
row_counter > 0,
902+
1.0 - weight_decay * learning_rate,
903+
torch.Tensor([1.0]),
904+
)
866905

867906
weights = exp_reg_correction * weights - adjusted_multiplier * dense_cpu_grad
868907
return weights
@@ -1129,6 +1168,14 @@ def test_backward_optimizers_partial_rowwise_adam_bf16_momentum( # noqa C901
11291168
WeightDecayMode.COWCLIP,
11301169
]
11311170
),
1171+
counter_weight_decay_mode=st.sampled_from(
1172+
[
1173+
CounterWeightDecayMode.NONE,
1174+
CounterWeightDecayMode.L2,
1175+
CounterWeightDecayMode.DECOUPLE,
1176+
CounterWeightDecayMode.ADAGRADW,
1177+
]
1178+
),
11321179
)
11331180
@settings(
11341181
verbosity=VERBOSITY,
@@ -1152,6 +1199,7 @@ def test_backward_optimizers_adagrad( # noqa C901
11521199
pooling_mode: PoolingMode,
11531200
use_cpu: bool,
11541201
weight_decay_mode: WeightDecayMode,
1202+
counter_weight_decay_mode: CounterWeightDecayMode,
11551203
) -> None:
11561204
if (
11571205
pooling_mode == PoolingMode.NONE
@@ -1172,6 +1220,7 @@ def test_backward_optimizers_adagrad( # noqa C901
11721220
pooling_mode,
11731221
use_cpu,
11741222
weight_decay_mode,
1223+
counter_weight_decay_mode=counter_weight_decay_mode,
11751224
)
11761225

11771226
@given(

0 commit comments

Comments
 (0)