@@ -107,6 +107,9 @@ def execute_backward_optimizers_( # noqa C901
107
107
uvm_non_rowwise_momentum : bool = False ,
108
108
optimizer_state_dtypes : Optional [Dict [str , SparseType ]] = None ,
109
109
use_rowwise_bias_correction : bool = False ,
110
+ counter_weight_decay_mode : Optional [
111
+ CounterWeightDecayMode
112
+ ] = CounterWeightDecayMode .DECOUPLE ,
110
113
) -> None :
111
114
# NOTE: limit (T * B * L * D) to avoid timeout for CPU version!
112
115
@@ -152,6 +155,11 @@ def execute_backward_optimizers_( # noqa C901
152
155
return
153
156
if mixed_B and (use_cpu or pooling_mode == PoolingMode .NONE ):
154
157
return
158
+ if (
159
+ pooling_mode == PoolingMode .NONE
160
+ and counter_weight_decay_mode == CounterWeightDecayMode .ADAGRADW
161
+ ):
162
+ return
155
163
156
164
emb_op = SplitTableBatchedEmbeddingBagsCodegen
157
165
if pooling_mode == PoolingMode .SUM :
@@ -278,12 +286,22 @@ def execute_backward_optimizers_( # noqa C901
278
286
optimizer_kwargs ["weight_decay_mode" ] = weight_decay_mode
279
287
280
288
if weight_decay_mode == WeightDecayMode .COUNTER :
289
+ opt_arg_weight_decay_mode : CounterWeightDecayMode = (
290
+ counter_weight_decay_mode
291
+ if counter_weight_decay_mode is not None
292
+ else CounterWeightDecayMode .DECOUPLE
293
+ )
294
+ opt_arg_learning_rate_mode : LearningRateMode = (
295
+ LearningRateMode .TAIL_ID_LR_DECREASE
296
+ if opt_arg_weight_decay_mode != CounterWeightDecayMode .ADAGRADW
297
+ else LearningRateMode .EQUAL
298
+ )
281
299
counter_based_regularization = CounterBasedRegularizationDefinition (
282
- counter_weight_decay_mode = CounterWeightDecayMode . DECOUPLE ,
300
+ counter_weight_decay_mode = opt_arg_weight_decay_mode ,
283
301
counter_halflife = 20000 ,
284
- adjustment_iter = 24000 ,
302
+ adjustment_iter = - 1 ,
285
303
adjustment_ub = 0.1 ,
286
- learning_rate_mode = LearningRateMode . TAIL_ID_LR_DECREASE ,
304
+ learning_rate_mode = opt_arg_learning_rate_mode ,
287
305
grad_sum_decay = GradSumDecay .NO_DECAY ,
288
306
tail_id_threshold = TailIdThreshold (val = 1000 , is_ratio = False ),
289
307
)
@@ -545,6 +563,12 @@ def execute_backward_optimizers_( # noqa C901
545
563
WeightDecayMode .COWCLIP ,
546
564
):
547
565
expected_keys .update (["prev_iter" , "row_counter" ])
566
+ if (
567
+ weight_decay_mode == WeightDecayMode .COUNTER
568
+ and counter_based_regularization .counter_weight_decay_mode
569
+ == CounterWeightDecayMode .ADAGRADW
570
+ ):
571
+ expected_keys .update (["iter" ])
548
572
assert set (optimizer_states_dict .keys ()) == expected_keys
549
573
550
574
if optimizer in (OptimType .PARTIAL_ROWWISE_ADAM , OptimType .ADAM ):
@@ -778,12 +802,13 @@ def _get_grad_from_counter_adagrad(
778
802
l2_wd = 1.0 if counter_weight_decay_mode == CounterWeightDecayMode .L2 else 0.0
779
803
780
804
if counter_halflife > 0 :
781
- freq = torch .tensor ([counter_halflife ]) / row_counter
805
+ freq = torch .where (
806
+ row_counter > 0 ,
807
+ torch .tensor ([counter_halflife ]) / row_counter ,
808
+ torch .tensor ([1.0 ]),
809
+ )
782
810
783
- if isinstance (regularization , CounterBasedRegularizationDefinition ):
784
- dense_cpu_grad += l2_wd * freq * weight_decay * weights
785
- else :
786
- dense_cpu_grad += l2_wd * weight_decay * weights
811
+ dense_cpu_grad += l2_wd * weight_decay * weights
787
812
return dense_cpu_grad , row_counter , freq
788
813
789
814
def _get_wts_from_counter_adagrad_using_counter (
@@ -863,6 +888,11 @@ def _get_wts_from_counter_adagrad_using_counter(
863
888
exp_reg_correction = 1.0 - freq * weight_decay * learning_rate
864
889
elif counter_weight_decay_mode == CounterWeightDecayMode .L2 :
865
890
exp_reg_correction = 1.0 - freq * weight_decay * multiplier
891
+ elif counter_weight_decay_mode == CounterWeightDecayMode .ADAGRADW :
892
+ adjusted_multiplier = multiplier * torch .sqrt (row_counter )
893
+ exp_reg_correction = torch .where (
894
+ row_counter > 0 , 1.0 - weight_decay * learning_rate , 1.0
895
+ )
866
896
867
897
weights = exp_reg_correction * weights - adjusted_multiplier * dense_cpu_grad
868
898
return weights
@@ -1129,6 +1159,14 @@ def test_backward_optimizers_partial_rowwise_adam_bf16_momentum( # noqa C901
1129
1159
WeightDecayMode .COWCLIP ,
1130
1160
]
1131
1161
),
1162
+ counter_weight_decay_mode = st .sampled_from (
1163
+ [
1164
+ CounterWeightDecayMode .NONE ,
1165
+ CounterWeightDecayMode .L2 ,
1166
+ CounterWeightDecayMode .DECOUPLE ,
1167
+ CounterWeightDecayMode .ADAGRADW ,
1168
+ ]
1169
+ ),
1132
1170
)
1133
1171
@settings (
1134
1172
verbosity = VERBOSITY ,
@@ -1152,6 +1190,7 @@ def test_backward_optimizers_adagrad( # noqa C901
1152
1190
pooling_mode : PoolingMode ,
1153
1191
use_cpu : bool ,
1154
1192
weight_decay_mode : WeightDecayMode ,
1193
+ counter_weight_decay_mode : CounterWeightDecayMode ,
1155
1194
) -> None :
1156
1195
if (
1157
1196
pooling_mode == PoolingMode .NONE
@@ -1172,6 +1211,7 @@ def test_backward_optimizers_adagrad( # noqa C901
1172
1211
pooling_mode ,
1173
1212
use_cpu ,
1174
1213
weight_decay_mode ,
1214
+ counter_weight_decay_mode = counter_weight_decay_mode ,
1175
1215
)
1176
1216
1177
1217
@given (
0 commit comments