Skip to content

Commit 25278e8

Browse files
committed
add qgalore
1 parent 3dca6a2 commit 25278e8

File tree

4 files changed

+135
-0
lines changed

4 files changed

+135
-0
lines changed

src/transformers/trainer.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@
155155
is_ipex_available,
156156
is_lomo_available,
157157
is_peft_available,
158+
is_q_galore_torch_available,
158159
is_safetensors_available,
159160
is_sagemaker_dp_enabled,
160161
is_sagemaker_mp_enabled,
@@ -1288,6 +1289,132 @@ def get_optimizer_cls_and_kwargs(
12881289
optimizer_cls = torch.optim.Adagrad
12891290
elif args.optim == OptimizerNames.RMSPROP:
12901291
optimizer_cls = torch.optim.RMSprop
1292+
elif args.optim in [OptimizerNames.QGALORE_ADAMW_8BIT, OptimizerNames.QGALORE_ADAMW_8BIT_LAYERWISE]:
1293+
if not is_q_galore_torch_available():
1294+
raise ImportError(
1295+
"You need to install `q-galore-torch` in order to use GaLore optimizers"
1296+
" install it with `pip install qgalore"
1297+
)
1298+
from q_galore_torch import QGaLoreAdamW8bit
1299+
1300+
is_layerwise = args.optim.lower().endswith("layerwise")
1301+
if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED:
1302+
# TODO: check if this is True
1303+
raise NotImplementedError("Layer-wise QGaLore does not support DDP at this time")
1304+
1305+
optimizer_cls = QGaLoreAdamW8bit
1306+
1307+
if args.optim_target_modules is None:
1308+
raise ValueError(
1309+
"You need to define a `optim_target_modules` in order to properly use QGaLore optimizers"
1310+
)
1311+
if args.optim_target_modules is None:
1312+
raise ValueError(
1313+
"You need to define a `optim_target_modules` in order to properly use GaLore optimizers"
1314+
)
1315+
1316+
if not isinstance(args.optim_target_modules, (list, str)):
1317+
raise ValueError(
1318+
f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}"
1319+
)
1320+
1321+
if model is None:
1322+
raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.")
1323+
1324+
logger.warning(
1325+
"Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !"
1326+
)
1327+
1328+
all_linear = (
1329+
isinstance(args.optim_target_modules, str)
1330+
and args.optim_target_modules.replace("_", "-") == "all-linear"
1331+
)
1332+
1333+
galore_params = []
1334+
galore_params_names = []
1335+
for module_name, module in model.named_modules():
1336+
target_module_exists, is_regex = check_target_module_exists(
1337+
args.optim_target_modules, module_name, return_is_regex=True
1338+
)
1339+
1340+
if not isinstance(module, nn.Linear):
1341+
# Warn in case we match but it's not a linear layer
1342+
if target_module_exists and not is_regex:
1343+
logger.warning(
1344+
f"{module_name} has been matched but ignored as GaLore only supports linear layers. Please double check your `optim_target_modules`!"
1345+
)
1346+
1347+
continue
1348+
1349+
if not target_module_exists and not all_linear:
1350+
continue
1351+
1352+
galore_params.append(module.weight)
1353+
galore_params_names.append(module_name + ".weight")
1354+
1355+
if len(galore_params) == 0:
1356+
raise ValueError(
1357+
f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`."
1358+
)
1359+
1360+
non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names]
1361+
1362+
# The default args are from the official repository: https://github.com/VITA-Group/Q-GaLore
1363+
galore_optim_kwargs = {
1364+
"rank": int(optim_args.pop("rank", 256)),
1365+
"update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
1366+
"scale": float(optim_args.pop("scale", 0.25)),
1367+
"proj_type": optim_args.pop("proj_type", "std"),
1368+
"quant": optim_args.pop("quant", True),
1369+
"quant_n_bit": optim_args.pop("quant_n_bit", 4),
1370+
"quant_group_size": optim_args.pop("quant_group_size", 256),
1371+
"cos_threshold": optim_args.pop("cos_threshold", 0.4),
1372+
"gamma_proj": optim_args.pop("gamma_proj", 2),
1373+
"queue_size": optim_args.pop("queue_size", 5),
1374+
}
1375+
1376+
param_groups = [
1377+
{"params": non_galore_params},
1378+
{"params": galore_params, **galore_optim_kwargs},
1379+
]
1380+
1381+
if is_layerwise:
1382+
# For layer-wise optimizers, the optimization step is done through post accumulation
1383+
# gradient hooks. The trick is to first attach these hooks to the model parameters then
1384+
# create a dummy optimizer that will perform no-ops in the Trainer.
1385+
# See the original implementation or the nice implementation from @hiyouga
1386+
# here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
1387+
if args.gradient_accumulation_steps != 1:
1388+
raise ValueError("Layerwise QGaLoRE optimizer do not support gradient accumulation !")
1389+
1390+
optimizer_dict = {}
1391+
for param in non_galore_params:
1392+
if param.requires_grad:
1393+
param_groups = [{"params": [param]}]
1394+
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
1395+
# TODO: in the original repo, they multiply update_proj_gap param by 2, to check
1396+
for param in galore_params:
1397+
param_groups = [{"params": [param], **galore_optim_kwargs}]
1398+
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
1399+
1400+
def optimizer_hook(param):
1401+
if (not hasattr(param, "float_grad")) and param.grad is None:
1402+
return
1403+
optimizer_dict[param].step()
1404+
optimizer_dict[param].zero_grad()
1405+
1406+
id_galore_params = [id(p) for p in galore_params]
1407+
1408+
# TODO: strange, we are not applying on every param here compared to galore
1409+
for param in model.parameters():
1410+
if id(param) in id_galore_params or param.requires_grad:
1411+
setattr(param, "backward_hook", optimizer_hook)
1412+
1413+
optimizer_cls = LayerWiseDummyOptimizer
1414+
optimizer_kwargs.update({"optimizer_dict": optimizer_dict})
1415+
1416+
optimizer_kwargs.update({"params": param_groups})
1417+
12911418
elif args.optim in [
12921419
OptimizerNames.GALORE_ADAMW,
12931420
OptimizerNames.GALORE_ADAMW_8BIT,

src/transformers/training_args.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ class OptimizerNames(ExplicitEnum):
174174
GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
175175
GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
176176
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
177+
QGALORE_ADAMW_8BIT = "qgalore_adamw_8bit"
178+
QGALORE_ADAMW_8BIT_LAYERWISE = "qgalore_adamw_8bit_layerwise"
177179
LOMO = "lomo"
178180
ADALOMO = "adalomo"
179181

src/transformers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@
164164
is_pytesseract_available,
165165
is_pytest_available,
166166
is_pytorch_quantization_available,
167+
is_q_galore_torch_available,
167168
is_quanto_available,
168169
is_rjieba_available,
169170
is_sacremoses_available,

src/transformers/utils/import_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
9999
_bitsandbytes_available = _is_package_available("bitsandbytes")
100100
_eetq_available = _is_package_available("eetq")
101101
_galore_torch_available = _is_package_available("galore_torch")
102+
_q_galore_torch_available = _is_package_available("q_galore_torch")
102103
_lomo_available = _is_package_available("lomo_optim")
103104
_torchao_available = _is_package_available("torchao")
104105
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
@@ -346,6 +347,10 @@ def is_galore_torch_available():
346347
return _galore_torch_available
347348

348349

350+
def is_q_galore_torch_available():
351+
return _q_galore_torch_available
352+
353+
349354
def is_lomo_available():
350355
return _lomo_available
351356

0 commit comments

Comments
 (0)