Skip to content

Commit 04052c4

Browse files
SunMarcBernardZach
authored andcommitted
FEAT / Trainer: Add adamw 4bit optimizer (huggingface#31865)
* add 4bit optimizer * style * fix msg * style * add qgalore * Revert "add qgalore" This reverts commit 25278e8. * style * version check
1 parent b597aa7 commit 04052c4

File tree

3 files changed

+29
-0
lines changed

3 files changed

+29
-0
lines changed

src/transformers/trainer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@
168168
is_torch_npu_available,
169169
is_torch_xla_available,
170170
is_torch_xpu_available,
171+
is_torchao_available,
171172
logging,
172173
strtobool,
173174
)
@@ -1451,7 +1452,23 @@ def optimizer_hook(param):
14511452
"gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)),
14521453
}
14531454
)
1455+
elif args.optim == OptimizerNames.ADAMW_TORCH_4BIT:
1456+
if not is_torchao_available() or version.parse(importlib.metadata.version("torchao")) < version.parse(
1457+
"0.4.0"
1458+
):
1459+
raise ImportError(
1460+
"You need to have `torchao>=0.4.0` in order to use torch 4-bit optimizers."
1461+
"Install it with `pip install torchao` or follow the instructions here: https://github.com/pytorch/ao"
1462+
)
1463+
if version.parse(importlib.metadata.version("torch")) < version.parse("2.3"):
1464+
raise ImportError(
1465+
"You need to have `torch>=2.3` in order to use torch 4-bit optimizers. "
1466+
"Install it with `pip install --upgrade torch`"
1467+
)
1468+
from torchao.prototype.low_bit_optim import AdamW4bit
14541469

1470+
optimizer_cls = AdamW4bit
1471+
optimizer_kwargs.update(adam_kwargs)
14551472
else:
14561473
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
14571474
return optimizer_cls, optimizer_kwargs

src/transformers/training_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ class OptimizerNames(ExplicitEnum):
154154
ADAMW_APEX_FUSED = "adamw_apex_fused"
155155
ADAFACTOR = "adafactor"
156156
ADAMW_ANYPRECISION = "adamw_anyprecision"
157+
ADAMW_TORCH_4BIT = "adamw_torch_4bit"
157158
SGD = "sgd"
158159
ADAGRAD = "adagrad"
159160
ADAMW_BNB = "adamw_bnb_8bit"

tests/trainer/test_trainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
is_apex_available,
100100
is_bitsandbytes_available,
101101
is_safetensors_available,
102+
is_torchao_available,
102103
is_torchdistx_available,
103104
)
104105
from transformers.utils.hp_naming import TrialShortNamer
@@ -4210,6 +4211,16 @@ def hp_name(trial):
42104211
dict(default_adam_kwargs, **default_anyprecision_kwargs),
42114212
)
42124213
)
4214+
if is_torchao_available():
4215+
import torchao
4216+
4217+
optim_test_params.append(
4218+
(
4219+
TrainingArguments(optim=OptimizerNames.ADAMW_TORCH_4BIT, output_dir="None"),
4220+
torchao.prototype.low_bit_optim.AdamW4bit,
4221+
default_adam_kwargs,
4222+
)
4223+
)
42134224

42144225

42154226
@require_torch

0 commit comments

Comments
 (0)