@@ -107,6 +107,7 @@ def execute_backward_optimizers_( # noqa C901
107
107
uvm_non_rowwise_momentum : bool = False ,
108
108
optimizer_state_dtypes : Optional [Dict [str , SparseType ]] = None ,
109
109
use_rowwise_bias_correction : bool = False ,
110
+ counter_weight_decay_mode : Optional [CounterWeightDecayMode ] = None ,
110
111
) -> None :
111
112
# NOTE: limit (T * B * L * D) to avoid timeout for CPU version!
112
113
@@ -152,6 +153,11 @@ def execute_backward_optimizers_( # noqa C901
152
153
return
153
154
if mixed_B and (use_cpu or pooling_mode == PoolingMode .NONE ):
154
155
return
156
+ if (
157
+ pooling_mode == PoolingMode .NONE
158
+ and counter_weight_decay_mode == CounterWeightDecayMode .ADAGRADW
159
+ ):
160
+ return
155
161
156
162
emb_op = SplitTableBatchedEmbeddingBagsCodegen
157
163
if pooling_mode == PoolingMode .SUM :
@@ -278,12 +284,22 @@ def execute_backward_optimizers_( # noqa C901
278
284
optimizer_kwargs ["weight_decay_mode" ] = weight_decay_mode
279
285
280
286
if weight_decay_mode == WeightDecayMode .COUNTER :
287
+ opt_arg_weight_decay_mode : CounterWeightDecayMode = (
288
+ counter_weight_decay_mode
289
+ if counter_weight_decay_mode is not None
290
+ else CounterWeightDecayMode .DECOUPLE
291
+ )
292
+ opt_arg_learning_rate_mode : LearningRateMode = (
293
+ LearningRateMode .TAIL_ID_LR_DECREASE
294
+ if opt_arg_weight_decay_mode != CounterWeightDecayMode .ADAGRADW
295
+ else LearningRateMode .EQUAL
296
+ )
281
297
counter_based_regularization = CounterBasedRegularizationDefinition (
282
- counter_weight_decay_mode = CounterWeightDecayMode . DECOUPLE ,
298
+ counter_weight_decay_mode = opt_arg_weight_decay_mode ,
283
299
counter_halflife = 20000 ,
284
- adjustment_iter = 24000 ,
300
+ adjustment_iter = - 1 ,
285
301
adjustment_ub = 0.1 ,
286
- learning_rate_mode = LearningRateMode . TAIL_ID_LR_DECREASE ,
302
+ learning_rate_mode = opt_arg_learning_rate_mode ,
287
303
grad_sum_decay = GradSumDecay .NO_DECAY ,
288
304
tail_id_threshold = TailIdThreshold (val = 1000 , is_ratio = False ),
289
305
)
@@ -545,6 +561,12 @@ def execute_backward_optimizers_( # noqa C901
545
561
WeightDecayMode .COWCLIP ,
546
562
):
547
563
expected_keys .update (["prev_iter" , "row_counter" ])
564
+ if (
565
+ weight_decay_mode == WeightDecayMode .COUNTER
566
+ and counter_based_regularization .counter_weight_decay_mode
567
+ == CounterWeightDecayMode .ADAGRADW
568
+ ):
569
+ expected_keys .update (["iter" ])
548
570
assert set (optimizer_states_dict .keys ()) == expected_keys
549
571
550
572
if optimizer in (OptimType .PARTIAL_ROWWISE_ADAM , OptimType .ADAM ):
@@ -778,12 +800,13 @@ def _get_grad_from_counter_adagrad(
778
800
l2_wd = 1.0 if counter_weight_decay_mode == CounterWeightDecayMode .L2 else 0.0
779
801
780
802
if counter_halflife > 0 :
781
- freq = torch .tensor ([counter_halflife ]) / row_counter
803
+ freq = torch .where (
804
+ row_counter > 0 ,
805
+ torch .tensor ([counter_halflife ]) / row_counter ,
806
+ torch .tensor ([1.0 ]),
807
+ )
782
808
783
- if isinstance (regularization , CounterBasedRegularizationDefinition ):
784
- dense_cpu_grad += l2_wd * freq * weight_decay * weights
785
- else :
786
- dense_cpu_grad += l2_wd * weight_decay * weights
809
+ dense_cpu_grad += l2_wd * weight_decay * weights
787
810
return dense_cpu_grad , row_counter , freq
788
811
789
812
def _get_wts_from_counter_adagrad_using_counter (
@@ -863,6 +886,11 @@ def _get_wts_from_counter_adagrad_using_counter(
863
886
exp_reg_correction = 1.0 - freq * weight_decay * learning_rate
864
887
elif counter_weight_decay_mode == CounterWeightDecayMode .L2 :
865
888
exp_reg_correction = 1.0 - freq * weight_decay * multiplier
889
+ elif counter_weight_decay_mode == CounterWeightDecayMode .ADAGRADW :
890
+ adjusted_multiplier = multiplier * torch .sqrt (row_counter )
891
+ exp_reg_correction = torch .where (
892
+ row_counter > 0 , 1.0 - weight_decay * learning_rate , 1.0
893
+ )
866
894
867
895
weights = exp_reg_correction * weights - adjusted_multiplier * dense_cpu_grad
868
896
return weights
@@ -1129,6 +1157,14 @@ def test_backward_optimizers_partial_rowwise_adam_bf16_momentum( # noqa C901
1129
1157
WeightDecayMode .COWCLIP ,
1130
1158
]
1131
1159
),
1160
+ counter_weight_decay_mode = st .sampled_from (
1161
+ [
1162
+ CounterWeightDecayMode .NONE ,
1163
+ CounterWeightDecayMode .L2 ,
1164
+ CounterWeightDecayMode .DECOUPLE ,
1165
+ CounterWeightDecayMode .ADAGRADW ,
1166
+ ]
1167
+ ),
1132
1168
)
1133
1169
@settings (
1134
1170
verbosity = VERBOSITY ,
@@ -1152,6 +1188,7 @@ def test_backward_optimizers_adagrad( # noqa C901
1152
1188
pooling_mode : PoolingMode ,
1153
1189
use_cpu : bool ,
1154
1190
weight_decay_mode : WeightDecayMode ,
1191
+ counter_weight_decay_mode : CounterWeightDecayMode ,
1155
1192
) -> None :
1156
1193
if (
1157
1194
pooling_mode == PoolingMode .NONE
@@ -1172,6 +1209,7 @@ def test_backward_optimizers_adagrad( # noqa C901
1172
1209
pooling_mode ,
1173
1210
use_cpu ,
1174
1211
weight_decay_mode ,
1212
+ counter_weight_decay_mode = counter_weight_decay_mode ,
1175
1213
)
1176
1214
1177
1215
@given (
0 commit comments