@@ -107,6 +107,7 @@ def execute_backward_optimizers_( # noqa C901
107
107
uvm_non_rowwise_momentum : bool = False ,
108
108
optimizer_state_dtypes : Optional [Dict [str , SparseType ]] = None ,
109
109
use_rowwise_bias_correction : bool = False ,
110
+ counter_weight_decay_mode : Optional [CounterWeightDecayMode ] = None ,
110
111
) -> None :
111
112
# NOTE: limit (T * B * L * D) to avoid timeout for CPU version!
112
113
@@ -278,15 +279,31 @@ def execute_backward_optimizers_( # noqa C901
278
279
optimizer_kwargs ["weight_decay_mode" ] = weight_decay_mode
279
280
280
281
if weight_decay_mode == WeightDecayMode .COUNTER :
281
- counter_based_regularization = CounterBasedRegularizationDefinition (
282
- counter_weight_decay_mode = CounterWeightDecayMode .DECOUPLE ,
283
- counter_halflife = 20000 ,
284
- adjustment_iter = 24000 ,
285
- adjustment_ub = 0.1 ,
286
- learning_rate_mode = LearningRateMode .TAIL_ID_LR_DECREASE ,
287
- grad_sum_decay = GradSumDecay .NO_DECAY ,
288
- tail_id_threshold = TailIdThreshold (val = 1000 , is_ratio = False ),
282
+ opt_arg_weight_decay_mode : CounterWeightDecayMode = (
283
+ counter_weight_decay_mode
284
+ if counter_weight_decay_mode is not None
285
+ else CounterWeightDecayMode .DECOUPLE
289
286
)
287
+ if opt_arg_weight_decay_mode != CounterWeightDecayMode .ADAGRADW :
288
+ counter_based_regularization = CounterBasedRegularizationDefinition (
289
+ counter_weight_decay_mode = opt_arg_weight_decay_mode ,
290
+ counter_halflife = 20000 ,
291
+ adjustment_iter = 24000 ,
292
+ adjustment_ub = 0.1 ,
293
+ learning_rate_mode = LearningRateMode .TAIL_ID_LR_DECREASE ,
294
+ grad_sum_decay = GradSumDecay .NO_DECAY ,
295
+ tail_id_threshold = TailIdThreshold (val = 1000 , is_ratio = False ),
296
+ )
297
+ else :
298
+ counter_based_regularization = CounterBasedRegularizationDefinition (
299
+ counter_weight_decay_mode = CounterWeightDecayMode .ADAGRADW ,
300
+ counter_halflife = - 1 ,
301
+ adjustment_iter = - 1 ,
302
+ adjustment_ub = 0.1 ,
303
+ learning_rate_mode = LearningRateMode .EQUAL ,
304
+ grad_sum_decay = GradSumDecay .NO_DECAY ,
305
+ tail_id_threshold = TailIdThreshold (val = 1000 , is_ratio = False ),
306
+ )
290
307
# fmt: off
291
308
optimizer_kwargs ["counter_based_regularization" ] = (
292
309
counter_based_regularization
@@ -545,6 +562,12 @@ def execute_backward_optimizers_( # noqa C901
545
562
WeightDecayMode .COWCLIP ,
546
563
):
547
564
expected_keys .update (["prev_iter" , "row_counter" ])
565
+ if (
566
+ weight_decay_mode == WeightDecayMode .COUNTER
567
+ and counter_based_regularization .counter_weight_decay_mode
568
+ == CounterWeightDecayMode .ADAGRADW
569
+ ):
570
+ expected_keys .update (["iter" ])
548
571
assert set (optimizer_states_dict .keys ()) == expected_keys
549
572
550
573
if optimizer in (OptimType .PARTIAL_ROWWISE_ADAM , OptimType .ADAM ):
@@ -778,12 +801,13 @@ def _get_grad_from_counter_adagrad(
778
801
l2_wd = 1.0 if counter_weight_decay_mode == CounterWeightDecayMode .L2 else 0.0
779
802
780
803
if counter_halflife > 0 :
781
- freq = torch .tensor ([counter_halflife ]) / row_counter
804
+ freq = torch .where (
805
+ row_counter > 0 ,
806
+ torch .tensor ([counter_halflife ]) / row_counter ,
807
+ torch .tensor ([1.0 ]),
808
+ )
782
809
783
- if isinstance (regularization , CounterBasedRegularizationDefinition ):
784
- dense_cpu_grad += l2_wd * freq * weight_decay * weights
785
- else :
786
- dense_cpu_grad += l2_wd * weight_decay * weights
810
+ dense_cpu_grad += l2_wd * weight_decay * weights
787
811
return dense_cpu_grad , row_counter , freq
788
812
789
813
def _get_wts_from_counter_adagrad_using_counter (
@@ -863,6 +887,21 @@ def _get_wts_from_counter_adagrad_using_counter(
863
887
exp_reg_correction = 1.0 - freq * weight_decay * learning_rate
864
888
elif counter_weight_decay_mode == CounterWeightDecayMode .L2 :
865
889
exp_reg_correction = 1.0 - freq * weight_decay * multiplier
890
+ else :
891
+ if adjustment_iter <= 0 or (
892
+ adjustment_iter > 0 and iter_ > adjustment_iter
893
+ ):
894
+ if counter_weight_decay_mode == CounterWeightDecayMode .ADAGRADW :
895
+ adjusted_multiplier = torch .where (
896
+ row_counter > 0 ,
897
+ multiplier * torch .sqrt (row_counter ),
898
+ torch .Tensor ([0.0 ]),
899
+ )
900
+ exp_reg_correction = torch .where (
901
+ row_counter > 0 ,
902
+ 1.0 - weight_decay * learning_rate ,
903
+ torch .Tensor ([1.0 ]),
904
+ )
866
905
867
906
weights = exp_reg_correction * weights - adjusted_multiplier * dense_cpu_grad
868
907
return weights
@@ -1129,6 +1168,14 @@ def test_backward_optimizers_partial_rowwise_adam_bf16_momentum( # noqa C901
1129
1168
WeightDecayMode .COWCLIP ,
1130
1169
]
1131
1170
),
1171
+ counter_weight_decay_mode = st .sampled_from (
1172
+ [
1173
+ CounterWeightDecayMode .NONE ,
1174
+ CounterWeightDecayMode .L2 ,
1175
+ CounterWeightDecayMode .DECOUPLE ,
1176
+ CounterWeightDecayMode .ADAGRADW ,
1177
+ ]
1178
+ ),
1132
1179
)
1133
1180
@settings (
1134
1181
verbosity = VERBOSITY ,
@@ -1152,6 +1199,7 @@ def test_backward_optimizers_adagrad( # noqa C901
1152
1199
pooling_mode : PoolingMode ,
1153
1200
use_cpu : bool ,
1154
1201
weight_decay_mode : WeightDecayMode ,
1202
+ counter_weight_decay_mode : CounterWeightDecayMode ,
1155
1203
) -> None :
1156
1204
if (
1157
1205
pooling_mode == PoolingMode .NONE
@@ -1172,6 +1220,7 @@ def test_backward_optimizers_adagrad( # noqa C901
1172
1220
pooling_mode ,
1173
1221
use_cpu ,
1174
1222
weight_decay_mode ,
1223
+ counter_weight_decay_mode = counter_weight_decay_mode ,
1175
1224
)
1176
1225
1177
1226
@given (
0 commit comments