Skip to content

Commit c9b0498

Browse files
authored
Support block_size and bf16_stochastic_round keyword arguments for torchao optimizers (#42972)
feat(trainer): support block_size and bf16_stochastic_round for torchao optimizers
1 parent e8d60a7 commit c9b0498

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

src/transformers/trainer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,6 +1671,12 @@ def optimizer_hook(param):
16711671
optimizer_cls = AdamW8bit
16721672
else:
16731673
raise ValueError("Invalid optimizer")
1674+
optimizer_kwargs.update(
1675+
{
1676+
"block_size": optim_args.get("block_size", 256),
1677+
"bf16_stochastic_round": strtobool(optim_args.get("bf16_stochastic_round", "False")),
1678+
}
1679+
)
16741680
optimizer_kwargs.update(adam_kwargs)
16751681
elif args.optim in [
16761682
OptimizerNames.SCHEDULE_FREE_RADAM,

0 commit comments

Comments
 (0)