File tree Expand file tree Collapse file tree 3 files changed +29
-0
lines changed
Expand file tree Collapse file tree 3 files changed +29
-0
lines changed Original file line number Diff line number Diff line change 168168 is_torch_npu_available ,
169169 is_torch_xla_available ,
170170 is_torch_xpu_available ,
171+ is_torchao_available ,
171172 logging ,
172173 strtobool ,
173174)
@@ -1451,7 +1452,23 @@ def optimizer_hook(param):
14511452 "gradient_clipping" : float (optim_args .get ("gradient_clipping" , 1.0 )),
14521453 }
14531454 )
1455+ elif args .optim == OptimizerNames .ADAMW_TORCH_4BIT :
1456+ if not is_torchao_available () or version .parse (importlib .metadata .version ("torchao" )) < version .parse (
1457+ "0.4.0"
1458+ ):
1459+ raise ImportError (
1460+ "You need to have `torchao>=0.4.0` in order to use torch 4-bit optimizers."
1461+ "Install it with `pip install torchao` or follow the instructions here: https://github.com/pytorch/ao"
1462+ )
1463+ if version .parse (importlib .metadata .version ("torch" )) < version .parse ("2.3" ):
1464+ raise ImportError (
1465+ "You need to have `torch>=2.3` in order to use torch 4-bit optimizers. "
1466+ "Install it with `pip install --upgrade torch`"
1467+ )
1468+ from torchao .prototype .low_bit_optim import AdamW4bit
14541469
1470+ optimizer_cls = AdamW4bit
1471+ optimizer_kwargs .update (adam_kwargs )
14551472 else :
14561473 raise ValueError (f"Trainer cannot instantiate unsupported optimizer: { args .optim } " )
14571474 return optimizer_cls , optimizer_kwargs
Original file line number Diff line number Diff line change @@ -154,6 +154,7 @@ class OptimizerNames(ExplicitEnum):
154154 ADAMW_APEX_FUSED = "adamw_apex_fused"
155155 ADAFACTOR = "adafactor"
156156 ADAMW_ANYPRECISION = "adamw_anyprecision"
157+ ADAMW_TORCH_4BIT = "adamw_torch_4bit"
157158 SGD = "sgd"
158159 ADAGRAD = "adagrad"
159160 ADAMW_BNB = "adamw_bnb_8bit"
Original file line number Diff line number Diff line change 9999 is_apex_available ,
100100 is_bitsandbytes_available ,
101101 is_safetensors_available ,
102+ is_torchao_available ,
102103 is_torchdistx_available ,
103104)
104105from transformers .utils .hp_naming import TrialShortNamer
@@ -4210,6 +4211,16 @@ def hp_name(trial):
42104211 dict (default_adam_kwargs , ** default_anyprecision_kwargs ),
42114212 )
42124213 )
4214+ if is_torchao_available ():
4215+ import torchao
4216+
4217+ optim_test_params .append (
4218+ (
4219+ TrainingArguments (optim = OptimizerNames .ADAMW_TORCH_4BIT , output_dir = "None" ),
4220+ torchao .prototype .low_bit_optim .AdamW4bit ,
4221+ default_adam_kwargs ,
4222+ )
4223+ )
42134224
42144225
42154226@require_torch
You can’t perform that action at this time.
0 commit comments