Skip to content

Commit c0a4b05

Browse files
minhua-chenfacebook-github-bot
authored andcommitted
AdagradW (fbgemm frontend) (pytorch#3850)
Summary: X-link: facebookresearch/FBGEMM#938 AdagradW (fbgemm frontend) Differential Revision: D71102031
1 parent ee4c88b commit c0a4b05

File tree

2 files changed

+54
-8
lines changed

2 files changed

+54
-8
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class CounterWeightDecayMode(enum.IntEnum):
9999
NONE = 0
100100
L2 = 1
101101
DECOUPLE = 2
102+
ADAGRADW = 3
102103

103104

104105
class StepMode(enum.IntEnum):
@@ -2657,6 +2658,11 @@ def get_optimizer_state(self) -> List[Dict[str, torch.Tensor]]:
26572658
list_of_state_dict = [
26582659
(
26592660
{"sum": states[0], "prev_iter": states[1], "row_counter": states[2]}
2661+
| (
2662+
{"iter": self.iter}
2663+
if self.optimizer_args.weight_decay_mode == 3
2664+
else {}
2665+
)
26602666
if self._used_rowwise_adagrad_with_counter
26612667
else (
26622668
{

fbgemm_gpu/test/tbe/training/backward_optimizers_test.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ def execute_backward_optimizers_( # noqa C901
107107
uvm_non_rowwise_momentum: bool = False,
108108
optimizer_state_dtypes: Optional[Dict[str, SparseType]] = None,
109109
use_rowwise_bias_correction: bool = False,
110+
counter_weight_decay_mode: Optional[
111+
CounterWeightDecayMode
112+
] = CounterWeightDecayMode.DECOUPLE,
110113
) -> None:
111114
# NOTE: limit (T * B * L * D) to avoid timeout for CPU version!
112115

@@ -152,6 +155,11 @@ def execute_backward_optimizers_( # noqa C901
152155
return
153156
if mixed_B and (use_cpu or pooling_mode == PoolingMode.NONE):
154157
return
158+
if (
159+
pooling_mode == PoolingMode.NONE
160+
and counter_weight_decay_mode == CounterWeightDecayMode.ADAGRADW
161+
):
162+
return
155163

156164
emb_op = SplitTableBatchedEmbeddingBagsCodegen
157165
if pooling_mode == PoolingMode.SUM:
@@ -278,12 +286,22 @@ def execute_backward_optimizers_( # noqa C901
278286
optimizer_kwargs["weight_decay_mode"] = weight_decay_mode
279287

280288
if weight_decay_mode == WeightDecayMode.COUNTER:
289+
opt_arg_weight_decay_mode: CounterWeightDecayMode = (
290+
counter_weight_decay_mode
291+
if counter_weight_decay_mode is not None
292+
else CounterWeightDecayMode.DECOUPLE
293+
)
294+
opt_arg_learning_rate_mode: LearningRateMode = (
295+
LearningRateMode.TAIL_ID_LR_DECREASE
296+
if opt_arg_weight_decay_mode != CounterWeightDecayMode.ADAGRADW
297+
else LearningRateMode.EQUAL
298+
)
281299
counter_based_regularization = CounterBasedRegularizationDefinition(
282-
counter_weight_decay_mode=CounterWeightDecayMode.DECOUPLE,
300+
counter_weight_decay_mode=opt_arg_weight_decay_mode,
283301
counter_halflife=20000,
284-
adjustment_iter=24000,
302+
adjustment_iter=-1,
285303
adjustment_ub=0.1,
286-
learning_rate_mode=LearningRateMode.TAIL_ID_LR_DECREASE,
304+
learning_rate_mode=opt_arg_learning_rate_mode,
287305
grad_sum_decay=GradSumDecay.NO_DECAY,
288306
tail_id_threshold=TailIdThreshold(val=1000, is_ratio=False),
289307
)
@@ -545,6 +563,12 @@ def execute_backward_optimizers_( # noqa C901
545563
WeightDecayMode.COWCLIP,
546564
):
547565
expected_keys.update(["prev_iter", "row_counter"])
566+
if (
567+
weight_decay_mode == WeightDecayMode.COUNTER
568+
and counter_based_regularization.counter_weight_decay_mode
569+
== CounterWeightDecayMode.ADAGRADW
570+
):
571+
expected_keys.update(["iter"])
548572
assert set(optimizer_states_dict.keys()) == expected_keys
549573

550574
if optimizer in (OptimType.PARTIAL_ROWWISE_ADAM, OptimType.ADAM):
@@ -778,12 +802,13 @@ def _get_grad_from_counter_adagrad(
778802
l2_wd = 1.0 if counter_weight_decay_mode == CounterWeightDecayMode.L2 else 0.0
779803

780804
if counter_halflife > 0:
781-
freq = torch.tensor([counter_halflife]) / row_counter
805+
freq = torch.where(
806+
row_counter > 0,
807+
torch.tensor([counter_halflife]) / row_counter,
808+
torch.tensor([1.0]),
809+
)
782810

783-
if isinstance(regularization, CounterBasedRegularizationDefinition):
784-
dense_cpu_grad += l2_wd * freq * weight_decay * weights
785-
else:
786-
dense_cpu_grad += l2_wd * weight_decay * weights
811+
dense_cpu_grad += l2_wd * weight_decay * weights
787812
return dense_cpu_grad, row_counter, freq
788813

789814
def _get_wts_from_counter_adagrad_using_counter(
@@ -863,6 +888,11 @@ def _get_wts_from_counter_adagrad_using_counter(
863888
exp_reg_correction = 1.0 - freq * weight_decay * learning_rate
864889
elif counter_weight_decay_mode == CounterWeightDecayMode.L2:
865890
exp_reg_correction = 1.0 - freq * weight_decay * multiplier
891+
elif counter_weight_decay_mode == CounterWeightDecayMode.ADAGRADW:
892+
adjusted_multiplier = multiplier * torch.sqrt(row_counter)
893+
exp_reg_correction = torch.where(
894+
row_counter > 0, 1.0 - weight_decay * learning_rate, 1.0
895+
)
866896

867897
weights = exp_reg_correction * weights - adjusted_multiplier * dense_cpu_grad
868898
return weights
@@ -1129,6 +1159,14 @@ def test_backward_optimizers_partial_rowwise_adam_bf16_momentum( # noqa C901
11291159
WeightDecayMode.COWCLIP,
11301160
]
11311161
),
1162+
counter_weight_decay_mode=st.sampled_from(
1163+
[
1164+
CounterWeightDecayMode.NONE,
1165+
CounterWeightDecayMode.L2,
1166+
CounterWeightDecayMode.DECOUPLE,
1167+
CounterWeightDecayMode.ADAGRADW,
1168+
]
1169+
),
11321170
)
11331171
@settings(
11341172
verbosity=VERBOSITY,
@@ -1152,6 +1190,7 @@ def test_backward_optimizers_adagrad( # noqa C901
11521190
pooling_mode: PoolingMode,
11531191
use_cpu: bool,
11541192
weight_decay_mode: WeightDecayMode,
1193+
counter_weight_decay_mode: CounterWeightDecayMode,
11551194
) -> None:
11561195
if (
11571196
pooling_mode == PoolingMode.NONE
@@ -1172,6 +1211,7 @@ def test_backward_optimizers_adagrad( # noqa C901
11721211
pooling_mode,
11731212
use_cpu,
11741213
weight_decay_mode,
1214+
counter_weight_decay_mode=counter_weight_decay_mode,
11751215
)
11761216

11771217
@given(

0 commit comments

Comments
 (0)